ai_agent/services/session_memory/
session_memory_utils.rs1use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7const EXTRACTION_WAIT_TIMEOUT: Duration = Duration::from_secs(15);
9
10const EXTRACTION_STALE_THRESHOLD: Duration = Duration::from_secs(60);
12
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct SessionMemoryConfig {
16 pub minimum_message_tokens_to_init: u64,
18 pub minimum_tokens_between_update: u64,
20 pub tool_calls_between_updates: u64,
22}
23
24pub const DEFAULT_SESSION_MEMORY_CONFIG: SessionMemoryConfig = SessionMemoryConfig {
26 minimum_message_tokens_to_init: 10_000,
27 minimum_tokens_between_update: 5_000,
28 tool_calls_between_updates: 3,
29};
30
31impl Default for SessionMemoryConfig {
32 fn default() -> Self {
33 Self {
34 minimum_message_tokens_to_init: 10_000,
35 minimum_tokens_between_update: 5_000,
36 tool_calls_between_updates: 3,
37 }
38 }
39}
40
41struct SessionMemoryState {
42 config: SessionMemoryConfig,
43 last_summarized_message_id: Option<String>,
45 extraction_started_at: Option<Instant>,
47 tokens_at_last_extraction: u64,
49 session_memory_initialized: bool,
51 memory_path: Option<String>,
53}
54
55impl Default for SessionMemoryState {
56 fn default() -> Self {
57 Self {
58 config: SessionMemoryConfig::default(),
59 last_summarized_message_id: None,
60 extraction_started_at: None,
61 tokens_at_last_extraction: 0,
62 session_memory_initialized: false,
63 memory_path: None,
64 }
65 }
66}
67
68static STATE: std::sync::LazyLock<Mutex<SessionMemoryState>> =
69 std::sync::LazyLock::new(|| Mutex::new(SessionMemoryState::default()));
70
71pub fn get_last_summarized_message_id() -> Option<String> {
73 STATE.lock().unwrap().last_summarized_message_id.clone()
74}
75
76pub fn set_last_summarized_message_id(id: Option<&str>) {
78 let mut state = STATE.lock().unwrap();
79 state.last_summarized_message_id = id.map(str::to_string);
80}
81
82pub fn mark_extraction_started() {
84 STATE.lock().unwrap().extraction_started_at = Some(Instant::now());
85}
86
87pub fn mark_extraction_completed() {
89 STATE.lock().unwrap().extraction_started_at = None;
90}
91
92pub async fn wait_for_session_memory_extraction() {
95 let start = Instant::now();
96 loop {
97 let started = { STATE.lock().unwrap().extraction_started_at };
98 match started {
99 None => return,
100 Some(t) if t.elapsed() > EXTRACTION_STALE_THRESHOLD => return,
101 _ => {}
102 }
103 if start.elapsed() > EXTRACTION_WAIT_TIMEOUT {
104 return;
105 }
106 tokio::time::sleep(Duration::from_millis(1000)).await;
107 }
108}
109
110pub fn set_session_memory_config(partial: SessionMemoryConfig) {
112 let mut state = STATE.lock().unwrap();
113 if partial.minimum_message_tokens_to_init > 0 {
114 state.config.minimum_message_tokens_to_init = partial.minimum_message_tokens_to_init;
115 }
116 if partial.minimum_tokens_between_update > 0 {
117 state.config.minimum_tokens_between_update = partial.minimum_tokens_between_update;
118 }
119 if partial.tool_calls_between_updates > 0 {
120 state.config.tool_calls_between_updates = partial.tool_calls_between_updates;
121 }
122}
123
124pub fn get_session_memory_config() -> SessionMemoryConfig {
126 STATE.lock().unwrap().config.clone()
127}
128
129pub fn record_extraction_token_count(current_token_count: u64) {
132 STATE.lock().unwrap().tokens_at_last_extraction = current_token_count;
133}
134
135pub fn is_session_memory_initialized() -> bool {
137 STATE.lock().unwrap().session_memory_initialized
138}
139
140pub fn mark_session_memory_initialized() {
142 STATE.lock().unwrap().session_memory_initialized = true;
143}
144
145pub fn has_met_initialization_threshold(current_token_count: u64) -> bool {
148 let state = STATE.lock().unwrap();
149 current_token_count >= state.config.minimum_message_tokens_to_init
150}
151
152pub fn has_met_update_threshold(current_token_count: u64) -> bool {
155 let state = STATE.lock().unwrap();
156 let tokens_since = current_token_count.saturating_sub(state.tokens_at_last_extraction);
157 tokens_since >= state.config.minimum_tokens_between_update
158}
159
160pub fn get_tool_calls_between_updates() -> u64 {
162 STATE.lock().unwrap().config.tool_calls_between_updates
163}
164
165pub fn get_session_memory_path() -> Option<String> {
167 STATE.lock().unwrap().memory_path.clone()
168}
169
170pub fn set_session_memory_path(path: String) {
172 STATE.lock().unwrap().memory_path = Some(path);
173}
174
175pub fn reset_session_memory_state() {
177 *STATE.lock().unwrap() = SessionMemoryState::default();
178}
179
180pub fn get_extraction_wait_timeout() -> Duration {
182 EXTRACTION_WAIT_TIMEOUT
183}
184
185pub fn get_extraction_stale_threshold() -> Duration {
187 EXTRACTION_STALE_THRESHOLD
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_default_config() {
196 reset_session_memory_state();
197 let config = get_session_memory_config();
198 assert_eq!(config.minimum_message_tokens_to_init, 10_000);
199 assert_eq!(config.minimum_tokens_between_update, 5_000);
200 assert_eq!(config.tool_calls_between_updates, 3);
201 }
202
203 #[test]
204 fn test_initialization_threshold() {
205 reset_session_memory_state();
206 assert!(!is_session_memory_initialized());
207 assert!(!has_met_initialization_threshold(5_000));
208 assert!(has_met_initialization_threshold(10_000));
209 }
210
211 #[test]
212 fn test_update_threshold() {
213 reset_session_memory_state();
214 record_extraction_token_count(10_000);
215 assert!(!has_met_update_threshold(12_000));
216 assert!(has_met_update_threshold(15_000));
217 }
218
219 #[test]
220 fn test_extraction_tracking() {
221 reset_session_memory_state();
222 assert!(get_last_summarized_message_id().is_none());
223 set_last_summarized_message_id(Some("msg_123"));
224 assert_eq!(get_last_summarized_message_id(), Some("msg_123".to_string()));
225 }
226
227 #[test]
228 fn test_mark_extraction() {
229 reset_session_memory_state();
230 mark_extraction_started();
231 mark_extraction_completed();
233 }
234}