llm_cost_ops/engine/
normalizer.rs

1use crate::domain::{Provider, Result, UsageRecord};
2
3/// Token normalizer for handling provider-specific token counting
4pub struct TokenNormalizer;
5
6impl TokenNormalizer {
7    pub fn new() -> Self {
8        Self
9    }
10
11    /// Normalize token counts across providers
12    pub fn normalize(&self, record: &UsageRecord) -> Result<UsageRecord> {
13        let mut normalized = record.clone();
14
15        // Apply provider-specific normalization factors
16        let normalization_factor = self.get_normalization_factor(&record.provider);
17
18        if normalization_factor != 1.0 {
19            normalized.prompt_tokens = (record.prompt_tokens as f64 * normalization_factor) as u64;
20            normalized.completion_tokens =
21                (record.completion_tokens as f64 * normalization_factor) as u64;
22            normalized.total_tokens = normalized.prompt_tokens + normalized.completion_tokens;
23        }
24
25        Ok(normalized)
26    }
27
28    /// Get provider-specific normalization factor
29    fn get_normalization_factor(&self, provider: &Provider) -> f64 {
30        match provider {
31            Provider::OpenAI => 1.0,
32            Provider::Anthropic => 1.0,
33            Provider::GoogleVertexAI => 1.0,
34            Provider::AzureOpenAI => 1.0,
35            Provider::AWSBedrock => 1.0,
36            Provider::Cohere => 1.0,
37            Provider::Mistral => 1.0,
38            Provider::Custom(_) => 1.0,
39        }
40    }
41
42    /// Estimate tokens from text if not provided
43    pub fn estimate_tokens(&self, text: &str, provider: &Provider) -> u64 {
44        // Conservative estimation: ~4 characters per token
45        let char_count = text.len() as f64;
46        let estimated_tokens = (char_count / 4.0).ceil() as u64;
47
48        // Apply provider-specific adjustment
49        match provider {
50            Provider::OpenAI => estimated_tokens,
51            Provider::Anthropic => estimated_tokens,
52            _ => estimated_tokens,
53        }
54    }
55
56    /// Validate token consistency
57    pub fn validate_consistency(&self, record: &UsageRecord) -> Result<bool> {
58        let calculated_total = record.prompt_tokens + record.completion_tokens;
59        let tolerance = 1; // Allow 1 token difference for rounding
60
61        Ok(calculated_total.abs_diff(record.total_tokens) <= tolerance)
62    }
63}
64
65impl Default for TokenNormalizer {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use crate::domain::{ModelIdentifier, IngestionSource};
75    use chrono::Utc;
76
77    fn create_test_usage() -> UsageRecord {
78        UsageRecord {
79            id: uuid::Uuid::new_v4(),
80            timestamp: Utc::now(),
81            provider: Provider::OpenAI,
82            model: ModelIdentifier::new("gpt-4".to_string(), 8192),
83            organization_id: "org-test".to_string(),
84            project_id: None,
85            user_id: None,
86            prompt_tokens: 100,
87            completion_tokens: 50,
88            total_tokens: 150,
89            cached_tokens: None,
90            reasoning_tokens: None,
91            latency_ms: None,
92            time_to_first_token_ms: None,
93            tags: vec![],
94            metadata: serde_json::Value::Null,
95            ingested_at: Utc::now(),
96            source: IngestionSource::Api {
97                endpoint: "test".to_string(),
98            },
99        }
100    }
101
102    #[test]
103    fn test_normalize() {
104        let normalizer = TokenNormalizer::new();
105        let usage = create_test_usage();
106
107        let normalized = normalizer.normalize(&usage).unwrap();
108        assert_eq!(normalized.total_tokens, usage.total_tokens);
109    }
110
111    #[test]
112    fn test_estimate_tokens() {
113        let normalizer = TokenNormalizer::new();
114        let text = "This is a test message with approximately twenty words";
115
116        let estimated = normalizer.estimate_tokens(text, &Provider::OpenAI);
117        assert!(estimated > 0);
118        assert!(estimated < 100); // Reasonable upper bound
119    }
120
121    #[test]
122    fn test_validate_consistency() {
123        let normalizer = TokenNormalizer::new();
124        let usage = create_test_usage();
125
126        let is_consistent = normalizer.validate_consistency(&usage).unwrap();
127        assert!(is_consistent);
128    }
129
130    #[test]
131    fn test_validate_inconsistent() {
132        let normalizer = TokenNormalizer::new();
133        let mut usage = create_test_usage();
134        usage.total_tokens = 200; // Intentional mismatch
135
136        let is_consistent = normalizer.validate_consistency(&usage).unwrap();
137        assert!(!is_consistent);
138    }
139}