Skip to main content

cortexai_agents/
memory.rs

1//! Agent memory management
2
3use cortexai_core::{errors::MemoryError, MemoryConfig, Message, RetentionPolicy};
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7/// Agent memory
8pub struct AgentMemory {
9    config: MemoryConfig,
10    messages: Arc<RwLock<Vec<Message>>>,
11    current_size: Arc<RwLock<usize>>,
12}
13
14impl AgentMemory {
15    pub fn new(config: MemoryConfig) -> Self {
16        Self {
17            config,
18            messages: Arc::new(RwLock::new(Vec::new())),
19            current_size: Arc::new(RwLock::new(0)),
20        }
21    }
22
23    /// Add a message to memory
24    pub async fn add_message(&self, message: Message) -> Result<(), MemoryError> {
25        let message_size = self.estimate_message_size(&message);
26
27        // Check if we need to free space
28        {
29            let current_size = self.current_size.read().await;
30            if *current_size + message_size > self.config.max_size {
31                drop(current_size); // Release lock before calling cleanup
32                self.cleanup().await?;
33            }
34        }
35
36        // Re-check after cleanup - only add if we have space
37        {
38            let current_size = self.current_size.read().await;
39            if *current_size + message_size > self.config.max_size {
40                // Still not enough space even after cleanup
41                // This can happen if the message itself is larger than max_size
42                // or if we couldn't free enough space
43                return Err(MemoryError::CapacityExceeded {
44                    current: *current_size,
45                    required: message_size,
46                    max: self.config.max_size,
47                });
48            }
49        }
50
51        // Add message
52        {
53            let mut messages = self.messages.write().await;
54            messages.push(message);
55
56            let mut current_size = self.current_size.write().await;
57            *current_size += message_size;
58        }
59
60        Ok(())
61    }
62
63    /// Get message history
64    pub async fn get_history(&self) -> Result<Vec<Message>, MemoryError> {
65        let messages = self.messages.read().await;
66        Ok(messages.clone())
67    }
68
69    /// Get last N messages
70    pub async fn get_last_n(&self, n: usize) -> Result<Vec<Message>, MemoryError> {
71        let messages = self.messages.read().await;
72        let start = messages.len().saturating_sub(n);
73        Ok(messages[start..].to_vec())
74    }
75
76    /// Clear memory
77    pub async fn clear(&self) -> Result<(), MemoryError> {
78        let mut messages = self.messages.write().await;
79        messages.clear();
80
81        let mut current_size = self.current_size.write().await;
82        *current_size = 0;
83
84        Ok(())
85    }
86
87    /// Get memory size
88    pub async fn size(&self) -> usize {
89        *self.current_size.read().await
90    }
91
92    /// Get message count
93    pub async fn count(&self) -> usize {
94        self.messages.read().await.len()
95    }
96
97    /// Cleanup old messages based on retention policy
98    async fn cleanup(&self) -> Result<(), MemoryError> {
99        let mut messages = self.messages.write().await;
100
101        let original_len = messages.len();
102
103        match &self.config.retention_policy {
104            RetentionPolicy::KeepAll => {
105                // Don't remove anything, but this means we're over limit
106                return Err(MemoryError::LimitExceeded(
107                    *self.current_size.read().await,
108                    self.config.max_size,
109                ));
110            }
111            RetentionPolicy::KeepRecent(n) => {
112                if messages.len() > *n {
113                    let remove_count = messages.len() - n;
114                    messages.drain(0..remove_count);
115                }
116            }
117            RetentionPolicy::KeepImportant(_threshold) => {
118                // For now, just keep recent messages
119                // TODO: Implement importance scoring
120                if messages.len() > 100 {
121                    let remove_count = messages.len() - 100;
122                    messages.drain(0..remove_count);
123                }
124            }
125            RetentionPolicy::Custom => {
126                // Keep last 50% of messages
127                if messages.len() > 1 {
128                    let half = messages.len() / 2;
129                    messages.drain(0..half);
130                }
131            }
132        }
133
134        // Recalculate size
135        let new_size: usize = messages.iter().map(|m| self.estimate_message_size(m)).sum();
136
137        let mut current_size = self.current_size.write().await;
138        *current_size = new_size;
139
140        tracing::debug!(
141            "Memory cleanup: removed {} messages, size: {} -> {} bytes",
142            original_len - messages.len(),
143            original_len,
144            new_size
145        );
146
147        Ok(())
148    }
149
150    /// Estimate message size in bytes
151    fn estimate_message_size(&self, message: &Message) -> usize {
152        // Rough estimation: JSON serialization size
153        serde_json::to_string(message)
154            .map(|s| s.len())
155            .unwrap_or(1024) // Default to 1KB if serialization fails
156    }
157}