llm_cost_ops/engine/
normalizer.rs1use crate::domain::{Provider, Result, UsageRecord};
2
3pub struct TokenNormalizer;
5
6impl TokenNormalizer {
7 pub fn new() -> Self {
8 Self
9 }
10
11 pub fn normalize(&self, record: &UsageRecord) -> Result<UsageRecord> {
13 let mut normalized = record.clone();
14
15 let normalization_factor = self.get_normalization_factor(&record.provider);
17
18 if normalization_factor != 1.0 {
19 normalized.prompt_tokens = (record.prompt_tokens as f64 * normalization_factor) as u64;
20 normalized.completion_tokens =
21 (record.completion_tokens as f64 * normalization_factor) as u64;
22 normalized.total_tokens = normalized.prompt_tokens + normalized.completion_tokens;
23 }
24
25 Ok(normalized)
26 }
27
28 fn get_normalization_factor(&self, provider: &Provider) -> f64 {
30 match provider {
31 Provider::OpenAI => 1.0,
32 Provider::Anthropic => 1.0,
33 Provider::GoogleVertexAI => 1.0,
34 Provider::AzureOpenAI => 1.0,
35 Provider::AWSBedrock => 1.0,
36 Provider::Cohere => 1.0,
37 Provider::Mistral => 1.0,
38 Provider::Custom(_) => 1.0,
39 }
40 }
41
42 pub fn estimate_tokens(&self, text: &str, provider: &Provider) -> u64 {
44 let char_count = text.len() as f64;
46 let estimated_tokens = (char_count / 4.0).ceil() as u64;
47
48 match provider {
50 Provider::OpenAI => estimated_tokens,
51 Provider::Anthropic => estimated_tokens,
52 _ => estimated_tokens,
53 }
54 }
55
56 pub fn validate_consistency(&self, record: &UsageRecord) -> Result<bool> {
58 let calculated_total = record.prompt_tokens + record.completion_tokens;
59 let tolerance = 1; Ok(calculated_total.abs_diff(record.total_tokens) <= tolerance)
62 }
63}
64
65impl Default for TokenNormalizer {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use crate::domain::{ModelIdentifier, IngestionSource};
75 use chrono::Utc;
76
77 fn create_test_usage() -> UsageRecord {
78 UsageRecord {
79 id: uuid::Uuid::new_v4(),
80 timestamp: Utc::now(),
81 provider: Provider::OpenAI,
82 model: ModelIdentifier::new("gpt-4".to_string(), 8192),
83 organization_id: "org-test".to_string(),
84 project_id: None,
85 user_id: None,
86 prompt_tokens: 100,
87 completion_tokens: 50,
88 total_tokens: 150,
89 cached_tokens: None,
90 reasoning_tokens: None,
91 latency_ms: None,
92 time_to_first_token_ms: None,
93 tags: vec![],
94 metadata: serde_json::Value::Null,
95 ingested_at: Utc::now(),
96 source: IngestionSource::Api {
97 endpoint: "test".to_string(),
98 },
99 }
100 }
101
102 #[test]
103 fn test_normalize() {
104 let normalizer = TokenNormalizer::new();
105 let usage = create_test_usage();
106
107 let normalized = normalizer.normalize(&usage).unwrap();
108 assert_eq!(normalized.total_tokens, usage.total_tokens);
109 }
110
111 #[test]
112 fn test_estimate_tokens() {
113 let normalizer = TokenNormalizer::new();
114 let text = "This is a test message with approximately twenty words";
115
116 let estimated = normalizer.estimate_tokens(text, &Provider::OpenAI);
117 assert!(estimated > 0);
118 assert!(estimated < 100); }
120
121 #[test]
122 fn test_validate_consistency() {
123 let normalizer = TokenNormalizer::new();
124 let usage = create_test_usage();
125
126 let is_consistent = normalizer.validate_consistency(&usage).unwrap();
127 assert!(is_consistent);
128 }
129
130 #[test]
131 fn test_validate_inconsistent() {
132 let normalizer = TokenNormalizer::new();
133 let mut usage = create_test_usage();
134 usage.total_tokens = 200; let is_consistent = normalizer.validate_consistency(&usage).unwrap();
137 assert!(!is_consistent);
138 }
139}