Skip to main content

aster/ratelimit/
budget.rs

1//! 预算管理
2//!
3//! 跟踪 API 调用成本和预算限制
4
5use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::time::Instant;
8
9/// 成本追踪器
10#[derive(Debug, Clone)]
11pub struct CostTracker {
12    /// 总成本
13    pub total_cost: f64,
14    /// 每个模型的成本
15    pub cost_per_model: HashMap<String, f64>,
16    /// 每个会话的成本
17    pub cost_per_session: HashMap<String, f64>,
18    /// 预算限制
19    pub budget_limit: Option<f64>,
20    /// 上次重置时间
21    pub last_reset: Instant,
22}
23
24impl Default for CostTracker {
25    fn default() -> Self {
26        Self {
27            total_cost: 0.0,
28            cost_per_model: HashMap::new(),
29            cost_per_session: HashMap::new(),
30            budget_limit: None,
31            last_reset: Instant::now(),
32        }
33    }
34}
35
36/// 预算管理器
37pub struct BudgetManager {
38    tracker: RwLock<CostTracker>,
39    budget_limit: RwLock<Option<f64>>,
40}
41
42impl BudgetManager {
43    /// 创建新的预算管理器
44    pub fn new(budget_limit: Option<f64>) -> Self {
45        Self {
46            tracker: RwLock::new(CostTracker {
47                budget_limit,
48                last_reset: Instant::now(),
49                ..Default::default()
50            }),
51            budget_limit: RwLock::new(budget_limit),
52        }
53    }
54
55    /// 添加成本
56    pub fn add_cost(&self, cost: f64, model: Option<&str>, session_id: Option<&str>) {
57        let mut tracker = self.tracker.write();
58        tracker.total_cost += cost;
59
60        if let Some(m) = model {
61            *tracker.cost_per_model.entry(m.to_string()).or_insert(0.0) += cost;
62        }
63
64        if let Some(s) = session_id {
65            *tracker.cost_per_session.entry(s.to_string()).or_insert(0.0) += cost;
66        }
67    }
68
69    /// 检查是否在预算内
70    pub fn is_within_budget(&self) -> bool {
71        let limit = self.budget_limit.read();
72        match *limit {
73            Some(l) => self.tracker.read().total_cost < l,
74            None => true,
75        }
76    }
77
78    /// 获取剩余预算
79    pub fn get_remaining_budget(&self) -> Option<f64> {
80        let limit = self.budget_limit.read();
81        limit.map(|l| (l - self.tracker.read().total_cost).max(0.0))
82    }
83
84    /// 获取追踪器状态
85    pub fn get_tracker(&self) -> CostTracker {
86        self.tracker.read().clone()
87    }
88
89    /// 重置追踪器
90    pub fn reset(&self) {
91        let mut tracker = self.tracker.write();
92        tracker.total_cost = 0.0;
93        tracker.cost_per_model.clear();
94        tracker.cost_per_session.clear();
95        tracker.last_reset = Instant::now();
96    }
97
98    /// 设置预算限制
99    pub fn set_budget_limit(&self, limit: Option<f64>) {
100        *self.budget_limit.write() = limit;
101        self.tracker.write().budget_limit = limit;
102    }
103
104    /// 获取总成本
105    pub fn get_total_cost(&self) -> f64 {
106        self.tracker.read().total_cost
107    }
108
109    /// 获取模型成本
110    pub fn get_model_cost(&self, model: &str) -> f64 {
111        self.tracker
112            .read()
113            .cost_per_model
114            .get(model)
115            .copied()
116            .unwrap_or(0.0)
117    }
118
119    /// 获取会话成本
120    pub fn get_session_cost(&self, session_id: &str) -> f64 {
121        self.tracker
122            .read()
123            .cost_per_session
124            .get(session_id)
125            .copied()
126            .unwrap_or(0.0)
127    }
128}
129
130impl Default for BudgetManager {
131    fn default() -> Self {
132        Self::new(None)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_budget_manager_no_limit() {
142        let manager = BudgetManager::new(None);
143        manager.add_cost(100.0, None, None);
144        assert!(manager.is_within_budget());
145        assert_eq!(manager.get_remaining_budget(), None);
146    }
147
148    #[test]
149    fn test_budget_manager_with_limit() {
150        let manager = BudgetManager::new(Some(100.0));
151        manager.add_cost(50.0, None, None);
152        assert!(manager.is_within_budget());
153        assert_eq!(manager.get_remaining_budget(), Some(50.0));
154
155        manager.add_cost(60.0, None, None);
156        assert!(!manager.is_within_budget());
157        assert_eq!(manager.get_remaining_budget(), Some(0.0));
158    }
159
160    #[test]
161    fn test_cost_tracking() {
162        let manager = BudgetManager::new(None);
163        manager.add_cost(10.0, Some("gpt-4"), Some("session-1"));
164        manager.add_cost(20.0, Some("claude-3"), Some("session-1"));
165        manager.add_cost(15.0, Some("gpt-4"), Some("session-2"));
166
167        assert_eq!(manager.get_total_cost(), 45.0);
168        assert_eq!(manager.get_model_cost("gpt-4"), 25.0);
169        assert_eq!(manager.get_model_cost("claude-3"), 20.0);
170        assert_eq!(manager.get_session_cost("session-1"), 30.0);
171        assert_eq!(manager.get_session_cost("session-2"), 15.0);
172    }
173
174    #[test]
175    fn test_reset() {
176        let manager = BudgetManager::new(Some(100.0));
177        manager.add_cost(50.0, Some("gpt-4"), None);
178        manager.reset();
179
180        assert_eq!(manager.get_total_cost(), 0.0);
181        assert_eq!(manager.get_model_cost("gpt-4"), 0.0);
182    }
183}