deepstrike_core/context/
token_engine.rs1use std::sync::Arc;
2
3use crate::types::message::{Content, ContentPart, Message};
4
5pub trait TokenCounter: Send + Sync {
8 fn count(&self, text: &str) -> u32;
10
11 fn truncate<'a>(&self, text: &'a str, max_tokens: u32) -> &'a str;
14}
15
16pub struct CharApproxCounter;
20
21impl TokenCounter for CharApproxCounter {
22 fn count(&self, text: &str) -> u32 {
23 (text.chars().count() as u32 / 4).max(1)
24 }
25
26 fn truncate<'a>(&self, text: &'a str, max_tokens: u32) -> &'a str {
27 let max_chars = (max_tokens as usize).saturating_mul(4);
28 let mut byte_end = text.len(); let mut seen = 0usize;
30 for (byte_idx, _) in text.char_indices() {
31 if seen >= max_chars {
32 byte_end = byte_idx;
33 break;
34 }
35 seen += 1;
36 }
37 &text[..byte_end]
38 }
39}
40
41#[derive(Clone)]
45pub struct ContextTokenEngine(Arc<dyn TokenCounter>);
46
47impl ContextTokenEngine {
48 pub fn char_approx() -> Self {
49 Self(Arc::new(CharApproxCounter))
50 }
51
52 pub fn count(&self, text: &str) -> u32 {
53 self.0.count(text)
54 }
55
56 pub fn truncate<'a>(&self, text: &'a str, max_tokens: u32) -> &'a str {
57 self.0.truncate(text, max_tokens)
58 }
59
60 pub fn token_budget_to_bytes(&self, tokens: u32) -> usize {
61 (tokens as usize).saturating_mul(4)
62 }
63
64 pub fn count_message(&self, msg: &Message) -> u32 {
65 match &msg.content {
66 Content::Text(t) => self.count(t),
67 Content::Parts(parts) => parts.iter().map(|p| self.count_part(p)).sum(),
68 }
69 }
70
71 fn count_part(&self, part: &ContentPart) -> u32 {
72 match part {
73 ContentPart::Text { text } => self.count(text),
74 ContentPart::ToolResult { output, .. } => self.count(output.as_str()),
75 ContentPart::Image { .. } => 1, ContentPart::Audio { data, .. } => self.count(data.as_str()),
77 }
78 }
79
80 pub fn truncate_message(&self, msg: &Message, max_tokens: u32) -> Message {
84 match &msg.content {
85 Content::Text(t) => {
86 let kept = self.0.truncate(t, max_tokens);
87 if kept.len() < t.len() {
88 let mut m = msg.clone();
89 m.content = Content::Text(format!("{}… [truncated]", kept));
90 m.token_count = Some(max_tokens);
91 m
92 } else {
93 msg.clone()
94 }
95 }
96 Content::Parts(_) => msg.clone(),
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::types::message::Message;
105
106 fn engine() -> ContextTokenEngine {
107 ContextTokenEngine::char_approx()
108 }
109
110 #[test]
111 fn count_nonzero_for_nonempty_text() {
112 assert!(engine().count("hello") > 0);
113 }
114
115 #[test]
116 fn count_is_char_based_not_byte_based() {
117 let e = engine();
118 let cjk_count = e.count("你好世界"); let ascii_count = e.count("abcd"); assert_eq!(cjk_count, ascii_count);
124 }
125
126 #[test]
127 fn truncate_stays_within_budget() {
128 let e = engine();
129 let text = "a".repeat(1000);
130 let kept = e.0.truncate(&text, 10);
131 assert!(e.count(kept) <= 10);
132 }
133
134 #[test]
135 fn truncate_cjk_valid_utf8() {
136 let e = engine();
137 let text = "你好世界".repeat(100);
138 let kept = e.0.truncate(&text, 5);
139 assert!(std::str::from_utf8(kept.as_bytes()).is_ok());
140 }
141
142 #[test]
143 fn truncate_count_le_budget() {
144 let e = engine();
145 for max in [1u32, 5, 20, 100] {
146 let kept =
147 e.0.truncate("The quick brown fox jumps over the lazy dog.", max);
148 assert!(
149 e.count(kept) <= max,
150 "max={max} kept_count={}",
151 e.count(kept)
152 );
153 }
154 }
155
156 #[test]
157 fn truncate_message_appends_suffix_on_cut() {
158 let e = engine();
159 let msg = Message::user("a".repeat(200));
160 let truncated = e.truncate_message(&msg, 5);
161 let text = truncated.content.as_text().unwrap();
162 assert!(text.ends_with("… [truncated]"), "got: {text}");
163 }
164
165 #[test]
166 fn truncate_message_unchanged_when_fits() {
167 let e = engine();
168 let msg = Message::user("hi");
169 let out = e.truncate_message(&msg, 1000);
170 assert_eq!(out.content.as_text().unwrap(), "hi");
171 }
172}