Skip to main content

ai_agent/
session_memory.rs

1//! Session Memory - automatic conversation summarization
2//!
3//! Ported from ~/claudecode/openclaudecode/src/services/SessionMemory/sessionMemory.ts
4//!
5//! Session memory automatically maintains a markdown file with notes about the current conversation.
6//! It runs periodically in the background using a forked subagent to extract key information
7//! without interrupting the main conversation flow.
8
9use crate::constants::env::system;
10use crate::types::*;
11use crate::AgentError;
12use std::path::PathBuf;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::LazyLock;
15use std::sync::Mutex;
16
17/// Default configuration for session memory
18pub const DEFAULT_SESSION_MEMORY_CONFIG: SessionMemoryConfig = SessionMemoryConfig {
19    minimum_message_tokens_to_init: 10000,
20    minimum_tokens_between_update: 5000,
21    tool_calls_between_updates: 3,
22};
23
24/// Session memory configuration
25#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
26pub struct SessionMemoryConfig {
27    /// Minimum context window tokens before initializing session memory
28    pub minimum_message_tokens_to_init: u32,
29    /// Minimum context window growth (in tokens) between updates
30    pub minimum_tokens_between_update: u32,
31    /// Number of tool calls between session memory updates
32    pub tool_calls_between_updates: u32,
33}
34
35impl Default for SessionMemoryConfig {
36    fn default() -> Self {
37        DEFAULT_SESSION_MEMORY_CONFIG
38    }
39}
40
41/// Session memory state
42pub struct SessionMemoryState {
43    config: Mutex<SessionMemoryConfig>,
44    initialized: AtomicBool,
45    tokens_at_last_extraction: AtomicU64,
46    /// Last summarized message index (not UUID since Message lacks id field)
47    last_summarized_index: Mutex<Option<usize>>,
48    extraction_in_progress: AtomicBool,
49}
50
51impl SessionMemoryState {
52    pub fn new() -> Self {
53        Self {
54            config: Mutex::new(DEFAULT_SESSION_MEMORY_CONFIG),
55            initialized: AtomicBool::new(false),
56            tokens_at_last_extraction: AtomicU64::new(0),
57            last_summarized_index: Mutex::new(None),
58            extraction_in_progress: AtomicBool::new(false),
59        }
60    }
61
62    pub fn is_initialized(&self) -> bool {
63        self.initialized.load(Ordering::SeqCst)
64    }
65
66    pub fn mark_initialized(&self) {
67        self.initialized.store(true, Ordering::SeqCst);
68    }
69
70    pub fn get_config(&self) -> SessionMemoryConfig {
71        self.config.lock().unwrap().clone()
72    }
73
74    pub fn set_config(&self, config: SessionMemoryConfig) {
75        *self.config.lock().unwrap() = config;
76    }
77
78    pub fn get_tokens_at_last_extraction(&self) -> u64 {
79        self.tokens_at_last_extraction.load(Ordering::SeqCst)
80    }
81
82    pub fn set_tokens_at_last_extraction(&self, tokens: u64) {
83        self.tokens_at_last_extraction
84            .store(tokens, Ordering::SeqCst);
85    }
86
87    pub fn get_last_summarized_index(&self) -> Option<usize> {
88        *self.last_summarized_index.lock().unwrap()
89    }
90
91    pub fn set_last_summarized_index(&self, index: Option<usize>) {
92        *self.last_summarized_index.lock().unwrap() = index;
93    }
94
95    pub fn is_extraction_in_progress(&self) -> bool {
96        self.extraction_in_progress.load(Ordering::SeqCst)
97    }
98
99    pub fn start_extraction(&self) {
100        self.extraction_in_progress.store(true, Ordering::SeqCst);
101    }
102
103    pub fn end_extraction(&self) {
104        self.extraction_in_progress.store(false, Ordering::SeqCst);
105    }
106}
107
108impl Default for SessionMemoryState {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114/// Global session memory state
115static SESSION_MEMORY_STATE: LazyLock<SessionMemoryState> = LazyLock::new(SessionMemoryState::new);
116
117/// Get the session memory state
118pub fn get_session_memory_state() -> &'static SessionMemoryState {
119    &SESSION_MEMORY_STATE
120}
121
122/// Get the session memory directory
123pub fn get_session_memory_dir() -> PathBuf {
124    let home = std::env::var(system::HOME)
125        .or_else(|_| std::env::var(system::USERPROFILE))
126        .unwrap_or_else(|_| "/tmp".to_string());
127    PathBuf::from(home)
128        .join(".open-agent-sdk")
129        .join("session_memory")
130}
131
132/// Get the session memory file path
133pub fn get_session_memory_path() -> PathBuf {
134    get_session_memory_dir().join("notes.md")
135}
136
137/// Check if session memory has been initialized
138pub fn is_session_memory_initialized() -> bool {
139    SESSION_MEMORY_STATE.is_initialized()
140}
141
142/// Mark session memory as initialized
143pub fn mark_session_memory_initialized() {
144    SESSION_MEMORY_STATE.mark_initialized();
145}
146
147/// Get current session memory configuration
148pub fn get_session_memory_config() -> SessionMemoryConfig {
149    SESSION_MEMORY_STATE.get_config()
150}
151
152/// Set session memory configuration
153pub fn set_session_memory_config(config: SessionMemoryConfig) {
154    SESSION_MEMORY_STATE.set_config(config);
155}
156
157/// Get the last summarized message index
158pub fn get_last_summarized_message_id() -> Option<usize> {
159    SESSION_MEMORY_STATE.get_last_summarized_index()
160}
161
162/// Set the last summarized message index
163pub fn set_last_summarized_message_id(message_id: Option<usize>) {
164    SESSION_MEMORY_STATE.set_last_summarized_index(message_id);
165}
166
167/// Check if we've met the initialization threshold
168pub fn has_met_initialization_threshold(current_token_count: u64) -> bool {
169    let config = get_session_memory_config();
170    current_token_count >= config.minimum_message_tokens_to_init as u64
171}
172
173/// Check if we've met the update threshold
174pub fn has_met_update_threshold(current_token_count: u64) -> bool {
175    let config = get_session_memory_config();
176    let tokens_at_last = SESSION_MEMORY_STATE.get_tokens_at_last_extraction();
177    let tokens_since_last = current_token_count.saturating_sub(tokens_at_last);
178    tokens_since_last >= config.minimum_tokens_between_update as u64
179}
180
181/// Get tool calls between updates
182pub fn get_tool_calls_between_updates() -> u32 {
183    get_session_memory_config().tool_calls_between_updates
184}
185
186/// Record token count at extraction time
187pub fn record_extraction_token_count(token_count: u64) {
188    SESSION_MEMORY_STATE.set_tokens_at_last_extraction(token_count);
189}
190
191/// Count tool calls since a given message index
192pub fn count_tool_calls_since(messages: &[Message], since_index: Option<usize>) -> usize {
193    let mut tool_call_count = 0;
194    let start_idx = since_index.unwrap_or(0);
195
196    for (i, message) in messages.iter().enumerate() {
197        if i < start_idx {
198            continue;
199        }
200
201        if message.role == MessageRole::Assistant {
202            // Count tool calls in this message
203            // In Rust we store content as string, so we approximate
204            if message.content.contains("tool_use") || message.tool_calls.is_some() {
205                tool_call_count += 1;
206            }
207        }
208    }
209
210    tool_call_count
211}
212
213/// Check if we should extract memory based on thresholds
214pub fn should_extract_memory(messages: &[Message]) -> bool {
215    // Estimate token count
216    let current_token_count = estimate_message_tokens(messages);
217
218    // Check initialization threshold
219    if !is_session_memory_initialized() {
220        if !has_met_initialization_threshold(current_token_count) {
221            return false;
222        }
223        mark_session_memory_initialized();
224    }
225
226    // Check token threshold
227    let has_met_token_threshold = has_met_update_threshold(current_token_count);
228
229    // Check tool call threshold
230    let last_index = get_last_summarized_message_id();
231    let tool_calls_since_last = count_tool_calls_since(messages, last_index);
232    let has_met_tool_call_threshold =
233        tool_calls_since_last >= get_tool_calls_between_updates() as usize;
234
235    // Check if last assistant turn has tool calls (unsafe to extract)
236    let has_tool_calls_in_last_turn = has_tool_calls_in_last_assistant_turn(messages);
237
238    // Trigger extraction when:
239    // 1. Both thresholds are met (tokens AND tool calls), OR
240    // 2. No tool calls in last turn AND token threshold is met
241    let should_extract = (has_met_token_threshold && has_met_tool_call_threshold)
242        || (has_met_token_threshold && !has_tool_calls_in_last_turn);
243
244    if should_extract {
245        // Store the last message index
246        if !messages.is_empty() {
247            set_last_summarized_message_id(Some(messages.len() - 1));
248        }
249    }
250
251    should_extract
252}
253
254/// Check if last assistant turn has tool calls
255fn has_tool_calls_in_last_assistant_turn(messages: &[Message]) -> bool {
256    // Find last assistant message and check for tool calls
257    for message in messages.iter().rev() {
258        if message.role == MessageRole::Assistant {
259            // Check for tool calls
260            if message.tool_calls.is_some() {
261                return true;
262            }
263            // Also check content for tool_use blocks
264            if message.content.contains("tool_use") {
265                return true;
266            }
267            // Found last assistant message without tool calls
268            return false;
269        }
270    }
271    false
272}
273
274/// Estimate token count for messages
275fn estimate_message_tokens(messages: &[Message]) -> u64 {
276    // Simple estimation: ~4 characters per token
277    let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
278    (total_chars / 4) as u64
279}
280
281/// Get session memory content from file
282pub async fn get_session_memory_content() -> Result<Option<String>, AgentError> {
283    let path = get_session_memory_path();
284
285    if !path.exists() {
286        return Ok(None);
287    }
288
289    match tokio::fs::read_to_string(&path).await {
290        Ok(content) => Ok(Some(content)),
291        Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
292        Err(e) => Err(AgentError::Io(e)),
293    }
294}
295
296/// Initialize session memory file with template
297pub async fn init_session_memory_file() -> Result<String, AgentError> {
298    let dir = get_session_memory_dir();
299    let path = get_session_memory_path();
300
301    // Create directory
302    tokio::fs::create_dir_all(&dir)
303        .await
304        .map_err(AgentError::Io)?;
305
306    // Check if file already exists
307    if !path.exists() {
308        // Create with template
309        let template = get_session_memory_template();
310        tokio::fs::write(&path, template)
311            .await
312            .map_err(AgentError::Io)?;
313    }
314
315    // Return current content
316    match tokio::fs::read_to_string(&path).await {
317        Ok(content) => Ok(content),
318        Err(e) => Err(AgentError::Io(e)),
319    }
320}
321
322/// Get session memory template
323fn get_session_memory_template() -> String {
324    r#"# Session Notes
325
326This file contains automatically extracted notes about the current conversation.
327
328## Key Points
329
330-
331
332## Decisions Made
333
334-
335
336## Open Items
337
338-
339
340## Context
341
342"#
343    .to_string()
344}
345
346/// Manual extraction result
347#[derive(Debug, serde::Serialize, serde::Deserialize)]
348pub struct ManualExtractionResult {
349    pub success: bool,
350    pub memory_path: Option<String>,
351    pub error: Option<String>,
352}
353
354/// Wait for any in-progress extraction to complete
355pub async fn wait_for_session_memory_extraction() {
356    // In Rust, this would need async coordination
357    // For now, simplified implementation
358    while SESSION_MEMORY_STATE.is_extraction_in_progress() {
359        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
360    }
361}
362
363/// Reset session memory state (for testing)
364pub fn reset_session_memory_state() {
365    SESSION_MEMORY_STATE.set_config(DEFAULT_SESSION_MEMORY_CONFIG);
366    SESSION_MEMORY_STATE.set_tokens_at_last_extraction(0);
367    SESSION_MEMORY_STATE.set_last_summarized_index(None);
368    // Note: can't reset atomic bool without interior mutability
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_default_config() {
377        let config = DEFAULT_SESSION_MEMORY_CONFIG;
378        assert_eq!(config.minimum_message_tokens_to_init, 10000);
379        assert_eq!(config.minimum_tokens_between_update, 5000);
380        assert_eq!(config.tool_calls_between_updates, 3);
381    }
382
383    #[test]
384    fn test_session_memory_state() {
385        let state = SessionMemoryState::new();
386        assert!(!state.is_initialized());
387
388        state.mark_initialized();
389        assert!(state.is_initialized());
390    }
391
392    #[test]
393    fn test_has_met_initialization_threshold() {
394        reset_session_memory_state();
395        assert!(has_met_initialization_threshold(10000));
396        assert!(!has_met_initialization_threshold(9999));
397    }
398
399    #[test]
400    fn test_has_met_update_threshold() {
401        reset_session_memory_state();
402        record_extraction_token_count(5000);
403        assert!(has_met_update_threshold(10000));
404        assert!(!has_met_update_threshold(7499));
405    }
406}