ai_agents_memory/
token_budget.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct MemoryTokenBudget {
9 #[serde(default = "default_total_budget")]
10 pub total: u32,
11
12 #[serde(default)]
13 pub allocation: TokenAllocation,
14
15 #[serde(default)]
16 pub overflow_strategy: OverflowStrategy,
17
18 #[serde(default = "default_warn_percent")]
19 pub warn_at_percent: u8,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TokenAllocation {
24 #[serde(default = "default_summary_tokens")]
25 pub summary: u32,
26
27 #[serde(default = "default_recent_tokens")]
28 pub recent_messages: u32,
29
30 #[serde(default = "default_facts_tokens")]
31 pub facts: u32,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
35#[serde(rename_all = "snake_case")]
36pub enum OverflowStrategy {
37 #[default]
38 TruncateOldest,
39 SummarizeMore,
40 Error,
41}
42
43fn default_total_budget() -> u32 {
44 4096
45}
46
47fn default_summary_tokens() -> u32 {
48 1024
49}
50
51fn default_recent_tokens() -> u32 {
52 2048
53}
54
55fn default_facts_tokens() -> u32 {
56 512
57}
58
59fn default_warn_percent() -> u8 {
60 80
61}
62
63impl Default for TokenAllocation {
64 fn default() -> Self {
65 Self {
66 summary: default_summary_tokens(),
67 recent_messages: default_recent_tokens(),
68 facts: default_facts_tokens(),
69 }
70 }
71}
72
73impl Default for MemoryTokenBudget {
74 fn default() -> Self {
75 Self {
76 total: default_total_budget(),
77 allocation: TokenAllocation::default(),
78 overflow_strategy: OverflowStrategy::default(),
79 warn_at_percent: default_warn_percent(),
80 }
81 }
82}
83
84impl MemoryTokenBudget {
85 pub fn new(total: u32) -> Self {
86 Self {
87 total,
88 ..Default::default()
89 }
90 }
91
92 pub fn with_allocation(mut self, allocation: TokenAllocation) -> Self {
93 self.allocation = allocation;
94 self
95 }
96
97 pub fn with_overflow_strategy(mut self, strategy: OverflowStrategy) -> Self {
98 self.overflow_strategy = strategy;
99 self
100 }
101
102 pub fn with_warn_at_percent(mut self, percent: u8) -> Self {
103 self.warn_at_percent = percent.min(100);
104 self
105 }
106
107 pub fn warn_threshold(&self) -> u32 {
108 (self.total as f64 * (self.warn_at_percent as f64 / 100.0)) as u32
109 }
110
111 pub fn is_over_warn_threshold(&self, used: u32) -> bool {
112 used >= self.warn_threshold()
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, Default)]
117pub struct MemoryBudgetState {
118 pub total_tokens_used: u32,
119 pub summary_tokens: u32,
120 pub recent_tokens: u32,
121 pub facts_tokens: u32,
122 pub last_warning_at: Option<chrono::DateTime<chrono::Utc>>,
123}
124
125impl MemoryBudgetState {
126 pub fn new() -> Self {
127 Self::default()
128 }
129
130 pub fn usage_percent(&self, budget: &MemoryTokenBudget) -> f64 {
131 if budget.total == 0 {
132 return 0.0;
133 }
134 (self.total_tokens_used as f64 / budget.total as f64) * 100.0
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_memory_token_budget_default() {
144 let budget = MemoryTokenBudget::default();
145 assert_eq!(budget.total, 4096);
146 assert_eq!(budget.allocation.summary, 1024);
147 assert_eq!(budget.allocation.recent_messages, 2048);
148 assert_eq!(budget.allocation.facts, 512);
149 assert_eq!(budget.warn_at_percent, 80);
150 }
151
152 #[test]
153 fn test_warn_threshold() {
154 let budget = MemoryTokenBudget::new(1000).with_warn_at_percent(75);
155 assert_eq!(budget.warn_threshold(), 750);
156 assert!(!budget.is_over_warn_threshold(700));
157 assert!(budget.is_over_warn_threshold(750));
158 assert!(budget.is_over_warn_threshold(800));
159 }
160
161 #[test]
162 fn test_budget_state_usage() {
163 let budget = MemoryTokenBudget::new(1000);
164 let mut state = MemoryBudgetState::new();
165 state.total_tokens_used = 500;
166 assert!((state.usage_percent(&budget) - 50.0).abs() < 0.01);
167 }
168
169 #[test]
170 fn test_overflow_strategy_deserialize() {
171 let yaml = r#"truncate_oldest"#;
172 let strategy: OverflowStrategy = serde_yaml::from_str(yaml).unwrap();
173 assert_eq!(strategy, OverflowStrategy::TruncateOldest);
174
175 let yaml = r#"summarize_more"#;
176 let strategy: OverflowStrategy = serde_yaml::from_str(yaml).unwrap();
177 assert_eq!(strategy, OverflowStrategy::SummarizeMore);
178 }
179}