Skip to main content

ai_agents_memory/
token_budget.rs

1//! Token budget management for memory
2
3use serde::{Deserialize, Serialize};
4
5// !!
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MemoryTokenBudget {
9    #[serde(default = "default_total_budget")]
10    pub total: u32,
11
12    #[serde(default)]
13    pub allocation: TokenAllocation,
14
15    #[serde(default)]
16    pub overflow_strategy: OverflowStrategy,
17
18    #[serde(default = "default_warn_percent")]
19    pub warn_at_percent: u8,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TokenAllocation {
24    #[serde(default = "default_summary_tokens")]
25    pub summary: u32,
26
27    #[serde(default = "default_recent_tokens")]
28    pub recent_messages: u32,
29
30    #[serde(default = "default_facts_tokens")]
31    pub facts: u32,
32
33    #[serde(default = "default_relationships_tokens")]
34    pub relationships: u32,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
38#[serde(rename_all = "snake_case")]
39pub enum OverflowStrategy {
40    #[default]
41    TruncateOldest,
42    SummarizeMore,
43    Error,
44}
45
46fn default_total_budget() -> u32 {
47    4096
48}
49
50fn default_summary_tokens() -> u32 {
51    1024
52}
53
54fn default_recent_tokens() -> u32 {
55    2048
56}
57
58fn default_facts_tokens() -> u32 {
59    512
60}
61
62fn default_relationships_tokens() -> u32 {
63    0
64}
65
66fn default_warn_percent() -> u8 {
67    80
68}
69
70impl Default for TokenAllocation {
71    fn default() -> Self {
72        Self {
73            summary: default_summary_tokens(),
74            recent_messages: default_recent_tokens(),
75            facts: default_facts_tokens(),
76            relationships: default_relationships_tokens(),
77        }
78    }
79}
80
81impl Default for MemoryTokenBudget {
82    fn default() -> Self {
83        Self {
84            total: default_total_budget(),
85            allocation: TokenAllocation::default(),
86            overflow_strategy: OverflowStrategy::default(),
87            warn_at_percent: default_warn_percent(),
88        }
89    }
90}
91
92impl MemoryTokenBudget {
93    pub fn new(total: u32) -> Self {
94        Self {
95            total,
96            ..Default::default()
97        }
98    }
99
100    pub fn with_allocation(mut self, allocation: TokenAllocation) -> Self {
101        self.allocation = allocation;
102        self
103    }
104
105    pub fn with_overflow_strategy(mut self, strategy: OverflowStrategy) -> Self {
106        self.overflow_strategy = strategy;
107        self
108    }
109
110    pub fn with_warn_at_percent(mut self, percent: u8) -> Self {
111        self.warn_at_percent = percent.min(100);
112        self
113    }
114
115    pub fn warn_threshold(&self) -> u32 {
116        (self.total as f64 * (self.warn_at_percent as f64 / 100.0)) as u32
117    }
118
119    pub fn is_over_warn_threshold(&self, used: u32) -> bool {
120        used >= self.warn_threshold()
121    }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, Default)]
125pub struct MemoryBudgetState {
126    pub total_tokens_used: u32,
127    pub summary_tokens: u32,
128    pub recent_tokens: u32,
129    pub facts_tokens: u32,
130    pub relationships_tokens: u32,
131    pub last_warning_at: Option<chrono::DateTime<chrono::Utc>>,
132}
133
134impl MemoryBudgetState {
135    pub fn new() -> Self {
136        Self::default()
137    }
138
139    pub fn usage_percent(&self, budget: &MemoryTokenBudget) -> f64 {
140        if budget.total == 0 {
141            return 0.0;
142        }
143        (self.total_tokens_used as f64 / budget.total as f64) * 100.0
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_memory_token_budget_default() {
153        let budget = MemoryTokenBudget::default();
154        assert_eq!(budget.total, 4096);
155        assert_eq!(budget.allocation.summary, 1024);
156        assert_eq!(budget.allocation.recent_messages, 2048);
157        assert_eq!(budget.allocation.facts, 512);
158        assert_eq!(budget.allocation.relationships, 0);
159        assert_eq!(budget.warn_at_percent, 80);
160    }
161
162    #[test]
163    fn test_warn_threshold() {
164        let budget = MemoryTokenBudget::new(1000).with_warn_at_percent(75);
165        assert_eq!(budget.warn_threshold(), 750);
166        assert!(!budget.is_over_warn_threshold(700));
167        assert!(budget.is_over_warn_threshold(750));
168        assert!(budget.is_over_warn_threshold(800));
169    }
170
171    #[test]
172    fn test_budget_state_usage() {
173        let budget = MemoryTokenBudget::new(1000);
174        let mut state = MemoryBudgetState::new();
175        state.total_tokens_used = 500;
176        assert!((state.usage_percent(&budget) - 50.0).abs() < 0.01);
177    }
178
179    #[test]
180    fn test_overflow_strategy_deserialize() {
181        let yaml = r#"truncate_oldest"#;
182        let strategy: OverflowStrategy = serde_yaml::from_str(yaml).unwrap();
183        assert_eq!(strategy, OverflowStrategy::TruncateOldest);
184
185        let yaml = r#"summarize_more"#;
186        let strategy: OverflowStrategy = serde_yaml::from_str(yaml).unwrap();
187        assert_eq!(strategy, OverflowStrategy::SummarizeMore);
188    }
189}