Skip to main content

ai_agents_memory/
token_budget.rs

1//! Token budget management for memory
2
3use serde::{Deserialize, Serialize};
4
5// !!
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MemoryTokenBudget {
9    #[serde(default = "default_total_budget")]
10    pub total: u32,
11
12    #[serde(default)]
13    pub allocation: TokenAllocation,
14
15    #[serde(default)]
16    pub overflow_strategy: OverflowStrategy,
17
18    #[serde(default = "default_warn_percent")]
19    pub warn_at_percent: u8,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TokenAllocation {
24    #[serde(default = "default_summary_tokens")]
25    pub summary: u32,
26
27    #[serde(default = "default_recent_tokens")]
28    pub recent_messages: u32,
29
30    #[serde(default = "default_facts_tokens")]
31    pub facts: u32,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
35#[serde(rename_all = "snake_case")]
36pub enum OverflowStrategy {
37    #[default]
38    TruncateOldest,
39    SummarizeMore,
40    Error,
41}
42
43fn default_total_budget() -> u32 {
44    4096
45}
46
47fn default_summary_tokens() -> u32 {
48    1024
49}
50
51fn default_recent_tokens() -> u32 {
52    2048
53}
54
55fn default_facts_tokens() -> u32 {
56    512
57}
58
59fn default_warn_percent() -> u8 {
60    80
61}
62
63impl Default for TokenAllocation {
64    fn default() -> Self {
65        Self {
66            summary: default_summary_tokens(),
67            recent_messages: default_recent_tokens(),
68            facts: default_facts_tokens(),
69        }
70    }
71}
72
73impl Default for MemoryTokenBudget {
74    fn default() -> Self {
75        Self {
76            total: default_total_budget(),
77            allocation: TokenAllocation::default(),
78            overflow_strategy: OverflowStrategy::default(),
79            warn_at_percent: default_warn_percent(),
80        }
81    }
82}
83
84impl MemoryTokenBudget {
85    pub fn new(total: u32) -> Self {
86        Self {
87            total,
88            ..Default::default()
89        }
90    }
91
92    pub fn with_allocation(mut self, allocation: TokenAllocation) -> Self {
93        self.allocation = allocation;
94        self
95    }
96
97    pub fn with_overflow_strategy(mut self, strategy: OverflowStrategy) -> Self {
98        self.overflow_strategy = strategy;
99        self
100    }
101
102    pub fn with_warn_at_percent(mut self, percent: u8) -> Self {
103        self.warn_at_percent = percent.min(100);
104        self
105    }
106
107    pub fn warn_threshold(&self) -> u32 {
108        (self.total as f64 * (self.warn_at_percent as f64 / 100.0)) as u32
109    }
110
111    pub fn is_over_warn_threshold(&self, used: u32) -> bool {
112        used >= self.warn_threshold()
113    }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, Default)]
117pub struct MemoryBudgetState {
118    pub total_tokens_used: u32,
119    pub summary_tokens: u32,
120    pub recent_tokens: u32,
121    pub facts_tokens: u32,
122    pub last_warning_at: Option<chrono::DateTime<chrono::Utc>>,
123}
124
125impl MemoryBudgetState {
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    pub fn usage_percent(&self, budget: &MemoryTokenBudget) -> f64 {
131        if budget.total == 0 {
132            return 0.0;
133        }
134        (self.total_tokens_used as f64 / budget.total as f64) * 100.0
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_memory_token_budget_default() {
144        let budget = MemoryTokenBudget::default();
145        assert_eq!(budget.total, 4096);
146        assert_eq!(budget.allocation.summary, 1024);
147        assert_eq!(budget.allocation.recent_messages, 2048);
148        assert_eq!(budget.allocation.facts, 512);
149        assert_eq!(budget.warn_at_percent, 80);
150    }
151
152    #[test]
153    fn test_warn_threshold() {
154        let budget = MemoryTokenBudget::new(1000).with_warn_at_percent(75);
155        assert_eq!(budget.warn_threshold(), 750);
156        assert!(!budget.is_over_warn_threshold(700));
157        assert!(budget.is_over_warn_threshold(750));
158        assert!(budget.is_over_warn_threshold(800));
159    }
160
161    #[test]
162    fn test_budget_state_usage() {
163        let budget = MemoryTokenBudget::new(1000);
164        let mut state = MemoryBudgetState::new();
165        state.total_tokens_used = 500;
166        assert!((state.usage_percent(&budget) - 50.0).abs() < 0.01);
167    }
168
169    #[test]
170    fn test_overflow_strategy_deserialize() {
171        let yaml = r#"truncate_oldest"#;
172        let strategy: OverflowStrategy = serde_yaml::from_str(yaml).unwrap();
173        assert_eq!(strategy, OverflowStrategy::TruncateOldest);
174
175        let yaml = r#"summarize_more"#;
176        let strategy: OverflowStrategy = serde_yaml::from_str(yaml).unwrap();
177        assert_eq!(strategy, OverflowStrategy::SummarizeMore);
178    }
179}