Skip to main content

gemini_tokenizer/
accumulator.rs

1// <FILE>src/accumulator.rs</FILE> - <DESC>Port of Google Python SDK's _TextsAccumulator for extracting countable text</DESC>
2// <VERS>VERSION: 0.2.0</VERS>
3// <WCTX>Ergonomics update to match Python SDK API</WCTX>
4// <CLOG>Add add_part() method, refactor add_content to delegate to it</CLOG>
5
6//! Text accumulation logic ported from the official Google Python SDK.
7//!
8//! The [`TextAccumulator`] traverses structured [`Content`], [`Tool`], and [`Schema`]
9//! objects, extracting all text segments that should be counted as tokens.
10//! This matches the behavior of `_TextsAccumulator` in
11//! `google/genai/local_tokenizer.py`.
12
13use crate::types::*;
14
15/// Accumulates countable text strings from structured Gemini API objects.
16///
17/// This is a faithful port of the Python SDK's `_TextsAccumulator` class.
18/// It traverses `Content`, `Tool`, `FunctionCall`, `FunctionResponse`, and
19/// `Schema` objects, extracting all text that Google counts when computing
20/// token counts.
21///
22/// # How text is extracted
23///
24/// - **Text parts**: the text string itself
25/// - **Function calls**: the function name, plus all dict keys and string values
26///   from the args
27/// - **Function responses**: the function name, plus all dict keys and string
28///   values from the response
29/// - **Function declarations** (tools): the name, description, plus recursive
30///   schema traversal of parameters and response
31/// - **Schemas**: format, description, enum values, required field names,
32///   property keys, and recursive traversal of nested schemas
33pub struct TextAccumulator {
34    texts: Vec<String>,
35}
36
37impl TextAccumulator {
38    /// Creates a new empty accumulator.
39    pub fn new() -> Self {
40        Self { texts: Vec::new() }
41    }
42
43    /// Returns the accumulated text segments.
44    pub fn get_texts(&self) -> &[String] {
45        &self.texts
46    }
47
48    /// Consumes the accumulator and returns the accumulated text segments.
49    pub fn into_texts(self) -> Vec<String> {
50        self.texts
51    }
52
53    /// Adds all text from multiple content objects.
54    pub fn add_contents(&mut self, contents: &[Content]) {
55        for content in contents {
56            self.add_content(content);
57        }
58    }
59
60    /// Adds all countable text from a single content object.
61    ///
62    /// Processes each part in the content by delegating to [`add_part`](Self::add_part).
63    pub fn add_content(&mut self, content: &Content) {
64        if let Some(parts) = &content.parts {
65            for part in parts {
66                self.add_part(part);
67            }
68        }
69    }
70
71    /// Adds countable text from a single content part.
72    ///
73    /// Processes the part's fields:
74    /// - Text: appends the text directly
75    /// - Function calls: delegates to [`add_function_call`](Self::add_function_call)
76    /// - Function responses: delegates to [`add_function_response`](Self::add_function_response)
77    pub fn add_part(&mut self, part: &Part) {
78        if let Some(fc) = &part.function_call {
79            self.add_function_call(fc);
80        }
81        if let Some(fr) = &part.function_response {
82            self.add_function_response(fr);
83        }
84        if let Some(text) = &part.text {
85            self.texts.push(text.clone());
86        }
87    }
88
89    /// Adds countable text from a function call.
90    ///
91    /// Extracts the function name and traverses the args dictionary,
92    /// collecting all keys and string values.
93    pub fn add_function_call(&mut self, function_call: &FunctionCall) {
94        if let Some(name) = &function_call.name {
95            self.texts.push(name.clone());
96        }
97        if let Some(args) = &function_call.args {
98            self.dict_traverse(args);
99        }
100    }
101
102    /// Adds countable text from multiple tools.
103    pub fn add_tools(&mut self, tools: &[Tool]) {
104        for tool in tools {
105            self.add_tool(tool);
106        }
107    }
108
109    /// Adds countable text from a single tool definition.
110    ///
111    /// Processes each function declaration in the tool.
112    pub fn add_tool(&mut self, tool: &Tool) {
113        if let Some(declarations) = &tool.function_declarations {
114            for decl in declarations {
115                self.add_function_declaration(decl);
116            }
117        }
118    }
119
120    /// Adds countable text from multiple function responses.
121    pub fn add_function_responses(&mut self, responses: &[FunctionResponse]) {
122        for response in responses {
123            self.add_function_response(response);
124        }
125    }
126
127    /// Adds countable text from a function response.
128    ///
129    /// Extracts the function name and traverses the response dictionary,
130    /// collecting all keys and string values.
131    pub fn add_function_response(&mut self, function_response: &FunctionResponse) {
132        if let Some(name) = &function_response.name {
133            self.texts.push(name.clone());
134        }
135        if let Some(response) = &function_response.response {
136            self.dict_traverse(response);
137        }
138    }
139
140    /// Adds countable text from a function declaration.
141    ///
142    /// Extracts the name, description, and recursively processes the
143    /// parameter and response schemas.
144    fn add_function_declaration(&mut self, decl: &FunctionDeclaration) {
145        if let Some(name) = &decl.name {
146            self.texts.push(name.clone());
147        }
148        if let Some(description) = &decl.description {
149            self.texts.push(description.clone());
150        }
151        if let Some(parameters) = &decl.parameters {
152            self.add_schema(parameters);
153        }
154        if let Some(response) = &decl.response {
155            self.add_schema(response);
156        }
157    }
158
159    /// Adds countable text from a schema definition.
160    ///
161    /// Extracts format, description, enum values, required field names,
162    /// property keys, and recursively processes nested schemas (items,
163    /// properties, examples).
164    pub fn add_schema(&mut self, schema: &Schema) {
165        // Note: schema.type and schema.title are tracked but NOT added to texts,
166        // matching the Python SDK behavior.
167        if let Some(format) = &schema.format {
168            self.texts.push(format.clone());
169        }
170        if let Some(description) = &schema.description {
171            self.texts.push(description.clone());
172        }
173        if let Some(enum_values) = &schema.enum_values {
174            for v in enum_values {
175                self.texts.push(v.clone());
176            }
177        }
178        if let Some(required) = &schema.required {
179            for r in required {
180                self.texts.push(r.clone());
181            }
182        }
183        if let Some(items) = &schema.items {
184            self.add_schema(items);
185        }
186        if let Some(properties) = &schema.properties {
187            for (key, value) in properties {
188                self.texts.push(key.clone());
189                self.add_schema(value);
190            }
191        }
192        if let Some(example) = &schema.example {
193            self.any_traverse(example);
194        }
195    }
196
197    /// Traverses a dictionary (JSON object), adding all keys and recursively
198    /// processing all values.
199    fn dict_traverse(&mut self, d: &std::collections::HashMap<String, serde_json::Value>) {
200        // Add all keys
201        let keys: Vec<String> = d.keys().cloned().collect();
202        self.texts.extend(keys);
203
204        // Traverse all values
205        for val in d.values() {
206            self.any_traverse(val);
207        }
208    }
209
210    /// Traverses an arbitrary JSON value, adding strings and recursing into
211    /// objects and arrays.
212    fn any_traverse(&mut self, value: &serde_json::Value) {
213        match value {
214            serde_json::Value::String(s) => {
215                self.texts.push(s.clone());
216            }
217            serde_json::Value::Object(map) => {
218                // Collect keys
219                let keys: Vec<String> = map.keys().cloned().collect();
220                self.texts.extend(keys);
221                // Recurse into values
222                for val in map.values() {
223                    self.any_traverse(val);
224                }
225            }
226            serde_json::Value::Array(arr) => {
227                for item in arr {
228                    self.any_traverse(item);
229                }
230            }
231            // Numbers, bools, nulls are not added to texts (matches Python SDK)
232            _ => {}
233        }
234    }
235}
236
237impl Default for TextAccumulator {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::collections::HashMap;
247
248    #[test]
249    fn test_empty_accumulator() {
250        let acc = TextAccumulator::new();
251        assert!(acc.get_texts().is_empty());
252    }
253
254    #[test]
255    fn test_add_text_content() {
256        let mut acc = TextAccumulator::new();
257        let content = Content {
258            role: Some("user".to_string()),
259            parts: Some(vec![Part {
260                text: Some("Hello, world!".to_string()),
261                ..Default::default()
262            }]),
263        };
264        acc.add_content(&content);
265        assert_eq!(acc.get_texts(), &["Hello, world!"]);
266    }
267
268    #[test]
269    fn test_add_function_call() {
270        let mut acc = TextAccumulator::new();
271        let mut args = HashMap::new();
272        args.insert(
273            "query".to_string(),
274            serde_json::Value::String("weather".to_string()),
275        );
276        args.insert(
277            "location".to_string(),
278            serde_json::Value::String("NYC".to_string()),
279        );
280
281        let fc = FunctionCall {
282            name: Some("search".to_string()),
283            args: Some(args),
284        };
285        acc.add_function_call(&fc);
286
287        let texts = acc.get_texts();
288        assert!(texts.contains(&"search".to_string()));
289        assert!(texts.contains(&"query".to_string()));
290        assert!(texts.contains(&"location".to_string()));
291        assert!(texts.contains(&"weather".to_string()));
292        assert!(texts.contains(&"NYC".to_string()));
293    }
294
295    #[test]
296    fn test_add_function_response() {
297        let mut acc = TextAccumulator::new();
298        let mut response = HashMap::new();
299        response.insert(
300            "result".to_string(),
301            serde_json::Value::String("sunny".to_string()),
302        );
303
304        let fr = FunctionResponse {
305            name: Some("search".to_string()),
306            response: Some(response),
307        };
308        acc.add_function_response(&fr);
309
310        let texts = acc.get_texts();
311        assert!(texts.contains(&"search".to_string()));
312        assert!(texts.contains(&"result".to_string()));
313        assert!(texts.contains(&"sunny".to_string()));
314    }
315
316    #[test]
317    fn test_add_schema_with_properties() {
318        let mut acc = TextAccumulator::new();
319        let mut properties = HashMap::new();
320        properties.insert(
321            "name".to_string(),
322            Schema {
323                schema_type: Some("STRING".to_string()),
324                description: Some("The user's name".to_string()),
325                ..Default::default()
326            },
327        );
328
329        let schema = Schema {
330            schema_type: Some("OBJECT".to_string()),
331            description: Some("A user object".to_string()),
332            required: Some(vec!["name".to_string()]),
333            properties: Some(properties),
334            ..Default::default()
335        };
336        acc.add_schema(&schema);
337
338        let texts = acc.get_texts();
339        assert!(texts.contains(&"A user object".to_string()));
340        assert!(texts.contains(&"name".to_string()));
341        // Property key "name" and required "name" both added
342        assert!(texts.contains(&"The user's name".to_string()));
343    }
344
345    #[test]
346    fn test_add_tool() {
347        let mut acc = TextAccumulator::new();
348        let tool = Tool {
349            function_declarations: Some(vec![FunctionDeclaration {
350                name: Some("get_weather".to_string()),
351                description: Some("Gets the weather for a location".to_string()),
352                parameters: Some(Schema {
353                    schema_type: Some("OBJECT".to_string()),
354                    properties: Some({
355                        let mut props = HashMap::new();
356                        props.insert(
357                            "location".to_string(),
358                            Schema {
359                                schema_type: Some("STRING".to_string()),
360                                description: Some("The city name".to_string()),
361                                ..Default::default()
362                            },
363                        );
364                        props
365                    }),
366                    required: Some(vec!["location".to_string()]),
367                    ..Default::default()
368                }),
369                response: None,
370            }]),
371        };
372        acc.add_tool(&tool);
373
374        let texts = acc.get_texts();
375        assert!(texts.contains(&"get_weather".to_string()));
376        assert!(texts.contains(&"Gets the weather for a location".to_string()));
377        assert!(texts.contains(&"location".to_string())); // property key
378        assert!(texts.contains(&"The city name".to_string()));
379    }
380
381    #[test]
382    fn test_schema_enum_values() {
383        let mut acc = TextAccumulator::new();
384        let schema = Schema {
385            schema_type: Some("STRING".to_string()),
386            enum_values: Some(vec!["red".to_string(), "green".to_string(), "blue".to_string()]),
387            ..Default::default()
388        };
389        acc.add_schema(&schema);
390
391        let texts = acc.get_texts();
392        assert!(texts.contains(&"red".to_string()));
393        assert!(texts.contains(&"green".to_string()));
394        assert!(texts.contains(&"blue".to_string()));
395    }
396
397    #[test]
398    fn test_any_traverse_nested() {
399        let mut acc = TextAccumulator::new();
400        let mut args = HashMap::new();
401        args.insert(
402            "data".to_string(),
403            serde_json::json!({"nested_key": "nested_value", "list": ["a", "b"]}),
404        );
405        let fc = FunctionCall {
406            name: Some("test_fn".to_string()),
407            args: Some(args),
408        };
409        acc.add_function_call(&fc);
410
411        let texts = acc.get_texts();
412        assert!(texts.contains(&"test_fn".to_string()));
413        assert!(texts.contains(&"data".to_string()));
414        assert!(texts.contains(&"nested_key".to_string()));
415        assert!(texts.contains(&"nested_value".to_string()));
416        assert!(texts.contains(&"list".to_string()));
417        assert!(texts.contains(&"a".to_string()));
418        assert!(texts.contains(&"b".to_string()));
419    }
420
421    #[test]
422    fn test_content_with_function_call_part() {
423        let mut acc = TextAccumulator::new();
424        let mut args = HashMap::new();
425        args.insert(
426            "q".to_string(),
427            serde_json::Value::String("test".to_string()),
428        );
429        let content = Content {
430            role: Some("model".to_string()),
431            parts: Some(vec![Part {
432                function_call: Some(FunctionCall {
433                    name: Some("search".to_string()),
434                    args: Some(args),
435                }),
436                ..Default::default()
437            }]),
438        };
439        acc.add_content(&content);
440
441        let texts = acc.get_texts();
442        assert!(texts.contains(&"search".to_string()));
443        assert!(texts.contains(&"q".to_string()));
444        assert!(texts.contains(&"test".to_string()));
445    }
446}
447
448// <FILE>src/accumulator.rs</FILE> - <DESC>Port of Google Python SDK's _TextsAccumulator for extracting countable text</DESC>
449// <VERS>END OF VERSION: 0.2.0</VERS>