use crate::constants::env::system;
use crate::types::*;
use crate::AgentError;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::LazyLock;
use std::sync::Mutex;
pub const DEFAULT_SESSION_MEMORY_CONFIG: SessionMemoryConfig = SessionMemoryConfig {
minimum_message_tokens_to_init: 10000,
minimum_tokens_between_update: 5000,
tool_calls_between_updates: 3,
};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SessionMemoryConfig {
pub minimum_message_tokens_to_init: u32,
pub minimum_tokens_between_update: u32,
pub tool_calls_between_updates: u32,
}
impl Default for SessionMemoryConfig {
fn default() -> Self {
DEFAULT_SESSION_MEMORY_CONFIG
}
}
pub struct SessionMemoryState {
config: Mutex<SessionMemoryConfig>,
initialized: AtomicBool,
tokens_at_last_extraction: AtomicU64,
last_summarized_index: Mutex<Option<usize>>,
extraction_in_progress: AtomicBool,
}
impl SessionMemoryState {
pub fn new() -> Self {
Self {
config: Mutex::new(DEFAULT_SESSION_MEMORY_CONFIG),
initialized: AtomicBool::new(false),
tokens_at_last_extraction: AtomicU64::new(0),
last_summarized_index: Mutex::new(None),
extraction_in_progress: AtomicBool::new(false),
}
}
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::SeqCst)
}
pub fn mark_initialized(&self) {
self.initialized.store(true, Ordering::SeqCst);
}
pub fn get_config(&self) -> SessionMemoryConfig {
self.config.lock().unwrap().clone()
}
pub fn set_config(&self, config: SessionMemoryConfig) {
*self.config.lock().unwrap() = config;
}
pub fn get_tokens_at_last_extraction(&self) -> u64 {
self.tokens_at_last_extraction.load(Ordering::SeqCst)
}
pub fn set_tokens_at_last_extraction(&self, tokens: u64) {
self.tokens_at_last_extraction
.store(tokens, Ordering::SeqCst);
}
pub fn get_last_summarized_index(&self) -> Option<usize> {
*self.last_summarized_index.lock().unwrap()
}
pub fn set_last_summarized_index(&self, index: Option<usize>) {
*self.last_summarized_index.lock().unwrap() = index;
}
pub fn is_extraction_in_progress(&self) -> bool {
self.extraction_in_progress.load(Ordering::SeqCst)
}
pub fn start_extraction(&self) {
self.extraction_in_progress.store(true, Ordering::SeqCst);
}
pub fn end_extraction(&self) {
self.extraction_in_progress.store(false, Ordering::SeqCst);
}
}
impl Default for SessionMemoryState {
fn default() -> Self {
Self::new()
}
}
static SESSION_MEMORY_STATE: LazyLock<SessionMemoryState> = LazyLock::new(SessionMemoryState::new);
pub fn get_session_memory_state() -> &'static SessionMemoryState {
&SESSION_MEMORY_STATE
}
pub fn get_session_memory_dir() -> PathBuf {
let home = std::env::var(system::HOME)
.or_else(|_| std::env::var(system::USERPROFILE))
.unwrap_or_else(|_| "/tmp".to_string());
PathBuf::from(home)
.join(".open-agent-sdk")
.join("session_memory")
}
pub fn get_session_memory_path() -> PathBuf {
get_session_memory_dir().join("notes.md")
}
pub fn is_session_memory_initialized() -> bool {
SESSION_MEMORY_STATE.is_initialized()
}
pub fn mark_session_memory_initialized() {
SESSION_MEMORY_STATE.mark_initialized();
}
pub fn get_session_memory_config() -> SessionMemoryConfig {
SESSION_MEMORY_STATE.get_config()
}
pub fn set_session_memory_config(config: SessionMemoryConfig) {
SESSION_MEMORY_STATE.set_config(config);
}
pub fn get_last_summarized_message_id() -> Option<usize> {
SESSION_MEMORY_STATE.get_last_summarized_index()
}
pub fn set_last_summarized_message_id(message_id: Option<usize>) {
SESSION_MEMORY_STATE.set_last_summarized_index(message_id);
}
pub fn has_met_initialization_threshold(current_token_count: u64) -> bool {
let config = get_session_memory_config();
current_token_count >= config.minimum_message_tokens_to_init as u64
}
pub fn has_met_update_threshold(current_token_count: u64) -> bool {
let config = get_session_memory_config();
let tokens_at_last = SESSION_MEMORY_STATE.get_tokens_at_last_extraction();
let tokens_since_last = current_token_count.saturating_sub(tokens_at_last);
tokens_since_last >= config.minimum_tokens_between_update as u64
}
pub fn get_tool_calls_between_updates() -> u32 {
get_session_memory_config().tool_calls_between_updates
}
pub fn record_extraction_token_count(token_count: u64) {
SESSION_MEMORY_STATE.set_tokens_at_last_extraction(token_count);
}
pub fn count_tool_calls_since(messages: &[Message], since_index: Option<usize>) -> usize {
let mut tool_call_count = 0;
let start_idx = since_index.unwrap_or(0);
for (i, message) in messages.iter().enumerate() {
if i < start_idx {
continue;
}
if message.role == MessageRole::Assistant {
if message.content.contains("tool_use") || message.tool_calls.is_some() {
tool_call_count += 1;
}
}
}
tool_call_count
}
pub fn should_extract_memory(messages: &[Message]) -> bool {
let current_token_count = estimate_message_tokens(messages);
if !is_session_memory_initialized() {
if !has_met_initialization_threshold(current_token_count) {
return false;
}
mark_session_memory_initialized();
}
let has_met_token_threshold = has_met_update_threshold(current_token_count);
let last_index = get_last_summarized_message_id();
let tool_calls_since_last = count_tool_calls_since(messages, last_index);
let has_met_tool_call_threshold =
tool_calls_since_last >= get_tool_calls_between_updates() as usize;
let has_tool_calls_in_last_turn = has_tool_calls_in_last_assistant_turn(messages);
let should_extract = (has_met_token_threshold && has_met_tool_call_threshold)
|| (has_met_token_threshold && !has_tool_calls_in_last_turn);
if should_extract {
if !messages.is_empty() {
set_last_summarized_message_id(Some(messages.len() - 1));
}
}
should_extract
}
fn has_tool_calls_in_last_assistant_turn(messages: &[Message]) -> bool {
for message in messages.iter().rev() {
if message.role == MessageRole::Assistant {
if message.tool_calls.is_some() {
return true;
}
if message.content.contains("tool_use") {
return true;
}
return false;
}
}
false
}
fn estimate_message_tokens(messages: &[Message]) -> u64 {
let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
(total_chars / 4) as u64
}
pub async fn get_session_memory_content() -> Result<Option<String>, AgentError> {
let path = get_session_memory_path();
if !path.exists() {
return Ok(None);
}
match tokio::fs::read_to_string(&path).await {
Ok(content) => Ok(Some(content)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(AgentError::Io(e)),
}
}
pub async fn init_session_memory_file() -> Result<String, AgentError> {
let dir = get_session_memory_dir();
let path = get_session_memory_path();
tokio::fs::create_dir_all(&dir)
.await
.map_err(AgentError::Io)?;
if !path.exists() {
let template = get_session_memory_template();
tokio::fs::write(&path, template)
.await
.map_err(AgentError::Io)?;
}
match tokio::fs::read_to_string(&path).await {
Ok(content) => Ok(content),
Err(e) => Err(AgentError::Io(e)),
}
}
fn get_session_memory_template() -> String {
r#"# Session Notes
This file contains automatically extracted notes about the current conversation.
## Key Points
-
## Decisions Made
-
## Open Items
-
## Context
"#
.to_string()
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct ManualExtractionResult {
pub success: bool,
pub memory_path: Option<String>,
pub error: Option<String>,
}
pub async fn wait_for_session_memory_extraction() {
while SESSION_MEMORY_STATE.is_extraction_in_progress() {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
pub fn reset_session_memory_state() {
SESSION_MEMORY_STATE.set_config(DEFAULT_SESSION_MEMORY_CONFIG);
SESSION_MEMORY_STATE.set_tokens_at_last_extraction(0);
SESSION_MEMORY_STATE.set_last_summarized_index(None);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = DEFAULT_SESSION_MEMORY_CONFIG;
assert_eq!(config.minimum_message_tokens_to_init, 10000);
assert_eq!(config.minimum_tokens_between_update, 5000);
assert_eq!(config.tool_calls_between_updates, 3);
}
#[test]
fn test_session_memory_state() {
let state = SessionMemoryState::new();
assert!(!state.is_initialized());
state.mark_initialized();
assert!(state.is_initialized());
}
#[test]
fn test_has_met_initialization_threshold() {
reset_session_memory_state();
assert!(has_met_initialization_threshold(10000));
assert!(!has_met_initialization_threshold(9999));
}
#[test]
fn test_has_met_update_threshold() {
reset_session_memory_state();
record_extraction_token_count(5000);
assert!(has_met_update_threshold(10000));
assert!(!has_met_update_threshold(7499));
}
}