aster/ratelimit/
budget.rs1use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::time::Instant;
8
9#[derive(Debug, Clone)]
11pub struct CostTracker {
12 pub total_cost: f64,
14 pub cost_per_model: HashMap<String, f64>,
16 pub cost_per_session: HashMap<String, f64>,
18 pub budget_limit: Option<f64>,
20 pub last_reset: Instant,
22}
23
24impl Default for CostTracker {
25 fn default() -> Self {
26 Self {
27 total_cost: 0.0,
28 cost_per_model: HashMap::new(),
29 cost_per_session: HashMap::new(),
30 budget_limit: None,
31 last_reset: Instant::now(),
32 }
33 }
34}
35
36pub struct BudgetManager {
38 tracker: RwLock<CostTracker>,
39 budget_limit: RwLock<Option<f64>>,
40}
41
42impl BudgetManager {
43 pub fn new(budget_limit: Option<f64>) -> Self {
45 Self {
46 tracker: RwLock::new(CostTracker {
47 budget_limit,
48 last_reset: Instant::now(),
49 ..Default::default()
50 }),
51 budget_limit: RwLock::new(budget_limit),
52 }
53 }
54
55 pub fn add_cost(&self, cost: f64, model: Option<&str>, session_id: Option<&str>) {
57 let mut tracker = self.tracker.write();
58 tracker.total_cost += cost;
59
60 if let Some(m) = model {
61 *tracker.cost_per_model.entry(m.to_string()).or_insert(0.0) += cost;
62 }
63
64 if let Some(s) = session_id {
65 *tracker.cost_per_session.entry(s.to_string()).or_insert(0.0) += cost;
66 }
67 }
68
69 pub fn is_within_budget(&self) -> bool {
71 let limit = self.budget_limit.read();
72 match *limit {
73 Some(l) => self.tracker.read().total_cost < l,
74 None => true,
75 }
76 }
77
78 pub fn get_remaining_budget(&self) -> Option<f64> {
80 let limit = self.budget_limit.read();
81 limit.map(|l| (l - self.tracker.read().total_cost).max(0.0))
82 }
83
84 pub fn get_tracker(&self) -> CostTracker {
86 self.tracker.read().clone()
87 }
88
89 pub fn reset(&self) {
91 let mut tracker = self.tracker.write();
92 tracker.total_cost = 0.0;
93 tracker.cost_per_model.clear();
94 tracker.cost_per_session.clear();
95 tracker.last_reset = Instant::now();
96 }
97
98 pub fn set_budget_limit(&self, limit: Option<f64>) {
100 *self.budget_limit.write() = limit;
101 self.tracker.write().budget_limit = limit;
102 }
103
104 pub fn get_total_cost(&self) -> f64 {
106 self.tracker.read().total_cost
107 }
108
109 pub fn get_model_cost(&self, model: &str) -> f64 {
111 self.tracker
112 .read()
113 .cost_per_model
114 .get(model)
115 .copied()
116 .unwrap_or(0.0)
117 }
118
119 pub fn get_session_cost(&self, session_id: &str) -> f64 {
121 self.tracker
122 .read()
123 .cost_per_session
124 .get(session_id)
125 .copied()
126 .unwrap_or(0.0)
127 }
128}
129
130impl Default for BudgetManager {
131 fn default() -> Self {
132 Self::new(None)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_budget_manager_no_limit() {
142 let manager = BudgetManager::new(None);
143 manager.add_cost(100.0, None, None);
144 assert!(manager.is_within_budget());
145 assert_eq!(manager.get_remaining_budget(), None);
146 }
147
148 #[test]
149 fn test_budget_manager_with_limit() {
150 let manager = BudgetManager::new(Some(100.0));
151 manager.add_cost(50.0, None, None);
152 assert!(manager.is_within_budget());
153 assert_eq!(manager.get_remaining_budget(), Some(50.0));
154
155 manager.add_cost(60.0, None, None);
156 assert!(!manager.is_within_budget());
157 assert_eq!(manager.get_remaining_budget(), Some(0.0));
158 }
159
160 #[test]
161 fn test_cost_tracking() {
162 let manager = BudgetManager::new(None);
163 manager.add_cost(10.0, Some("gpt-4"), Some("session-1"));
164 manager.add_cost(20.0, Some("claude-3"), Some("session-1"));
165 manager.add_cost(15.0, Some("gpt-4"), Some("session-2"));
166
167 assert_eq!(manager.get_total_cost(), 45.0);
168 assert_eq!(manager.get_model_cost("gpt-4"), 25.0);
169 assert_eq!(manager.get_model_cost("claude-3"), 20.0);
170 assert_eq!(manager.get_session_cost("session-1"), 30.0);
171 assert_eq!(manager.get_session_cost("session-2"), 15.0);
172 }
173
174 #[test]
175 fn test_reset() {
176 let manager = BudgetManager::new(Some(100.0));
177 manager.add_cost(50.0, Some("gpt-4"), None);
178 manager.reset();
179
180 assert_eq!(manager.get_total_cost(), 0.0);
181 assert_eq!(manager.get_model_cost("gpt-4"), 0.0);
182 }
183}