Skip to main content

sh_layer1/
cost_tracker.rs

1//! 成本追踪模块
2//!
3//! Token 计数、费用计算、预算控制。
4
5use crate::utils::generate_short_id;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Instant;
10
11/// 成本报告
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CostReport {
14    /// 总输入 token 数
15    pub total_input_tokens: u64,
16    /// 总输出 token 数
17    pub total_output_tokens: u64,
18    /// 总成本(美元)
19    pub total_cost_usd: f64,
20    /// 模型成本明细
21    pub model_costs: HashMap<String, ModelCost>,
22    /// 报告时间
23    pub timestamp: String,
24}
25
26/// 单模型成本
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ModelCost {
29    pub input_tokens: u64,
30    pub output_tokens: u64,
31    pub cost_usd: f64,
32}
33
34/// 模型定价(每百万 token)
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ModelPricing {
37    /// 输入 token 价格(美元/百万 token)
38    pub input_price_per_million: f64,
39    /// 输出 token 价格(美元/百万 token)
40    pub output_price_per_million: f64,
41}
42
43impl ModelPricing {
44    /// 计算成本
45    pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
46        let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_price_per_million;
47        let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_price_per_million;
48        input_cost + output_cost
49    }
50}
51
52/// 默认模型定价
53fn default_pricing() -> HashMap<String, ModelPricing> {
54    let mut pricing = HashMap::new();
55
56    // Claude 模型定价
57    pricing.insert(
58        "claude-opus-4-6".to_string(),
59        ModelPricing {
60            input_price_per_million: 15.0,
61            output_price_per_million: 75.0,
62        },
63    );
64    pricing.insert(
65        "claude-sonnet-4-6".to_string(),
66        ModelPricing {
67            input_price_per_million: 3.0,
68            output_price_per_million: 15.0,
69        },
70    );
71    pricing.insert(
72        "claude-haiku-4-5".to_string(),
73        ModelPricing {
74            input_price_per_million: 0.8,
75            output_price_per_million: 4.0,
76        },
77    );
78
79    // OpenAI 模型定价
80    pricing.insert(
81        "claude-sonnet-4-6".to_string(),
82        ModelPricing {
83            input_price_per_million: 3.0,
84            output_price_per_million: 15.0,
85        },
86    );
87    pricing.insert(
88        "gpt-4o-mini".to_string(),
89        ModelPricing {
90            input_price_per_million: 0.15,
91            output_price_per_million: 0.6,
92        },
93    );
94
95    pricing
96}
97
98/// 成本追踪器
99pub struct CostTracker {
100    /// 使用记录
101    usage: RwLock<HashMap<String, UsageRecord>>,
102    /// 模型定价
103    pricing: HashMap<String, ModelPricing>,
104    /// 预算上限
105    budget_limit: RwLock<Option<f64>>,
106    /// 当前总成本
107    current_cost: RwLock<f64>,
108}
109
110/// 使用记录
111#[derive(Debug, Clone)]
112struct UsageRecord {
113    model: String,
114    input_tokens: u64,
115    output_tokens: u64,
116    #[allow(dead_code)]
117    timestamp: Instant,
118}
119
120impl CostTracker {
121    pub fn new() -> Self {
122        Self {
123            usage: RwLock::new(HashMap::new()),
124            pricing: default_pricing(),
125            budget_limit: RwLock::new(None),
126            current_cost: RwLock::new(0.0),
127        }
128    }
129
130    /// 设置预算上限
131    pub fn set_budget_limit(&self, limit: f64) {
132        *self.budget_limit.write() = Some(limit);
133    }
134
135    /// 记录使用
136    pub fn record_usage(
137        &self,
138        model: &str,
139        input_tokens: u64,
140        output_tokens: u64,
141    ) -> anyhow::Result<()> {
142        // 计算成本
143        let pricing = self.pricing.get(model).cloned().unwrap_or(ModelPricing {
144            // 默认定价(中等模型)
145            input_price_per_million: 3.0,
146            output_price_per_million: 15.0,
147        });
148
149        let cost = pricing.calculate_cost(input_tokens, output_tokens);
150
151        // 检查预算
152        let current = *self.current_cost.read();
153        let limit = *self.budget_limit.read();
154
155        if let Some(limit) = limit {
156            if current + cost > limit {
157                return Err(anyhow::anyhow!(
158                    "Budget limit exceeded: current {:.4}, new {:.4}, limit {:.2}",
159                    current,
160                    current + cost,
161                    limit
162                ));
163            }
164        }
165
166        // 更新记录
167        let record_id = generate_short_id();
168        self.usage.write().insert(
169            record_id,
170            UsageRecord {
171                model: model.to_string(),
172                input_tokens,
173                output_tokens,
174                timestamp: Instant::now(),
175            },
176        );
177
178        // 更新当前成本
179        *self.current_cost.write() += cost;
180
181        Ok(())
182    }
183
184    /// 获取当前使用情况
185    pub fn get_current_usage(&self) -> UsageSnapshot {
186        let usage = self.usage.read();
187        let mut model_costs = HashMap::new();
188        let mut total_input = 0;
189        let mut total_output = 0;
190
191        for record in usage.values() {
192            let entry = model_costs
193                .entry(record.model.clone())
194                .or_insert(ModelCost {
195                    input_tokens: 0,
196                    output_tokens: 0,
197                    cost_usd: 0.0,
198                });
199
200            entry.input_tokens += record.input_tokens;
201            entry.output_tokens += record.output_tokens;
202
203            let pricing = self
204                .pricing
205                .get(&record.model)
206                .cloned()
207                .unwrap_or(ModelPricing {
208                    input_price_per_million: 3.0,
209                    output_price_per_million: 15.0,
210                });
211
212            entry.cost_usd += pricing.calculate_cost(record.input_tokens, record.output_tokens);
213
214            total_input += record.input_tokens;
215            total_output += record.output_tokens;
216        }
217
218        UsageSnapshot {
219            total_input_tokens: total_input,
220            total_output_tokens: total_output,
221            total_cost_usd: *self.current_cost.read(),
222            model_costs,
223            budget_remaining: self
224                .budget_limit
225                .read()
226                .map(|limit| limit - *self.current_cost.read()),
227        }
228    }
229
230    /// 预估下一步成本
231    pub fn estimate_next_step(
232        &self,
233        model: &str,
234        estimated_input: u64,
235        estimated_output: u64,
236    ) -> CostEstimate {
237        let pricing = self.pricing.get(model).cloned().unwrap_or(ModelPricing {
238            input_price_per_million: 3.0,
239            output_price_per_million: 15.0,
240        });
241
242        let estimated_cost = pricing.calculate_cost(estimated_input, estimated_output);
243
244        CostEstimate {
245            min_tokens: estimated_input,
246            max_tokens: estimated_input + estimated_output,
247            estimated_cost_usd: estimated_cost,
248            confidence: "medium".to_string(), // [NOTE] 置信度估算为基本启发式,未来可改进
249        }
250    }
251
252    /// 生成报告
253    pub fn generate_report(&self) -> CostReport {
254        let snapshot = self.get_current_usage();
255
256        CostReport {
257            total_input_tokens: snapshot.total_input_tokens,
258            total_output_tokens: snapshot.total_output_tokens,
259            total_cost_usd: snapshot.total_cost_usd,
260            model_costs: snapshot.model_costs,
261            timestamp: chrono::Utc::now().to_rfc3339(),
262        }
263    }
264
265    /// 重置追踪器
266    pub fn reset(&self) {
267        self.usage.write().clear();
268        *self.current_cost.write() = 0.0;
269    }
270}
271
272/// 使用快照
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct UsageSnapshot {
275    pub total_input_tokens: u64,
276    pub total_output_tokens: u64,
277    pub total_cost_usd: f64,
278    pub model_costs: HashMap<String, ModelCost>,
279    pub budget_remaining: Option<f64>,
280}
281
282/// 成本预估
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct CostEstimate {
285    pub min_tokens: u64,
286    pub max_tokens: u64,
287    pub estimated_cost_usd: f64,
288    pub confidence: String,
289}
290
291impl Default for CostTracker {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_pricing_calculation() {
303        let pricing = ModelPricing {
304            input_price_per_million: 3.0,
305            output_price_per_million: 15.0,
306        };
307
308        // 1000 input + 500 output tokens
309        let cost = pricing.calculate_cost(1000, 500);
310        assert!(cost > 0.0);
311        assert!(cost < 1.0); // 应该小于 1 美元
312    }
313
314    #[test]
315    fn test_usage_tracking() {
316        let tracker = CostTracker::new();
317
318        tracker
319            .record_usage("claude-sonnet-4-6", 1000, 500)
320            .unwrap();
321
322        let snapshot = tracker.get_current_usage();
323        assert_eq!(snapshot.total_input_tokens, 1000);
324        assert_eq!(snapshot.total_output_tokens, 500);
325    }
326
327    #[test]
328    fn test_budget_limit() {
329        let tracker = CostTracker::new();
330        tracker.set_budget_limit(0.01); // 1 美分
331
332        // 第一次应该成功
333        tracker.record_usage("claude-sonnet-4-6", 100, 50).unwrap();
334
335        // 第二次可能超出预算
336        let result = tracker.record_usage("claude-sonnet-4-6", 10000, 5000);
337        assert!(result.is_err());
338    }
339
340    #[test]
341    fn test_multiple_models() {
342        let tracker = CostTracker::new();
343
344        // 记录不同模型的使用
345        tracker.record_usage("claude-opus-4-6", 1000, 500).unwrap();
346        tracker
347            .record_usage("claude-sonnet-4-6", 2000, 1000)
348            .unwrap();
349        tracker.record_usage("claude-haiku-4-5", 500, 250).unwrap();
350        tracker.record_usage("gpt-4o", 1500, 750).unwrap();
351        tracker.record_usage("gpt-4o-mini", 3000, 1500).unwrap();
352
353        let snapshot = tracker.get_current_usage();
354
355        // 验证总 token 数
356        assert_eq!(snapshot.total_input_tokens, 8000);
357        assert_eq!(snapshot.total_output_tokens, 4000);
358
359        // 验证每个模型都有记录
360        assert!(snapshot.model_costs.contains_key("claude-opus-4-6"));
361        assert!(snapshot.model_costs.contains_key("claude-sonnet-4-6"));
362        assert!(snapshot.model_costs.contains_key("claude-haiku-4-5"));
363        assert!(snapshot.model_costs.contains_key("gpt-4o"));
364        assert!(snapshot.model_costs.contains_key("gpt-4o-mini"));
365
366        // 验证总成本大于 0
367        assert!(snapshot.total_cost_usd > 0.0);
368
369        // 验证不同模型有不同成本
370        let opus_cost = snapshot
371            .model_costs
372            .get("claude-opus-4-6")
373            .unwrap()
374            .cost_usd;
375        let haiku_cost = snapshot
376            .model_costs
377            .get("claude-haiku-4-5")
378            .unwrap()
379            .cost_usd;
380
381        // Opus 应该比 Haiku 贵(即使 token 数相同)
382        assert!(
383            opus_cost > haiku_cost,
384            "Opus should be more expensive than Haiku"
385        );
386    }
387
388    #[test]
389    fn test_budget_reset() {
390        let tracker = CostTracker::new();
391        tracker.set_budget_limit(1.0);
392
393        // 消耗一些预算
394        tracker
395            .record_usage("claude-sonnet-4-6", 5000, 2500)
396            .unwrap();
397        let snapshot = tracker.get_current_usage();
398        assert!(snapshot.total_cost_usd > 0.0);
399        assert!(snapshot.budget_remaining.is_some());
400        assert!(snapshot.budget_remaining.unwrap() < 1.0);
401
402        // 重置
403        tracker.reset();
404
405        // 验证重置后状态
406        let snapshot = tracker.get_current_usage();
407        assert_eq!(snapshot.total_input_tokens, 0);
408        assert_eq!(snapshot.total_output_tokens, 0);
409        assert_eq!(snapshot.total_cost_usd, 0.0);
410        assert!(snapshot.model_costs.is_empty());
411
412        // 预算限制应该仍然有效 - 使用较小的用量
413        tracker
414            .record_usage("claude-sonnet-4-6", 1000, 500)
415            .unwrap();
416        let snapshot = tracker.get_current_usage();
417        assert!(snapshot.total_cost_usd > 0.0);
418    }
419
420    #[test]
421    fn test_concurrent_recording() {
422        use std::sync::Arc;
423        use std::thread;
424
425        let tracker = Arc::new(CostTracker::new());
426        let mut handles = vec![];
427
428        for i in 0..10 {
429            let t = Arc::clone(&tracker);
430            handles.push(thread::spawn(move || {
431                let model = match i % 3 {
432                    0 => "claude-opus-4-6",
433                    1 => "claude-sonnet-4-6",
434                    _ => "claude-haiku-4-5",
435                };
436                t.record_usage(model, 100, 50).unwrap()
437            }));
438        }
439
440        // 所有记录都应该成功
441        for handle in handles {
442            handle.join().unwrap();
443        }
444
445        let snapshot = tracker.get_current_usage();
446        assert_eq!(snapshot.total_input_tokens, 1000);
447        assert_eq!(snapshot.total_output_tokens, 500);
448    }
449
450    #[test]
451    fn test_unknown_model_pricing() {
452        let tracker = CostTracker::new();
453
454        // 使用未知模型应该使用默认定价
455        tracker.record_usage("unknown-model", 1000, 500).unwrap();
456
457        let snapshot = tracker.get_current_usage();
458        assert!(snapshot.model_costs.contains_key("unknown-model"));
459        // 验证使用了默认定价(应该比 haiku 贵,比 opus 便宜)
460        let cost = snapshot.model_costs.get("unknown-model").unwrap().cost_usd;
461        assert!(cost > 0.0);
462    }
463
464    #[test]
465    fn test_estimate_next_step() {
466        let tracker = CostTracker::new();
467
468        let estimate = tracker.estimate_next_step("claude-sonnet-4-6", 1000, 500);
469        assert_eq!(estimate.min_tokens, 1000);
470        assert_eq!(estimate.max_tokens, 1500);
471        assert!(estimate.estimated_cost_usd > 0.0);
472    }
473
474    #[test]
475    fn test_generate_report() {
476        let tracker = CostTracker::new();
477
478        tracker
479            .record_usage("claude-sonnet-4-6", 1000, 500)
480            .unwrap();
481
482        let report = tracker.generate_report();
483        assert_eq!(report.total_input_tokens, 1000);
484        assert_eq!(report.total_output_tokens, 500);
485        assert!(!report.timestamp.is_empty());
486    }
487
488    #[test]
489    fn test_budget_remaining_calculation() {
490        let tracker = CostTracker::new();
491        tracker.set_budget_limit(1.0); // $1
492
493        tracker
494            .record_usage("claude-sonnet-4-6", 1000, 500)
495            .unwrap();
496
497        let snapshot = tracker.get_current_usage();
498        assert!(snapshot.budget_remaining.is_some());
499        let remaining = snapshot.budget_remaining.unwrap();
500
501        // 剩余预算应该小于总预算
502        assert!(remaining < 1.0);
503        assert!(remaining > 0.9); // 大部分预算应该剩余
504    }
505}