neuron_context/injector.rs
1//! System prompt injection based on turn count and token thresholds.
2
3/// Trigger condition for a [`SystemInjector`] rule.
4#[derive(Debug, Clone)]
5pub enum InjectionTrigger {
6 /// Fire every N turns (turn % n == 0, excluding turn 0).
7 EveryNTurns(usize),
8 /// Fire when the token count meets or exceeds the threshold.
9 OnTokenThreshold(usize),
10}
11
12struct InjectionRule {
13 trigger: InjectionTrigger,
14 content: String,
15}
16
17/// Injects system prompt content based on turn or token thresholds.
18///
19/// Add rules with [`SystemInjector::add_rule`], then call [`SystemInjector::check`]
20/// each turn to get any content that should be injected.
21///
22/// # Example
23///
24/// ```
25/// use neuron_context::{SystemInjector, InjectionTrigger};
26///
27/// let mut injector = SystemInjector::new();
28/// injector.add_rule(InjectionTrigger::EveryNTurns(5), "Reminder: be concise.".into());
29/// injector.add_rule(InjectionTrigger::OnTokenThreshold(50_000), "Context is getting long.".into());
30///
31/// // Turn 5, under token threshold
32/// let injected = injector.check(5, 10_000);
33/// assert!(injected.contains(&"Reminder: be concise.".to_string()));
34///
35/// // Turn 1, over token threshold
36/// let injected = injector.check(1, 60_000);
37/// assert!(injected.contains(&"Context is getting long.".to_string()));
38/// ```
39#[derive(Default)]
40pub struct SystemInjector {
41 rules: Vec<InjectionRule>,
42}
43
44impl SystemInjector {
45 /// Creates a new `SystemInjector` with no rules.
46 #[must_use]
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 /// Adds an injection rule.
52 ///
53 /// # Arguments
54 /// * `trigger` — when this rule fires
55 /// * `content` — the text to inject when triggered
56 pub fn add_rule(&mut self, trigger: InjectionTrigger, content: String) {
57 self.rules.push(InjectionRule { trigger, content });
58 }
59
60 /// Returns all content strings whose triggers are satisfied by the given state.
61 ///
62 /// # Arguments
63 /// * `turn` — the current turn number (1-indexed for "every N" checks)
64 /// * `token_count` — the current estimated token count
65 #[must_use]
66 pub fn check(&self, turn: usize, token_count: usize) -> Vec<String> {
67 self.rules
68 .iter()
69 .filter(|rule| match rule.trigger {
70 InjectionTrigger::EveryNTurns(n) => n > 0 && turn > 0 && turn.is_multiple_of(n),
71 InjectionTrigger::OnTokenThreshold(threshold) => token_count >= threshold,
72 })
73 .map(|rule| rule.content.clone())
74 .collect()
75 }
76}