1pub mod accumulator;
60pub mod types;
61
62pub use accumulator::TextAccumulator;
63pub use types::*;
64
65use sentencepiece::SentencePieceProcessor;
66use std::sync::{Arc, OnceLock};
67
68pub const MODEL_SHA256: &str =
70 "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c";
71
72pub const VOCAB_SIZE: usize = 262_144;
74
75static MODEL_BYTES: &[u8] =
80 include_bytes!("../resources/gemma3_cleaned_262144_v2.spiece.model");
81
82static GLOBAL_PROCESSOR: OnceLock<Arc<SentencePieceProcessor>> = OnceLock::new();
83
84const SUPPORTED_MODELS: &[&str] = &[
89 "gemini-2.5-pro",
91 "gemini-2.5-flash",
92 "gemini-2.5-flash-lite",
93 "gemini-2.0-flash",
94 "gemini-2.0-flash-lite",
95 "gemini-2.5-pro-preview-06-05",
97 "gemini-2.5-pro-preview-05-06",
98 "gemini-2.5-pro-exp-03-25",
99 "gemini-live-2.5-flash",
100 "gemini-2.5-flash-preview-05-20",
101 "gemini-2.5-flash-preview-04-17",
102 "gemini-2.5-flash-lite-preview-06-17",
103 "gemini-2.0-flash-001",
104 "gemini-2.0-flash-lite-001",
105 "gemini-3-pro-preview",
106];
107
108#[derive(Debug)]
110pub enum TokenizerError {
111 ModelLoadError(String),
113
114 HashMismatch { expected: String, actual: String },
116
117 UnsupportedModel(String),
119}
120
121impl std::fmt::Display for TokenizerError {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 TokenizerError::ModelLoadError(msg) => {
125 write!(f, "failed to load SentencePiece model: {}", msg)
126 }
127 TokenizerError::HashMismatch { expected, actual } => {
128 write!(
129 f,
130 "model hash mismatch: expected {}, got {}",
131 expected, actual
132 )
133 }
134 TokenizerError::UnsupportedModel(name) => {
135 write!(
136 f,
137 "model {} is not supported. Supported models: {}",
138 name,
139 SUPPORTED_MODELS.join(", ")
140 )
141 }
142 }
143 }
144}
145
146impl std::error::Error for TokenizerError {}
147
148#[derive(Debug)]
164pub struct LocalTokenizer {
165 processor: Arc<SentencePieceProcessor>,
166 model_name: String,
167}
168
169impl LocalTokenizer {
170 pub fn new(model_name: &str) -> Result<Self, TokenizerError> {
182 if !SUPPORTED_MODELS.contains(&model_name) {
183 return Err(TokenizerError::UnsupportedModel(model_name.to_string()));
184 }
185 let processor = GLOBAL_PROCESSOR
186 .get_or_init(|| {
187 let p = SentencePieceProcessor::from_serialized_proto(MODEL_BYTES)
188 .expect("Critical: Embedded tokenizer model is corrupt");
189 Arc::new(p)
190 })
191 .clone();
192
193 Ok(Self {
194 processor,
195 model_name: model_name.to_string(),
196 })
197 }
198
199 pub fn model_name(&self) -> &str {
201 &self.model_name
202 }
203
204 pub fn vocab_size(&self) -> usize {
206 self.processor.len()
207 }
208
209 pub fn count_tokens<'a>(
229 &self,
230 contents: impl Into<Contents<'a>>,
231 config: Option<&CountTokensConfig>,
232 ) -> CountTokensResult {
233 let content_vec = contents_to_vec(contents.into());
234 let mut acc = TextAccumulator::new();
235 acc.add_contents(&content_vec);
236
237 if let Some(config) = config {
238 if let Some(tools) = &config.tools {
239 acc.add_tools(tools);
240 }
241 if let Some(schema) = &config.response_schema {
242 acc.add_schema(schema);
243 }
244 if let Some(system_instruction) = &config.system_instruction {
245 acc.add_content(system_instruction);
246 }
247 }
248
249 let mut total = 0;
250 for text in acc.get_texts() {
251 total += match self.processor.encode(text) {
252 Ok(pieces) => pieces.len(),
253 Err(_) => 0,
254 };
255 }
256
257 CountTokensResult {
258 total_tokens: total,
259 }
260 }
261
262 pub fn compute_tokens<'a>(
281 &self,
282 contents: impl Into<Contents<'a>>,
283 ) -> ComputeTokensResult {
284 let content_vec = contents_to_vec(contents.into());
285 let mut tokens_info = Vec::new();
286
287 for content in &content_vec {
288 if let Some(parts) = &content.parts {
289 for part in parts {
290 let mut acc = TextAccumulator::new();
291 acc.add_part(part);
292
293 let mut all_ids = Vec::new();
294 let mut all_tokens = Vec::new();
295 for text in acc.get_texts() {
296 if let Ok(pieces) = self.processor.encode(text) {
297 for p in pieces {
298 all_ids.push(p.id);
299 all_tokens.push(token_piece_to_bytes(&p.piece));
300 }
301 }
302 }
303
304 tokens_info.push(TokensInfo {
305 token_ids: all_ids,
306 tokens: all_tokens,
307 role: content.role.clone(),
308 });
309 }
310 }
311 }
312
313 ComputeTokensResult { tokens_info }
314 }
315
316 pub fn processor(&self) -> &SentencePieceProcessor {
318 &self.processor
319 }
320}
321
322fn contents_to_vec(contents: Contents<'_>) -> Vec<Content> {
327 match contents {
328 Contents::Text(s) => vec![Content {
329 role: Some("user".to_string()),
330 parts: Some(vec![Part {
331 text: Some(s.to_string()),
332 ..Default::default()
333 }]),
334 }],
335 Contents::Structured(c) => c.to_vec(),
336 }
337}
338
339fn token_piece_to_bytes(piece: &str) -> Vec<u8> {
345 if piece.len() == 6 && piece.starts_with("<0x") && piece.ends_with('>') {
346 if let Ok(val) = u8::from_str_radix(&piece[3..5], 16) {
347 return vec![val];
348 }
349 }
350 piece.replace('\u{2581}', " ").into_bytes()
351}
352
353pub fn verify_model_hash() -> Result<(), TokenizerError> {
357 use sha2::{Digest, Sha256};
358 let mut hasher = Sha256::new();
359 hasher.update(MODEL_BYTES);
360 let actual = format!("{:x}", hasher.finalize());
361 if actual == MODEL_SHA256 {
362 Ok(())
363 } else {
364 Err(TokenizerError::HashMismatch {
365 expected: MODEL_SHA256.to_string(),
366 actual,
367 })
368 }
369}
370
371pub fn supported_models() -> &'static [&'static str] {
373 SUPPORTED_MODELS
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use std::collections::HashMap;
380
381 #[test]
382 fn test_verify_embedded_model_hash() {
383 verify_model_hash().expect("embedded model hash should match");
384 }
385
386 #[test]
387 fn test_vocab_size() {
388 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
389 assert_eq!(tok.vocab_size(), VOCAB_SIZE);
390 }
391
392 #[test]
393 fn test_model_name() {
394 let tok = LocalTokenizer::new("gemini-2.0-flash").expect("tokenizer should load");
395 assert_eq!(tok.model_name(), "gemini-2.0-flash");
396 }
397
398 #[test]
399 fn test_unsupported_model() {
400 let err = LocalTokenizer::new("gpt-4").unwrap_err();
401 match err {
402 TokenizerError::UnsupportedModel(name) => assert_eq!(name, "gpt-4"),
403 _ => panic!("expected UnsupportedModel error"),
404 }
405 }
406
407 #[test]
408 fn test_all_supported_models() {
409 for model in SUPPORTED_MODELS {
410 LocalTokenizer::new(model)
411 .unwrap_or_else(|_| panic!("{} should be supported", model));
412 }
413 }
414
415 #[test]
416 fn test_count_tokens_text() {
417 let tok = LocalTokenizer::new("gemini-2.0-flash-001").expect("tokenizer should load");
418 let result = tok.count_tokens("What is your name?", None);
419 assert_eq!(result.total_tokens, 5);
420 }
421
422 #[test]
423 fn test_count_tokens_empty() {
424 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
425 let result = tok.count_tokens("", None);
426 assert_eq!(result.total_tokens, 0);
427 }
428
429 #[test]
430 fn test_count_tokens_content() {
431 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
432 let contents = vec![Content {
433 role: Some("user".to_string()),
434 parts: Some(vec![Part {
435 text: Some("Hello, world!".to_string()),
436 ..Default::default()
437 }]),
438 }];
439 let result = tok.count_tokens(contents.as_slice(), None);
440 let direct = tok.count_tokens("Hello, world!", None);
441 assert_eq!(result.total_tokens, direct.total_tokens);
442 }
443
444 #[test]
445 fn test_count_tokens_vec_ref() {
446 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
447 let contents = vec![Content {
448 role: Some("user".to_string()),
449 parts: Some(vec![Part {
450 text: Some("Hello".to_string()),
451 ..Default::default()
452 }]),
453 }];
454 let result = tok.count_tokens(&contents, None);
456 assert!(result.total_tokens > 0);
457 }
458
459 #[test]
460 fn test_count_tokens_function_call() {
461 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
462
463 let mut args = HashMap::new();
464 args.insert(
465 "query".to_string(),
466 serde_json::Value::String("weather".to_string()),
467 );
468
469 let contents = vec![Content {
470 role: Some("model".to_string()),
471 parts: Some(vec![Part {
472 function_call: Some(FunctionCall {
473 name: Some("search".to_string()),
474 args: Some(args),
475 }),
476 ..Default::default()
477 }]),
478 }];
479
480 let result = tok.count_tokens(contents.as_slice(), None);
481 let expected = tok.count_tokens("search", None).total_tokens
482 + tok.count_tokens("query", None).total_tokens
483 + tok.count_tokens("weather", None).total_tokens;
484 assert_eq!(result.total_tokens, expected);
485 }
486
487 #[test]
488 fn test_count_tokens_function_response() {
489 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
490
491 let mut response = HashMap::new();
492 response.insert(
493 "result".to_string(),
494 serde_json::Value::String("sunny".to_string()),
495 );
496
497 let contents = vec![Content {
498 role: Some("model".to_string()),
499 parts: Some(vec![Part {
500 function_response: Some(FunctionResponse {
501 name: Some("search".to_string()),
502 response: Some(response),
503 }),
504 ..Default::default()
505 }]),
506 }];
507
508 let result = tok.count_tokens(contents.as_slice(), None);
509 let expected = tok.count_tokens("search", None).total_tokens
510 + tok.count_tokens("result", None).total_tokens
511 + tok.count_tokens("sunny", None).total_tokens;
512 assert_eq!(result.total_tokens, expected);
513 }
514
515 #[test]
516 fn test_count_tokens_with_tools() {
517 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
518
519 let contents = vec![Content {
520 role: Some("user".to_string()),
521 parts: Some(vec![Part {
522 text: Some("What is the weather?".to_string()),
523 ..Default::default()
524 }]),
525 }];
526
527 let config = CountTokensConfig {
528 tools: Some(vec![Tool {
529 function_declarations: Some(vec![FunctionDeclaration {
530 name: Some("get_weather".to_string()),
531 description: Some("Gets the current weather".to_string()),
532 parameters: Some(Schema {
533 schema_type: Some("OBJECT".to_string()),
534 properties: Some({
535 let mut props = HashMap::new();
536 props.insert(
537 "city".to_string(),
538 Schema {
539 schema_type: Some("STRING".to_string()),
540 description: Some("The city name".to_string()),
541 ..Default::default()
542 },
543 );
544 props
545 }),
546 required: Some(vec!["city".to_string()]),
547 ..Default::default()
548 }),
549 response: None,
550 }]),
551 }]),
552 ..Default::default()
553 };
554
555 let with_tools = tok.count_tokens(contents.as_slice(), Some(&config));
556 let without_tools = tok.count_tokens(contents.as_slice(), None);
557 assert!(with_tools.total_tokens > without_tools.total_tokens);
558 }
559
560 #[test]
561 fn test_count_tokens_with_system_instruction() {
562 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
563
564 let contents = vec![Content {
565 role: Some("user".to_string()),
566 parts: Some(vec![Part {
567 text: Some("Hello".to_string()),
568 ..Default::default()
569 }]),
570 }];
571
572 let config = CountTokensConfig {
573 system_instruction: Some(Content {
574 role: Some("system".to_string()),
575 parts: Some(vec![Part {
576 text: Some("You are a helpful assistant.".to_string()),
577 ..Default::default()
578 }]),
579 }),
580 ..Default::default()
581 };
582
583 let with_system = tok.count_tokens(contents.as_slice(), Some(&config));
584 let without_system = tok.count_tokens(contents.as_slice(), None);
585 assert!(with_system.total_tokens > without_system.total_tokens);
586 }
587
588 #[test]
589 fn test_count_tokens_multiple_parts() {
590 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
591 let contents = vec![Content {
592 role: Some("user".to_string()),
593 parts: Some(vec![
594 Part {
595 text: Some("Hello".to_string()),
596 ..Default::default()
597 },
598 Part {
599 text: Some("World".to_string()),
600 ..Default::default()
601 },
602 ]),
603 }];
604
605 let result = tok.count_tokens(contents.as_slice(), None);
606 let expected = tok.count_tokens("Hello", None).total_tokens
607 + tok.count_tokens("World", None).total_tokens;
608 assert_eq!(result.total_tokens, expected);
609 }
610
611 #[test]
612 fn test_compute_tokens_text() {
613 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
614 let result = tok.compute_tokens("Hello");
615 assert_eq!(result.tokens_info.len(), 1);
616 assert!(!result.tokens_info[0].token_ids.is_empty());
617 assert_eq!(
618 result.tokens_info[0].token_ids.len(),
619 result.tokens_info[0].tokens.len()
620 );
621 assert_eq!(result.tokens_info[0].role, Some("user".to_string()));
622 }
623
624 #[test]
625 fn test_compute_tokens_matches_count() {
626 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
627 let text = "The quick brown fox jumps over the lazy dog.";
628 let count_result = tok.count_tokens(text, None);
629 let compute_result = tok.compute_tokens(text);
630 let total_ids: usize = compute_result
631 .tokens_info
632 .iter()
633 .map(|ti| ti.token_ids.len())
634 .sum();
635 assert_eq!(total_ids, count_result.total_tokens);
636 }
637
638 #[test]
639 fn test_compute_tokens_preserves_role() {
640 let tok = LocalTokenizer::new("gemini-2.5-pro").expect("tokenizer should load");
641 let contents = vec![
642 Content {
643 role: Some("user".to_string()),
644 parts: Some(vec![Part {
645 text: Some("Hello".to_string()),
646 ..Default::default()
647 }]),
648 },
649 Content {
650 role: Some("model".to_string()),
651 parts: Some(vec![Part {
652 text: Some("Hi there".to_string()),
653 ..Default::default()
654 }]),
655 },
656 ];
657 let result = tok.compute_tokens(contents.as_slice());
658 assert_eq!(result.tokens_info.len(), 2);
659 assert_eq!(result.tokens_info[0].role, Some("user".to_string()));
660 assert_eq!(result.tokens_info[1].role, Some("model".to_string()));
661 }
662
663 #[test]
664 fn test_count_tokens_display() {
665 let result = CountTokensResult { total_tokens: 42 };
666 assert_eq!(format!("{}", result), "total_tokens=42");
667 }
668
669 #[test]
670 fn test_tokenizer_error_display() {
671 let err = TokenizerError::ModelLoadError("test error".to_string());
672 assert!(format!("{}", err).contains("test error"));
673
674 let err = TokenizerError::HashMismatch {
675 expected: "aaa".to_string(),
676 actual: "bbb".to_string(),
677 };
678 let msg = format!("{}", err);
679 assert!(msg.contains("aaa"));
680 assert!(msg.contains("bbb"));
681
682 let err = TokenizerError::UnsupportedModel("gpt-4".to_string());
683 let msg = format!("{}", err);
684 assert!(msg.contains("gpt-4"));
685 assert!(msg.contains("not supported"));
686 }
687
688 #[test]
689 fn test_token_piece_to_bytes_normal() {
690 let bytes = token_piece_to_bytes("\u{2581}Hello");
691 assert_eq!(bytes, b" Hello");
692 }
693
694 #[test]
695 fn test_token_piece_to_bytes_hex() {
696 let bytes = token_piece_to_bytes("<0xFF>");
697 assert_eq!(bytes, vec![0xFF]);
698 }
699
700 #[test]
701 fn test_supported_models_list() {
702 let models = supported_models();
703 assert!(models.contains(&"gemini-2.5-pro"));
704 assert!(models.contains(&"gemini-3-pro-preview"));
705 }
706}
707
708