Skip to main content

lean_ctx/core/
agent_budget.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use serde::{Deserialize, Serialize};
5
6static BUDGETS: Mutex<Option<HashMap<String, AgentBudget>>> = Mutex::new(None);
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct AgentBudget {
10    pub agent_id: String,
11    pub token_limit: usize,
12    pub tokens_consumed: usize,
13    pub reads_count: u32,
14    pub last_reset: String,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum BudgetCheckResult {
19    Allowed { remaining: usize },
20    Exceeded { limit: usize, consumed: usize },
21    Warning { remaining: usize, percent_used: f32 },
22}
23
24const WARNING_THRESHOLD: f32 = 0.80;
25
26fn with_budgets<F, R>(f: F) -> R
27where
28    F: FnOnce(&mut HashMap<String, AgentBudget>) -> R,
29{
30    let mut guard = BUDGETS
31        .lock()
32        .unwrap_or_else(std::sync::PoisonError::into_inner);
33    let map = guard.get_or_insert_with(HashMap::new);
34    f(map)
35}
36
37fn ensure_entry<'a>(
38    map: &'a mut HashMap<String, AgentBudget>,
39    agent_id: &str,
40) -> &'a mut AgentBudget {
41    map.entry(agent_id.to_string())
42        .or_insert_with(|| AgentBudget {
43            agent_id: agent_id.to_string(),
44            token_limit: usize::MAX,
45            tokens_consumed: 0,
46            reads_count: 0,
47            last_reset: chrono::Utc::now().to_rfc3339(),
48        })
49}
50
51pub fn check_budget(agent_id: &str, tokens_to_consume: usize) -> BudgetCheckResult {
52    with_budgets(|map| {
53        let budget = ensure_entry(map, agent_id);
54        if budget.token_limit == usize::MAX || budget.token_limit == 0 {
55            return BudgetCheckResult::Allowed {
56                remaining: usize::MAX,
57            };
58        }
59
60        let projected = budget.tokens_consumed.saturating_add(tokens_to_consume);
61        if projected > budget.token_limit {
62            return BudgetCheckResult::Exceeded {
63                limit: budget.token_limit,
64                consumed: budget.tokens_consumed,
65            };
66        }
67
68        let percent_used = projected as f32 / budget.token_limit as f32;
69        let remaining = budget.token_limit.saturating_sub(projected);
70
71        if percent_used >= WARNING_THRESHOLD {
72            BudgetCheckResult::Warning {
73                remaining,
74                percent_used,
75            }
76        } else {
77            BudgetCheckResult::Allowed { remaining }
78        }
79    })
80}
81
82pub fn record_consumption(agent_id: &str, tokens: usize) {
83    with_budgets(|map| {
84        let budget = ensure_entry(map, agent_id);
85        budget.tokens_consumed = budget.tokens_consumed.saturating_add(tokens);
86        budget.reads_count += 1;
87    });
88}
89
90pub fn get_status(agent_id: &str) -> AgentBudget {
91    with_budgets(|map| ensure_entry(map, agent_id).clone())
92}
93
94pub fn reset(agent_id: &str) {
95    with_budgets(|map| {
96        let budget = ensure_entry(map, agent_id);
97        budget.tokens_consumed = 0;
98        budget.reads_count = 0;
99        budget.last_reset = chrono::Utc::now().to_rfc3339();
100    });
101}
102
103pub fn set_limit(agent_id: &str, limit: usize) {
104    with_budgets(|map| {
105        let budget = ensure_entry(map, agent_id);
106        budget.token_limit = if limit == 0 { usize::MAX } else { limit };
107    });
108}
109
110pub fn init_from_config() {
111    let cfg_limit = crate::core::config::Config::load().agent_token_budget;
112    if cfg_limit > 0 {
113        with_budgets(|map| {
114            for budget in map.values_mut() {
115                if budget.token_limit == usize::MAX {
116                    budget.token_limit = cfg_limit;
117                }
118            }
119        });
120    }
121}
122
123pub fn default_limit_from_config() -> usize {
124    let cfg_limit = crate::core::config::Config::load().agent_token_budget;
125    if cfg_limit == 0 {
126        usize::MAX
127    } else {
128        cfg_limit
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    fn test_agent(name: &str) -> String {
137        format!("test_agent_{name}_{:?}", std::thread::current().id())
138    }
139
140    #[test]
141    fn unlimited_budget_always_allows() {
142        let id = test_agent("unlimited");
143        let result = check_budget(&id, 1_000_000);
144        assert!(matches!(result, BudgetCheckResult::Allowed { .. }));
145    }
146
147    #[test]
148    fn set_limit_and_exceed() {
149        let id = test_agent("exceed");
150        set_limit(&id, 1000);
151        record_consumption(&id, 800);
152        let result = check_budget(&id, 300);
153        assert!(matches!(
154            result,
155            BudgetCheckResult::Exceeded {
156                limit: 1000,
157                consumed: 800
158            }
159        ));
160    }
161
162    #[test]
163    fn warning_at_80_percent() {
164        let id = test_agent("warning");
165        set_limit(&id, 1000);
166        record_consumption(&id, 700);
167        let result = check_budget(&id, 100);
168        assert!(matches!(result, BudgetCheckResult::Warning { .. }));
169    }
170
171    #[test]
172    fn reset_clears_consumption() {
173        let id = test_agent("reset");
174        set_limit(&id, 1000);
175        record_consumption(&id, 900);
176        reset(&id);
177        let status = get_status(&id);
178        assert_eq!(status.tokens_consumed, 0);
179        assert_eq!(status.reads_count, 0);
180    }
181
182    #[test]
183    fn zero_limit_means_unlimited() {
184        let id = test_agent("zero");
185        set_limit(&id, 0);
186        let result = check_budget(&id, 1_000_000);
187        assert!(matches!(result, BudgetCheckResult::Allowed { .. }));
188    }
189
190    #[test]
191    fn record_increments_reads_count() {
192        let id = test_agent("reads");
193        record_consumption(&id, 100);
194        record_consumption(&id, 200);
195        let status = get_status(&id);
196        assert_eq!(status.reads_count, 2);
197        assert_eq!(status.tokens_consumed, 300);
198    }
199}