ai_agent/
session_memory.rs1use crate::constants::env::system;
10use crate::types::*;
11use crate::AgentError;
12use std::path::PathBuf;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::LazyLock;
15use std::sync::Mutex;
16
17pub const DEFAULT_SESSION_MEMORY_CONFIG: SessionMemoryConfig = SessionMemoryConfig {
19 minimum_message_tokens_to_init: 10000,
20 minimum_tokens_between_update: 5000,
21 tool_calls_between_updates: 3,
22};
23
24#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
26pub struct SessionMemoryConfig {
27 pub minimum_message_tokens_to_init: u32,
29 pub minimum_tokens_between_update: u32,
31 pub tool_calls_between_updates: u32,
33}
34
35impl Default for SessionMemoryConfig {
36 fn default() -> Self {
37 DEFAULT_SESSION_MEMORY_CONFIG
38 }
39}
40
41pub struct SessionMemoryState {
43 config: Mutex<SessionMemoryConfig>,
44 initialized: AtomicBool,
45 tokens_at_last_extraction: AtomicU64,
46 last_summarized_index: Mutex<Option<usize>>,
48 extraction_in_progress: AtomicBool,
49}
50
51impl SessionMemoryState {
52 pub fn new() -> Self {
53 Self {
54 config: Mutex::new(DEFAULT_SESSION_MEMORY_CONFIG),
55 initialized: AtomicBool::new(false),
56 tokens_at_last_extraction: AtomicU64::new(0),
57 last_summarized_index: Mutex::new(None),
58 extraction_in_progress: AtomicBool::new(false),
59 }
60 }
61
62 pub fn is_initialized(&self) -> bool {
63 self.initialized.load(Ordering::SeqCst)
64 }
65
66 pub fn mark_initialized(&self) {
67 self.initialized.store(true, Ordering::SeqCst);
68 }
69
70 pub fn get_config(&self) -> SessionMemoryConfig {
71 self.config.lock().unwrap().clone()
72 }
73
74 pub fn set_config(&self, config: SessionMemoryConfig) {
75 *self.config.lock().unwrap() = config;
76 }
77
78 pub fn get_tokens_at_last_extraction(&self) -> u64 {
79 self.tokens_at_last_extraction.load(Ordering::SeqCst)
80 }
81
82 pub fn set_tokens_at_last_extraction(&self, tokens: u64) {
83 self.tokens_at_last_extraction
84 .store(tokens, Ordering::SeqCst);
85 }
86
87 pub fn get_last_summarized_index(&self) -> Option<usize> {
88 *self.last_summarized_index.lock().unwrap()
89 }
90
91 pub fn set_last_summarized_index(&self, index: Option<usize>) {
92 *self.last_summarized_index.lock().unwrap() = index;
93 }
94
95 pub fn is_extraction_in_progress(&self) -> bool {
96 self.extraction_in_progress.load(Ordering::SeqCst)
97 }
98
99 pub fn start_extraction(&self) {
100 self.extraction_in_progress.store(true, Ordering::SeqCst);
101 }
102
103 pub fn end_extraction(&self) {
104 self.extraction_in_progress.store(false, Ordering::SeqCst);
105 }
106}
107
108impl Default for SessionMemoryState {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114static SESSION_MEMORY_STATE: LazyLock<SessionMemoryState> = LazyLock::new(SessionMemoryState::new);
116
117pub fn get_session_memory_state() -> &'static SessionMemoryState {
119 &SESSION_MEMORY_STATE
120}
121
122pub fn get_session_memory_dir() -> PathBuf {
124 let home = std::env::var(system::HOME)
125 .or_else(|_| std::env::var(system::USERPROFILE))
126 .unwrap_or_else(|_| "/tmp".to_string());
127 PathBuf::from(home)
128 .join(".open-agent-sdk")
129 .join("session_memory")
130}
131
132pub fn get_session_memory_path() -> PathBuf {
134 get_session_memory_dir().join("notes.md")
135}
136
137pub fn is_session_memory_initialized() -> bool {
139 SESSION_MEMORY_STATE.is_initialized()
140}
141
142pub fn mark_session_memory_initialized() {
144 SESSION_MEMORY_STATE.mark_initialized();
145}
146
147pub fn get_session_memory_config() -> SessionMemoryConfig {
149 SESSION_MEMORY_STATE.get_config()
150}
151
152pub fn set_session_memory_config(config: SessionMemoryConfig) {
154 SESSION_MEMORY_STATE.set_config(config);
155}
156
157pub fn get_last_summarized_message_id() -> Option<usize> {
159 SESSION_MEMORY_STATE.get_last_summarized_index()
160}
161
162pub fn set_last_summarized_message_id(message_id: Option<usize>) {
164 SESSION_MEMORY_STATE.set_last_summarized_index(message_id);
165}
166
167pub fn has_met_initialization_threshold(current_token_count: u64) -> bool {
169 let config = get_session_memory_config();
170 current_token_count >= config.minimum_message_tokens_to_init as u64
171}
172
173pub fn has_met_update_threshold(current_token_count: u64) -> bool {
175 let config = get_session_memory_config();
176 let tokens_at_last = SESSION_MEMORY_STATE.get_tokens_at_last_extraction();
177 let tokens_since_last = current_token_count.saturating_sub(tokens_at_last);
178 tokens_since_last >= config.minimum_tokens_between_update as u64
179}
180
181pub fn get_tool_calls_between_updates() -> u32 {
183 get_session_memory_config().tool_calls_between_updates
184}
185
186pub fn record_extraction_token_count(token_count: u64) {
188 SESSION_MEMORY_STATE.set_tokens_at_last_extraction(token_count);
189}
190
191pub fn count_tool_calls_since(messages: &[Message], since_index: Option<usize>) -> usize {
193 let mut tool_call_count = 0;
194 let start_idx = since_index.unwrap_or(0);
195
196 for (i, message) in messages.iter().enumerate() {
197 if i < start_idx {
198 continue;
199 }
200
201 if message.role == MessageRole::Assistant {
202 if message.content.contains("tool_use") || message.tool_calls.is_some() {
205 tool_call_count += 1;
206 }
207 }
208 }
209
210 tool_call_count
211}
212
213pub fn should_extract_memory(messages: &[Message]) -> bool {
215 let current_token_count = estimate_message_tokens(messages);
217
218 if !is_session_memory_initialized() {
220 if !has_met_initialization_threshold(current_token_count) {
221 return false;
222 }
223 mark_session_memory_initialized();
224 }
225
226 let has_met_token_threshold = has_met_update_threshold(current_token_count);
228
229 let last_index = get_last_summarized_message_id();
231 let tool_calls_since_last = count_tool_calls_since(messages, last_index);
232 let has_met_tool_call_threshold =
233 tool_calls_since_last >= get_tool_calls_between_updates() as usize;
234
235 let has_tool_calls_in_last_turn = has_tool_calls_in_last_assistant_turn(messages);
237
238 let should_extract = (has_met_token_threshold && has_met_tool_call_threshold)
242 || (has_met_token_threshold && !has_tool_calls_in_last_turn);
243
244 if should_extract {
245 if !messages.is_empty() {
247 set_last_summarized_message_id(Some(messages.len() - 1));
248 }
249 }
250
251 should_extract
252}
253
254fn has_tool_calls_in_last_assistant_turn(messages: &[Message]) -> bool {
256 for message in messages.iter().rev() {
258 if message.role == MessageRole::Assistant {
259 if message.tool_calls.is_some() {
261 return true;
262 }
263 if message.content.contains("tool_use") {
265 return true;
266 }
267 return false;
269 }
270 }
271 false
272}
273
274fn estimate_message_tokens(messages: &[Message]) -> u64 {
276 let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
278 (total_chars / 4) as u64
279}
280
281pub async fn get_session_memory_content() -> Result<Option<String>, AgentError> {
283 let path = get_session_memory_path();
284
285 if !path.exists() {
286 return Ok(None);
287 }
288
289 match tokio::fs::read_to_string(&path).await {
290 Ok(content) => Ok(Some(content)),
291 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
292 Err(e) => Err(AgentError::Io(e)),
293 }
294}
295
296pub async fn init_session_memory_file() -> Result<String, AgentError> {
298 let dir = get_session_memory_dir();
299 let path = get_session_memory_path();
300
301 tokio::fs::create_dir_all(&dir)
303 .await
304 .map_err(AgentError::Io)?;
305
306 if !path.exists() {
308 let template = get_session_memory_template();
310 tokio::fs::write(&path, template)
311 .await
312 .map_err(AgentError::Io)?;
313 }
314
315 match tokio::fs::read_to_string(&path).await {
317 Ok(content) => Ok(content),
318 Err(e) => Err(AgentError::Io(e)),
319 }
320}
321
322fn get_session_memory_template() -> String {
324 r#"# Session Notes
325
326This file contains automatically extracted notes about the current conversation.
327
328## Key Points
329
330-
331
332## Decisions Made
333
334-
335
336## Open Items
337
338-
339
340## Context
341
342"#
343 .to_string()
344}
345
346#[derive(Debug, serde::Serialize, serde::Deserialize)]
348pub struct ManualExtractionResult {
349 pub success: bool,
350 pub memory_path: Option<String>,
351 pub error: Option<String>,
352}
353
354pub async fn wait_for_session_memory_extraction() {
356 while SESSION_MEMORY_STATE.is_extraction_in_progress() {
359 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
360 }
361}
362
363pub fn reset_session_memory_state() {
365 SESSION_MEMORY_STATE.set_config(DEFAULT_SESSION_MEMORY_CONFIG);
366 SESSION_MEMORY_STATE.set_tokens_at_last_extraction(0);
367 SESSION_MEMORY_STATE.set_last_summarized_index(None);
368 }
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_default_config() {
377 let config = DEFAULT_SESSION_MEMORY_CONFIG;
378 assert_eq!(config.minimum_message_tokens_to_init, 10000);
379 assert_eq!(config.minimum_tokens_between_update, 5000);
380 assert_eq!(config.tool_calls_between_updates, 3);
381 }
382
383 #[test]
384 fn test_session_memory_state() {
385 let state = SessionMemoryState::new();
386 assert!(!state.is_initialized());
387
388 state.mark_initialized();
389 assert!(state.is_initialized());
390 }
391
392 #[test]
393 fn test_has_met_initialization_threshold() {
394 reset_session_memory_state();
395 assert!(has_met_initialization_threshold(10000));
396 assert!(!has_met_initialization_threshold(9999));
397 }
398
399 #[test]
400 fn test_has_met_update_threshold() {
401 reset_session_memory_state();
402 record_extraction_token_count(5000);
403 assert!(has_met_update_threshold(10000));
404 assert!(!has_met_update_threshold(7499));
405 }
406}