Skip to main content

ai_agent/services/session_memory/
session_memory_utils.rs

1// Source: ~/claudecode/openclaudecode/src/services/SessionMemory/sessionMemoryUtils.ts
2//! Session memory utility functions — state management, thresholds, config.
3
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7/// Default extraction wait timeout (15 seconds)
8const EXTRACTION_WAIT_TIMEOUT: Duration = Duration::from_secs(15);
9
10/// Stale extraction threshold (1 minute)
11const EXTRACTION_STALE_THRESHOLD: Duration = Duration::from_secs(60);
12
13/// Configuration for session memory extraction thresholds
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct SessionMemoryConfig {
16    /// Minimum context window tokens before initializing session memory
17    pub minimum_message_tokens_to_init: u64,
18    /// Minimum context window growth between session memory updates
19    pub minimum_tokens_between_update: u64,
20    /// Number of tool calls between session memory updates
21    pub tool_calls_between_updates: u64,
22}
23
24/// Default configuration for session memory (backward compat constant)
25pub const DEFAULT_SESSION_MEMORY_CONFIG: SessionMemoryConfig = SessionMemoryConfig {
26    minimum_message_tokens_to_init: 10_000,
27    minimum_tokens_between_update: 5_000,
28    tool_calls_between_updates: 3,
29};
30
31impl Default for SessionMemoryConfig {
32    fn default() -> Self {
33        Self {
34            minimum_message_tokens_to_init: 10_000,
35            minimum_tokens_between_update: 5_000,
36            tool_calls_between_updates: 3,
37        }
38    }
39}
40
41struct SessionMemoryState {
42    config: SessionMemoryConfig,
43    /// UUID of the last message summarized by session memory
44    last_summarized_message_id: Option<String>,
45    /// When the current extraction started (if any)
46    extraction_started_at: Option<Instant>,
47    /// Token count at last extraction (for measuring context growth)
48    tokens_at_last_extraction: u64,
49    /// Whether session memory has been initialized
50    session_memory_initialized: bool,
51    /// File path to the session memory notes file
52    memory_path: Option<String>,
53}
54
55impl Default for SessionMemoryState {
56    fn default() -> Self {
57        Self {
58            config: SessionMemoryConfig::default(),
59            last_summarized_message_id: None,
60            extraction_started_at: None,
61            tokens_at_last_extraction: 0,
62            session_memory_initialized: false,
63            memory_path: None,
64        }
65    }
66}
67
68static STATE: std::sync::LazyLock<Mutex<SessionMemoryState>> =
69    std::sync::LazyLock::new(|| Mutex::new(SessionMemoryState::default()));
70
71/// Get the message ID up to which the session memory is current
72pub fn get_last_summarized_message_id() -> Option<String> {
73    STATE.lock().unwrap().last_summarized_message_id.clone()
74}
75
76/// Set the last summarized message ID (called from session_memory.rs)
77pub fn set_last_summarized_message_id(id: Option<&str>) {
78    let mut state = STATE.lock().unwrap();
79    state.last_summarized_message_id = id.map(str::to_string);
80}
81
82/// Mark extraction as started (called from session_memory.rs)
83pub fn mark_extraction_started() {
84    STATE.lock().unwrap().extraction_started_at = Some(Instant::now());
85}
86
87/// Mark extraction as completed (called from session_memory.rs)
88pub fn mark_extraction_completed() {
89    STATE.lock().unwrap().extraction_started_at = None;
90}
91
92/// Wait for any in-progress session memory extraction to complete (with timeout).
93/// Returns immediately if no extraction is in progress or if extraction is stale.
94pub async fn wait_for_session_memory_extraction() {
95    let start = Instant::now();
96    loop {
97        let started = { STATE.lock().unwrap().extraction_started_at };
98        match started {
99            None => return,
100            Some(t) if t.elapsed() > EXTRACTION_STALE_THRESHOLD => return,
101            _ => {}
102        }
103        if start.elapsed() > EXTRACTION_WAIT_TIMEOUT {
104            return;
105        }
106        tokio::time::sleep(Duration::from_millis(1000)).await;
107    }
108}
109
110/// Set the session memory configuration
111pub fn set_session_memory_config(partial: SessionMemoryConfig) {
112    let mut state = STATE.lock().unwrap();
113    if partial.minimum_message_tokens_to_init > 0 {
114        state.config.minimum_message_tokens_to_init = partial.minimum_message_tokens_to_init;
115    }
116    if partial.minimum_tokens_between_update > 0 {
117        state.config.minimum_tokens_between_update = partial.minimum_tokens_between_update;
118    }
119    if partial.tool_calls_between_updates > 0 {
120        state.config.tool_calls_between_updates = partial.tool_calls_between_updates;
121    }
122}
123
124/// Get the current session memory configuration
125pub fn get_session_memory_config() -> SessionMemoryConfig {
126    STATE.lock().unwrap().config.clone()
127}
128
129/// Record the context size at the time of extraction.
130/// Used to measure context growth for minimumTokensBetweenUpdate threshold.
131pub fn record_extraction_token_count(current_token_count: u64) {
132    STATE.lock().unwrap().tokens_at_last_extraction = current_token_count;
133}
134
135/// Check if session memory has been initialized (met minimumTokensToInit threshold)
136pub fn is_session_memory_initialized() -> bool {
137    STATE.lock().unwrap().session_memory_initialized
138}
139
140/// Mark session memory as initialized
141pub fn mark_session_memory_initialized() {
142    STATE.lock().unwrap().session_memory_initialized = true;
143}
144
145/// Check if we've met the threshold to initialize session memory.
146/// Uses total context window tokens (same as autocompact) for consistent behavior.
147pub fn has_met_initialization_threshold(current_token_count: u64) -> bool {
148    let state = STATE.lock().unwrap();
149    current_token_count >= state.config.minimum_message_tokens_to_init
150}
151
152/// Check if we've met the threshold for the next update.
153/// Measures actual context window growth since last extraction.
154pub fn has_met_update_threshold(current_token_count: u64) -> bool {
155    let state = STATE.lock().unwrap();
156    let tokens_since = current_token_count.saturating_sub(state.tokens_at_last_extraction);
157    tokens_since >= state.config.minimum_tokens_between_update
158}
159
160/// Get the configured number of tool calls between updates
161pub fn get_tool_calls_between_updates() -> u64 {
162    STATE.lock().unwrap().config.tool_calls_between_updates
163}
164
165/// Get the session memory file path
166pub fn get_session_memory_path() -> Option<String> {
167    STATE.lock().unwrap().memory_path.clone()
168}
169
170/// Set the session memory file path (called during file setup)
171pub fn set_session_memory_path(path: String) {
172    STATE.lock().unwrap().memory_path = Some(path);
173}
174
175/// Reset session memory state (useful for testing)
176pub fn reset_session_memory_state() {
177    *STATE.lock().unwrap() = SessionMemoryState::default();
178}
179
180/// Get the extraction wait timeout
181pub fn get_extraction_wait_timeout() -> Duration {
182    EXTRACTION_WAIT_TIMEOUT
183}
184
185/// Get the extraction stale threshold
186pub fn get_extraction_stale_threshold() -> Duration {
187    EXTRACTION_STALE_THRESHOLD
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_default_config() {
196        reset_session_memory_state();
197        let config = get_session_memory_config();
198        assert_eq!(config.minimum_message_tokens_to_init, 10_000);
199        assert_eq!(config.minimum_tokens_between_update, 5_000);
200        assert_eq!(config.tool_calls_between_updates, 3);
201    }
202
203    #[test]
204    fn test_initialization_threshold() {
205        reset_session_memory_state();
206        assert!(!is_session_memory_initialized());
207        assert!(!has_met_initialization_threshold(5_000));
208        assert!(has_met_initialization_threshold(10_000));
209    }
210
211    #[test]
212    fn test_update_threshold() {
213        reset_session_memory_state();
214        record_extraction_token_count(10_000);
215        assert!(!has_met_update_threshold(12_000));
216        assert!(has_met_update_threshold(15_000));
217    }
218
219    #[test]
220    fn test_extraction_tracking() {
221        reset_session_memory_state();
222        assert!(get_last_summarized_message_id().is_none());
223        set_last_summarized_message_id(Some("msg_123"));
224        assert_eq!(get_last_summarized_message_id(), Some("msg_123".to_string()));
225    }
226
227    #[test]
228    fn test_mark_extraction() {
229        reset_session_memory_state();
230        mark_extraction_started();
231        // Subsequent call to mark_extraction_completed clears it
232        mark_extraction_completed();
233    }
234}