gemini_tokenizer/
accumulator.rs1use crate::types::*;
14
15pub struct TextAccumulator {
34 texts: Vec<String>,
35}
36
37impl TextAccumulator {
38 pub fn new() -> Self {
40 Self { texts: Vec::new() }
41 }
42
43 pub fn get_texts(&self) -> &[String] {
45 &self.texts
46 }
47
48 pub fn into_texts(self) -> Vec<String> {
50 self.texts
51 }
52
53 pub fn add_contents(&mut self, contents: &[Content]) {
55 for content in contents {
56 self.add_content(content);
57 }
58 }
59
60 pub fn add_content(&mut self, content: &Content) {
64 if let Some(parts) = &content.parts {
65 for part in parts {
66 self.add_part(part);
67 }
68 }
69 }
70
71 pub fn add_part(&mut self, part: &Part) {
78 if let Some(fc) = &part.function_call {
79 self.add_function_call(fc);
80 }
81 if let Some(fr) = &part.function_response {
82 self.add_function_response(fr);
83 }
84 if let Some(text) = &part.text {
85 self.texts.push(text.clone());
86 }
87 }
88
89 pub fn add_function_call(&mut self, function_call: &FunctionCall) {
94 if let Some(name) = &function_call.name {
95 self.texts.push(name.clone());
96 }
97 if let Some(args) = &function_call.args {
98 self.dict_traverse(args);
99 }
100 }
101
102 pub fn add_tools(&mut self, tools: &[Tool]) {
104 for tool in tools {
105 self.add_tool(tool);
106 }
107 }
108
109 pub fn add_tool(&mut self, tool: &Tool) {
113 if let Some(declarations) = &tool.function_declarations {
114 for decl in declarations {
115 self.add_function_declaration(decl);
116 }
117 }
118 }
119
120 pub fn add_function_responses(&mut self, responses: &[FunctionResponse]) {
122 for response in responses {
123 self.add_function_response(response);
124 }
125 }
126
127 pub fn add_function_response(&mut self, function_response: &FunctionResponse) {
132 if let Some(name) = &function_response.name {
133 self.texts.push(name.clone());
134 }
135 if let Some(response) = &function_response.response {
136 self.dict_traverse(response);
137 }
138 }
139
140 fn add_function_declaration(&mut self, decl: &FunctionDeclaration) {
145 if let Some(name) = &decl.name {
146 self.texts.push(name.clone());
147 }
148 if let Some(description) = &decl.description {
149 self.texts.push(description.clone());
150 }
151 if let Some(parameters) = &decl.parameters {
152 self.add_schema(parameters);
153 }
154 if let Some(response) = &decl.response {
155 self.add_schema(response);
156 }
157 }
158
159 pub fn add_schema(&mut self, schema: &Schema) {
165 if let Some(format) = &schema.format {
168 self.texts.push(format.clone());
169 }
170 if let Some(description) = &schema.description {
171 self.texts.push(description.clone());
172 }
173 if let Some(enum_values) = &schema.enum_values {
174 for v in enum_values {
175 self.texts.push(v.clone());
176 }
177 }
178 if let Some(required) = &schema.required {
179 for r in required {
180 self.texts.push(r.clone());
181 }
182 }
183 if let Some(items) = &schema.items {
184 self.add_schema(items);
185 }
186 if let Some(properties) = &schema.properties {
187 for (key, value) in properties {
188 self.texts.push(key.clone());
189 self.add_schema(value);
190 }
191 }
192 if let Some(example) = &schema.example {
193 self.any_traverse(example);
194 }
195 }
196
197 fn dict_traverse(&mut self, d: &std::collections::HashMap<String, serde_json::Value>) {
200 let keys: Vec<String> = d.keys().cloned().collect();
202 self.texts.extend(keys);
203
204 for val in d.values() {
206 self.any_traverse(val);
207 }
208 }
209
210 fn any_traverse(&mut self, value: &serde_json::Value) {
213 match value {
214 serde_json::Value::String(s) => {
215 self.texts.push(s.clone());
216 }
217 serde_json::Value::Object(map) => {
218 let keys: Vec<String> = map.keys().cloned().collect();
220 self.texts.extend(keys);
221 for val in map.values() {
223 self.any_traverse(val);
224 }
225 }
226 serde_json::Value::Array(arr) => {
227 for item in arr {
228 self.any_traverse(item);
229 }
230 }
231 _ => {}
233 }
234 }
235}
236
237impl Default for TextAccumulator {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use std::collections::HashMap;
247
248 #[test]
249 fn test_empty_accumulator() {
250 let acc = TextAccumulator::new();
251 assert!(acc.get_texts().is_empty());
252 }
253
254 #[test]
255 fn test_add_text_content() {
256 let mut acc = TextAccumulator::new();
257 let content = Content {
258 role: Some("user".to_string()),
259 parts: Some(vec![Part {
260 text: Some("Hello, world!".to_string()),
261 ..Default::default()
262 }]),
263 };
264 acc.add_content(&content);
265 assert_eq!(acc.get_texts(), &["Hello, world!"]);
266 }
267
268 #[test]
269 fn test_add_function_call() {
270 let mut acc = TextAccumulator::new();
271 let mut args = HashMap::new();
272 args.insert(
273 "query".to_string(),
274 serde_json::Value::String("weather".to_string()),
275 );
276 args.insert(
277 "location".to_string(),
278 serde_json::Value::String("NYC".to_string()),
279 );
280
281 let fc = FunctionCall {
282 name: Some("search".to_string()),
283 args: Some(args),
284 };
285 acc.add_function_call(&fc);
286
287 let texts = acc.get_texts();
288 assert!(texts.contains(&"search".to_string()));
289 assert!(texts.contains(&"query".to_string()));
290 assert!(texts.contains(&"location".to_string()));
291 assert!(texts.contains(&"weather".to_string()));
292 assert!(texts.contains(&"NYC".to_string()));
293 }
294
295 #[test]
296 fn test_add_function_response() {
297 let mut acc = TextAccumulator::new();
298 let mut response = HashMap::new();
299 response.insert(
300 "result".to_string(),
301 serde_json::Value::String("sunny".to_string()),
302 );
303
304 let fr = FunctionResponse {
305 name: Some("search".to_string()),
306 response: Some(response),
307 };
308 acc.add_function_response(&fr);
309
310 let texts = acc.get_texts();
311 assert!(texts.contains(&"search".to_string()));
312 assert!(texts.contains(&"result".to_string()));
313 assert!(texts.contains(&"sunny".to_string()));
314 }
315
316 #[test]
317 fn test_add_schema_with_properties() {
318 let mut acc = TextAccumulator::new();
319 let mut properties = HashMap::new();
320 properties.insert(
321 "name".to_string(),
322 Schema {
323 schema_type: Some("STRING".to_string()),
324 description: Some("The user's name".to_string()),
325 ..Default::default()
326 },
327 );
328
329 let schema = Schema {
330 schema_type: Some("OBJECT".to_string()),
331 description: Some("A user object".to_string()),
332 required: Some(vec!["name".to_string()]),
333 properties: Some(properties),
334 ..Default::default()
335 };
336 acc.add_schema(&schema);
337
338 let texts = acc.get_texts();
339 assert!(texts.contains(&"A user object".to_string()));
340 assert!(texts.contains(&"name".to_string()));
341 assert!(texts.contains(&"The user's name".to_string()));
343 }
344
345 #[test]
346 fn test_add_tool() {
347 let mut acc = TextAccumulator::new();
348 let tool = Tool {
349 function_declarations: Some(vec![FunctionDeclaration {
350 name: Some("get_weather".to_string()),
351 description: Some("Gets the weather for a location".to_string()),
352 parameters: Some(Schema {
353 schema_type: Some("OBJECT".to_string()),
354 properties: Some({
355 let mut props = HashMap::new();
356 props.insert(
357 "location".to_string(),
358 Schema {
359 schema_type: Some("STRING".to_string()),
360 description: Some("The city name".to_string()),
361 ..Default::default()
362 },
363 );
364 props
365 }),
366 required: Some(vec!["location".to_string()]),
367 ..Default::default()
368 }),
369 response: None,
370 }]),
371 };
372 acc.add_tool(&tool);
373
374 let texts = acc.get_texts();
375 assert!(texts.contains(&"get_weather".to_string()));
376 assert!(texts.contains(&"Gets the weather for a location".to_string()));
377 assert!(texts.contains(&"location".to_string())); assert!(texts.contains(&"The city name".to_string()));
379 }
380
381 #[test]
382 fn test_schema_enum_values() {
383 let mut acc = TextAccumulator::new();
384 let schema = Schema {
385 schema_type: Some("STRING".to_string()),
386 enum_values: Some(vec!["red".to_string(), "green".to_string(), "blue".to_string()]),
387 ..Default::default()
388 };
389 acc.add_schema(&schema);
390
391 let texts = acc.get_texts();
392 assert!(texts.contains(&"red".to_string()));
393 assert!(texts.contains(&"green".to_string()));
394 assert!(texts.contains(&"blue".to_string()));
395 }
396
397 #[test]
398 fn test_any_traverse_nested() {
399 let mut acc = TextAccumulator::new();
400 let mut args = HashMap::new();
401 args.insert(
402 "data".to_string(),
403 serde_json::json!({"nested_key": "nested_value", "list": ["a", "b"]}),
404 );
405 let fc = FunctionCall {
406 name: Some("test_fn".to_string()),
407 args: Some(args),
408 };
409 acc.add_function_call(&fc);
410
411 let texts = acc.get_texts();
412 assert!(texts.contains(&"test_fn".to_string()));
413 assert!(texts.contains(&"data".to_string()));
414 assert!(texts.contains(&"nested_key".to_string()));
415 assert!(texts.contains(&"nested_value".to_string()));
416 assert!(texts.contains(&"list".to_string()));
417 assert!(texts.contains(&"a".to_string()));
418 assert!(texts.contains(&"b".to_string()));
419 }
420
421 #[test]
422 fn test_content_with_function_call_part() {
423 let mut acc = TextAccumulator::new();
424 let mut args = HashMap::new();
425 args.insert(
426 "q".to_string(),
427 serde_json::Value::String("test".to_string()),
428 );
429 let content = Content {
430 role: Some("model".to_string()),
431 parts: Some(vec![Part {
432 function_call: Some(FunctionCall {
433 name: Some("search".to_string()),
434 args: Some(args),
435 }),
436 ..Default::default()
437 }]),
438 };
439 acc.add_content(&content);
440
441 let texts = acc.get_texts();
442 assert!(texts.contains(&"search".to_string()));
443 assert!(texts.contains(&"q".to_string()));
444 assert!(texts.contains(&"test".to_string()));
445 }
446}
447
448