Skip to main content

deepstrike_core/context/
token_engine.rs

1use std::sync::Arc;
2
3use crate::types::message::{Content, ContentPart, Message};
4
5/// Token counting and truncation interface. Implementations must be
6/// deterministic and must never panic on any valid UTF-8 input.
7pub trait TokenCounter: Send + Sync {
8    /// Count tokens in a UTF-8 string.
9    fn count(&self, text: &str) -> u32;
10
11    /// Return the longest prefix of `text` that fits within `max_tokens`.
12    /// The returned slice is always a valid UTF-8 prefix of `text`.
13    fn truncate<'a>(&self, text: &'a str, max_tokens: u32) -> &'a str;
14}
15
16/// Char-count approximation: 4 chars ≈ 1 token.
17/// Used when no real tokeniser is available. More accurate than byte-count
18/// for CJK text (3 bytes/char but ~0.5 tokens/char).
19pub struct CharApproxCounter;
20
21impl TokenCounter for CharApproxCounter {
22    fn count(&self, text: &str) -> u32 {
23        (text.chars().count() as u32 / 4).max(1)
24    }
25
26    fn truncate<'a>(&self, text: &'a str, max_tokens: u32) -> &'a str {
27        let max_chars = (max_tokens as usize).saturating_mul(4);
28        let mut byte_end = text.len(); // default: keep all
29        let mut seen = 0usize;
30        for (byte_idx, _) in text.char_indices() {
31            if seen >= max_chars {
32                byte_end = byte_idx;
33                break;
34            }
35            seen += 1;
36        }
37        &text[..byte_end]
38    }
39}
40
41/// Cheaply cloneable token engine shared across the context subsystem.
42/// All token counting and truncation goes through this single object —
43/// pressure, compression, and render use the same backend.
44#[derive(Clone)]
45pub struct ContextTokenEngine(Arc<dyn TokenCounter>);
46
47impl ContextTokenEngine {
48    pub fn char_approx() -> Self {
49        Self(Arc::new(CharApproxCounter))
50    }
51
52    pub fn count(&self, text: &str) -> u32 {
53        self.0.count(text)
54    }
55
56    pub fn truncate<'a>(&self, text: &'a str, max_tokens: u32) -> &'a str {
57        self.0.truncate(text, max_tokens)
58    }
59
60    pub fn token_budget_to_bytes(&self, tokens: u32) -> usize {
61        (tokens as usize).saturating_mul(4)
62    }
63
64    pub fn count_message(&self, msg: &Message) -> u32 {
65        match &msg.content {
66            Content::Text(t) => self.count(t),
67            Content::Parts(parts) => parts.iter().map(|p| self.count_part(p)).sum(),
68        }
69    }
70
71    fn count_part(&self, part: &ContentPart) -> u32 {
72        match part {
73            ContentPart::Text { text } => self.count(text),
74            ContentPart::ToolResult { output, .. } => self.count(output.as_str()),
75            ContentPart::Image { .. } => 1, // structural token — content is base64/url
76            ContentPart::Audio { data, .. } => self.count(data.as_str()),
77        }
78    }
79
80    /// Truncate a text message to `max_tokens`. Returns the message unchanged
81    /// if it fits. Parts messages are never truncated — mangling structured
82    /// content produces worse outcomes than a minor token overrun.
83    pub fn truncate_message(&self, msg: &Message, max_tokens: u32) -> Message {
84        match &msg.content {
85            Content::Text(t) => {
86                let kept = self.0.truncate(t, max_tokens);
87                if kept.len() < t.len() {
88                    let mut m = msg.clone();
89                    m.content = Content::Text(format!("{}… [truncated]", kept));
90                    m.token_count = Some(max_tokens);
91                    m
92                } else {
93                    msg.clone()
94                }
95            }
96            Content::Parts(_) => msg.clone(),
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::types::message::Message;
105
106    fn engine() -> ContextTokenEngine {
107        ContextTokenEngine::char_approx()
108    }
109
110    #[test]
111    fn count_nonzero_for_nonempty_text() {
112        assert!(engine().count("hello") > 0);
113    }
114
115    #[test]
116    fn count_is_char_based_not_byte_based() {
117        let e = engine();
118        // "你好" = 6 bytes, 2 chars → count = max(2/4, 1) = 1
119        // "hello" = 5 bytes, 5 chars → count = max(5/4, 1) = 1
120        // The key property: count doesn't grow 3× for CJK vs ASCII
121        let cjk_count = e.count("你好世界"); // 4 chars
122        let ascii_count = e.count("abcd"); // 4 chars (same char count)
123        assert_eq!(cjk_count, ascii_count);
124    }
125
126    #[test]
127    fn truncate_stays_within_budget() {
128        let e = engine();
129        let text = "a".repeat(1000);
130        let kept = e.0.truncate(&text, 10);
131        assert!(e.count(kept) <= 10);
132    }
133
134    #[test]
135    fn truncate_cjk_valid_utf8() {
136        let e = engine();
137        let text = "你好世界".repeat(100);
138        let kept = e.0.truncate(&text, 5);
139        assert!(std::str::from_utf8(kept.as_bytes()).is_ok());
140    }
141
142    #[test]
143    fn truncate_count_le_budget() {
144        let e = engine();
145        for max in [1u32, 5, 20, 100] {
146            let kept =
147                e.0.truncate("The quick brown fox jumps over the lazy dog.", max);
148            assert!(
149                e.count(kept) <= max,
150                "max={max} kept_count={}",
151                e.count(kept)
152            );
153        }
154    }
155
156    #[test]
157    fn truncate_message_appends_suffix_on_cut() {
158        let e = engine();
159        let msg = Message::user("a".repeat(200));
160        let truncated = e.truncate_message(&msg, 5);
161        let text = truncated.content.as_text().unwrap();
162        assert!(text.ends_with("… [truncated]"), "got: {text}");
163    }
164
165    #[test]
166    fn truncate_message_unchanged_when_fits() {
167        let e = engine();
168        let msg = Message::user("hi");
169        let out = e.truncate_message(&msg, 1000);
170        assert_eq!(out.content.as_text().unwrap(), "hi");
171    }
172}