Skip to main content

latch_core/
token.rs

1use std::sync::Arc;
2
3/// Token estimator function type.
4/// Callers can inject precise implementations (e.g., using tiktoken) if available.
5pub type TokenEstimator = Arc<dyn Fn(&str) -> usize + Send + Sync>;
6
7/// Wrapper for TokenEstimator that implements Debug
8#[derive(Clone)]
9pub struct TokenEstimatorWrapper(pub TokenEstimator);
10
11impl std::fmt::Debug for TokenEstimatorWrapper {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        write!(f, "TokenEstimatorWrapper")
14    }
15}
16
17/// Default lightweight token estimator.
18/// Uses heuristic based on character types:
19/// - English characters: ~4 chars/token
20/// - CJK characters (inc. Extension A): ~1.5 chars/token
21/// - Other characters: ~4 chars/token
22/// Returns 0 for empty string.
23pub fn default_token_estimator() -> TokenEstimator {
24    Arc::new(|text: &str| {
25        let total: usize = text.chars().count();
26        if total == 0 {
27            return 0;
28        }
29        let english = text.chars().filter(|c| c.is_ascii_alphabetic()).count();
30        let cjk = text
31            .chars()
32            .filter(|c| matches!(c, '\u{4e00}'..='\u{9fff}' | '\u{3400}'..='\u{4dbf}'))
33            .count();
34        let other = total - english - cjk;
35(english / 4).max(1) + (cjk * 2 / 3).max(1) + (other / 4).max(1)
36    })
37}
38
39#[cfg(test)]
40mod tests {
41    use super::*;
42
43    #[test]
44    fn default_estimator_handles_empty_string() {
45        let est = default_token_estimator();
46        assert_eq!((est)(""), 0); // Spec: empty string returns 0
47    }
48
49    #[test]
50    fn default_estimator_english_text() {
51        let est = default_token_estimator();
52        let text = "This is a test of English text";
53        let tokens = (est)(text);
54        // ~31 chars / 4 = ~7-8 tokens
55        assert!(tokens > 0);
56        assert!(tokens < 15);
57    }
58
59    #[test]
60    fn default_estimator_cjk_text() {
61        let est = default_token_estimator();
62        let text = "这是一个测试"; // 6 CJK chars
63        let tokens = (est)(text);
64        // 6 * 2 / 3 = 4 tokens
65        assert!(tokens > 0);
66    }
67}