lean_ctx/core/
agent_budget.rs1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use serde::{Deserialize, Serialize};
5
6static BUDGETS: Mutex<Option<HashMap<String, AgentBudget>>> = Mutex::new(None);
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct AgentBudget {
10 pub agent_id: String,
11 pub token_limit: usize,
12 pub tokens_consumed: usize,
13 pub reads_count: u32,
14 pub last_reset: String,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum BudgetCheckResult {
19 Allowed { remaining: usize },
20 Exceeded { limit: usize, consumed: usize },
21 Warning { remaining: usize, percent_used: f32 },
22}
23
24const WARNING_THRESHOLD: f32 = 0.80;
25
26fn with_budgets<F, R>(f: F) -> R
27where
28 F: FnOnce(&mut HashMap<String, AgentBudget>) -> R,
29{
30 let mut guard = BUDGETS
31 .lock()
32 .unwrap_or_else(std::sync::PoisonError::into_inner);
33 let map = guard.get_or_insert_with(HashMap::new);
34 f(map)
35}
36
37fn ensure_entry<'a>(
38 map: &'a mut HashMap<String, AgentBudget>,
39 agent_id: &str,
40) -> &'a mut AgentBudget {
41 map.entry(agent_id.to_string())
42 .or_insert_with(|| AgentBudget {
43 agent_id: agent_id.to_string(),
44 token_limit: usize::MAX,
45 tokens_consumed: 0,
46 reads_count: 0,
47 last_reset: chrono::Utc::now().to_rfc3339(),
48 })
49}
50
51pub fn check_budget(agent_id: &str, tokens_to_consume: usize) -> BudgetCheckResult {
52 with_budgets(|map| {
53 let budget = ensure_entry(map, agent_id);
54 if budget.token_limit == usize::MAX || budget.token_limit == 0 {
55 return BudgetCheckResult::Allowed {
56 remaining: usize::MAX,
57 };
58 }
59
60 let projected = budget.tokens_consumed.saturating_add(tokens_to_consume);
61 if projected > budget.token_limit {
62 return BudgetCheckResult::Exceeded {
63 limit: budget.token_limit,
64 consumed: budget.tokens_consumed,
65 };
66 }
67
68 let percent_used = projected as f32 / budget.token_limit as f32;
69 let remaining = budget.token_limit.saturating_sub(projected);
70
71 if percent_used >= WARNING_THRESHOLD {
72 BudgetCheckResult::Warning {
73 remaining,
74 percent_used,
75 }
76 } else {
77 BudgetCheckResult::Allowed { remaining }
78 }
79 })
80}
81
82pub fn record_consumption(agent_id: &str, tokens: usize) {
83 with_budgets(|map| {
84 let budget = ensure_entry(map, agent_id);
85 budget.tokens_consumed = budget.tokens_consumed.saturating_add(tokens);
86 budget.reads_count += 1;
87 });
88}
89
90pub fn get_status(agent_id: &str) -> AgentBudget {
91 with_budgets(|map| ensure_entry(map, agent_id).clone())
92}
93
94pub fn reset(agent_id: &str) {
95 with_budgets(|map| {
96 let budget = ensure_entry(map, agent_id);
97 budget.tokens_consumed = 0;
98 budget.reads_count = 0;
99 budget.last_reset = chrono::Utc::now().to_rfc3339();
100 });
101}
102
103pub fn set_limit(agent_id: &str, limit: usize) {
104 with_budgets(|map| {
105 let budget = ensure_entry(map, agent_id);
106 budget.token_limit = if limit == 0 { usize::MAX } else { limit };
107 });
108}
109
110pub fn init_from_config() {
111 let cfg_limit = crate::core::config::Config::load().agent_token_budget;
112 if cfg_limit > 0 {
113 with_budgets(|map| {
114 for budget in map.values_mut() {
115 if budget.token_limit == usize::MAX {
116 budget.token_limit = cfg_limit;
117 }
118 }
119 });
120 }
121}
122
123pub fn default_limit_from_config() -> usize {
124 let cfg_limit = crate::core::config::Config::load().agent_token_budget;
125 if cfg_limit == 0 {
126 usize::MAX
127 } else {
128 cfg_limit
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 fn test_agent(name: &str) -> String {
137 format!("test_agent_{name}_{:?}", std::thread::current().id())
138 }
139
140 #[test]
141 fn unlimited_budget_always_allows() {
142 let id = test_agent("unlimited");
143 let result = check_budget(&id, 1_000_000);
144 assert!(matches!(result, BudgetCheckResult::Allowed { .. }));
145 }
146
147 #[test]
148 fn set_limit_and_exceed() {
149 let id = test_agent("exceed");
150 set_limit(&id, 1000);
151 record_consumption(&id, 800);
152 let result = check_budget(&id, 300);
153 assert!(matches!(
154 result,
155 BudgetCheckResult::Exceeded {
156 limit: 1000,
157 consumed: 800
158 }
159 ));
160 }
161
162 #[test]
163 fn warning_at_80_percent() {
164 let id = test_agent("warning");
165 set_limit(&id, 1000);
166 record_consumption(&id, 700);
167 let result = check_budget(&id, 100);
168 assert!(matches!(result, BudgetCheckResult::Warning { .. }));
169 }
170
171 #[test]
172 fn reset_clears_consumption() {
173 let id = test_agent("reset");
174 set_limit(&id, 1000);
175 record_consumption(&id, 900);
176 reset(&id);
177 let status = get_status(&id);
178 assert_eq!(status.tokens_consumed, 0);
179 assert_eq!(status.reads_count, 0);
180 }
181
182 #[test]
183 fn zero_limit_means_unlimited() {
184 let id = test_agent("zero");
185 set_limit(&id, 0);
186 let result = check_budget(&id, 1_000_000);
187 assert!(matches!(result, BudgetCheckResult::Allowed { .. }));
188 }
189
190 #[test]
191 fn record_increments_reads_count() {
192 let id = test_agent("reads");
193 record_consumption(&id, 100);
194 record_consumption(&id, 200);
195 let status = get_status(&id);
196 assert_eq!(status.reads_count, 2);
197 assert_eq!(status.tokens_consumed, 300);
198 }
199}