Skip to main content

neuron_context/
injector.rs

1//! System prompt injection based on turn count and token thresholds.
2
3/// Trigger condition for a [`SystemInjector`] rule.
4#[derive(Debug, Clone)]
5pub enum InjectionTrigger {
6    /// Fire every N turns (turn % n == 0, excluding turn 0).
7    EveryNTurns(usize),
8    /// Fire when the token count meets or exceeds the threshold.
9    OnTokenThreshold(usize),
10}
11
12struct InjectionRule {
13    trigger: InjectionTrigger,
14    content: String,
15}
16
17/// Injects system prompt content based on turn or token thresholds.
18///
19/// Add rules with [`SystemInjector::add_rule`], then call [`SystemInjector::check`]
20/// each turn to get any content that should be injected.
21///
22/// # Example
23///
24/// ```
25/// use neuron_context::{SystemInjector, InjectionTrigger};
26///
27/// let mut injector = SystemInjector::new();
28/// injector.add_rule(InjectionTrigger::EveryNTurns(5), "Reminder: be concise.".into());
29/// injector.add_rule(InjectionTrigger::OnTokenThreshold(50_000), "Context is getting long.".into());
30///
31/// // Turn 5, under token threshold
32/// let injected = injector.check(5, 10_000);
33/// assert!(injected.contains(&"Reminder: be concise.".to_string()));
34///
35/// // Turn 1, over token threshold
36/// let injected = injector.check(1, 60_000);
37/// assert!(injected.contains(&"Context is getting long.".to_string()));
38/// ```
39#[derive(Default)]
40pub struct SystemInjector {
41    rules: Vec<InjectionRule>,
42}
43
44impl SystemInjector {
45    /// Creates a new `SystemInjector` with no rules.
46    #[must_use]
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    /// Adds an injection rule.
52    ///
53    /// # Arguments
54    /// * `trigger` — when this rule fires
55    /// * `content` — the text to inject when triggered
56    pub fn add_rule(&mut self, trigger: InjectionTrigger, content: String) {
57        self.rules.push(InjectionRule { trigger, content });
58    }
59
60    /// Returns all content strings whose triggers are satisfied by the given state.
61    ///
62    /// # Arguments
63    /// * `turn` — the current turn number (1-indexed for "every N" checks)
64    /// * `token_count` — the current estimated token count
65    #[must_use]
66    pub fn check(&self, turn: usize, token_count: usize) -> Vec<String> {
67        self.rules
68            .iter()
69            .filter(|rule| match rule.trigger {
70                InjectionTrigger::EveryNTurns(n) => n > 0 && turn > 0 && turn.is_multiple_of(n),
71                InjectionTrigger::OnTokenThreshold(threshold) => token_count >= threshold,
72            })
73            .map(|rule| rule.content.clone())
74            .collect()
75    }
76}