Skip to main content

agent_teams/memory/
mod.rs

1//! Cross-turn conversation memory for stateless backends.
2//!
3//! Provides [`ConversationMemory`] for accumulating turn records and
4//! [`MemoryManager`] for file-based persistence. When enabled for an agent,
5//! previous conversation context is prepended to each new input, giving
6//! stateless backends (e.g., Gemini CLI) a form of multi-turn memory.
7
8use std::path::PathBuf;
9
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12
13use crate::error::{Error, Result};
14use crate::util::atomic_write::atomic_write_json;
15use crate::util::file_lock::FileLock;
16
17// ---------------------------------------------------------------------------
18// Types
19// ---------------------------------------------------------------------------
20
21/// Speaker role in a conversation turn.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum Role {
25    User,
26    Assistant,
27}
28
29impl std::fmt::Display for Role {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            Role::User => write!(f, "User"),
33            Role::Assistant => write!(f, "Assistant"),
34        }
35    }
36}
37
38/// A single conversation turn.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct TurnRecord {
41    pub role: Role,
42    pub content: String,
43    pub timestamp: DateTime<Utc>,
44}
45
46/// Configuration for conversation memory limits.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct MemoryConfig {
49    /// Maximum number of turns to keep (oldest are evicted first).
50    pub max_turns: usize,
51    /// Maximum total characters in the formatted context string.
52    pub budget_chars: usize,
53}
54
55impl Default for MemoryConfig {
56    fn default() -> Self {
57        Self {
58            max_turns: 5,
59            budget_chars: 4000,
60        }
61    }
62}
63
64/// In-memory conversation history with configurable limits.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ConversationMemory {
67    turns: Vec<TurnRecord>,
68    config: MemoryConfig,
69}
70
71impl ConversationMemory {
72    /// Create a new memory with the given config.
73    pub fn new(config: MemoryConfig) -> Self {
74        Self {
75            turns: Vec::new(),
76            config,
77        }
78    }
79
80    /// Create a new memory with default configuration (5 turns, 4000 chars).
81    pub fn with_defaults() -> Self {
82        Self::new(MemoryConfig::default())
83    }
84
85    /// Record a new conversation turn, evicting the oldest if over `max_turns`.
86    pub fn record_turn(&mut self, role: Role, content: impl Into<String>) {
87        self.turns.push(TurnRecord {
88            role,
89            content: content.into(),
90            timestamp: Utc::now(),
91        });
92
93        // Evict oldest turns if we exceed max_turns
94        while self.turns.len() > self.config.max_turns {
95            self.turns.remove(0);
96        }
97    }
98
99    /// Format the stored turns as a context string for injection.
100    ///
101    /// Builds the context from most recent turns backward, stopping when
102    /// adding the next turn would exceed `budget_chars`. Returns the
103    /// context in chronological order.
104    pub fn format_context(&self) -> String {
105        if self.turns.is_empty() {
106            return String::new();
107        }
108
109        let mut selected: Vec<String> = Vec::new();
110        let mut total_chars = 0;
111        let header = "[Conversation Context]\n";
112        total_chars += header.len();
113
114        // Walk from most recent to oldest
115        for turn in self.turns.iter().rev() {
116            let line = format!("{}: {}\n", turn.role, turn.content);
117            if total_chars + line.len() > self.config.budget_chars {
118                break;
119            }
120            total_chars += line.len();
121            selected.push(line);
122        }
123
124        if selected.is_empty() {
125            return String::new();
126        }
127
128        // Reverse to get chronological order
129        selected.reverse();
130
131        let mut out = header.to_string();
132        for line in selected {
133            out.push_str(&line);
134        }
135        out
136    }
137
138    /// Remove all recorded turns.
139    pub fn clear(&mut self) {
140        self.turns.clear();
141    }
142
143    /// Number of recorded turns.
144    pub fn len(&self) -> usize {
145        self.turns.len()
146    }
147
148    /// Whether there are no recorded turns.
149    pub fn is_empty(&self) -> bool {
150        self.turns.is_empty()
151    }
152
153    /// Current configuration.
154    pub fn config(&self) -> &MemoryConfig {
155        &self.config
156    }
157}
158
159// ---------------------------------------------------------------------------
160// MemoryManager — file-based persistence
161// ---------------------------------------------------------------------------
162
163/// File-based persistence for agent conversation memories.
164///
165/// Stores memories at `{teams_base}/{team}/memory/{agent}.json`.
166#[derive(Debug, Clone)]
167pub struct MemoryManager {
168    teams_base: PathBuf,
169}
170
171impl MemoryManager {
172    /// Create a new manager rooted at the given teams base directory.
173    pub fn new(teams_base: PathBuf) -> Self {
174        Self { teams_base }
175    }
176
177    /// Directory for a team's memory files.
178    fn memory_dir(&self, team: &str) -> PathBuf {
179        self.teams_base.join(team).join("memory")
180    }
181
182    /// File path for a specific agent's memory.
183    fn memory_path(&self, team: &str, agent: &str) -> PathBuf {
184        self.memory_dir(team).join(format!("{agent}.json"))
185    }
186
187    /// Lock file path for a team's memory directory.
188    fn lock_path(&self, team: &str) -> PathBuf {
189        self.memory_dir(team).join(".lock")
190    }
191
192    /// Save a conversation memory to disk (atomic write with file lock).
193    pub fn save(&self, team: &str, agent: &str, memory: &ConversationMemory) -> Result<()> {
194        let dir = self.memory_dir(team);
195        std::fs::create_dir_all(&dir)?;
196
197        let lock_path = self.lock_path(team);
198        let _lock = FileLock::acquire(&lock_path)?;
199
200        let path = self.memory_path(team, agent);
201        atomic_write_json(&path, memory)?;
202        Ok(())
203    }
204
205    /// Load a conversation memory from disk. Returns `None` if not found.
206    pub fn load(&self, team: &str, agent: &str) -> Result<Option<ConversationMemory>> {
207        let path = self.memory_path(team, agent);
208        if !path.exists() {
209            return Ok(None);
210        }
211
212        let dir = self.memory_dir(team);
213        std::fs::create_dir_all(&dir)?;
214
215        let lock_path = self.lock_path(team);
216        let _lock = FileLock::acquire(&lock_path)?;
217
218        let data = std::fs::read_to_string(&path).map_err(|e| {
219            if e.kind() == std::io::ErrorKind::NotFound {
220                return Error::Other(format!("Memory file not found: {}", path.display()));
221            }
222            Error::Io(e)
223        })?;
224
225        let memory: ConversationMemory = serde_json::from_str(&data)?;
226        Ok(Some(memory))
227    }
228
229    /// Delete a conversation memory from disk.
230    pub fn delete(&self, team: &str, agent: &str) -> Result<()> {
231        let path = self.memory_path(team, agent);
232        if path.exists() {
233            let dir = self.memory_dir(team);
234            std::fs::create_dir_all(&dir)?;
235
236            let lock_path = self.lock_path(team);
237            let _lock = FileLock::acquire(&lock_path)?;
238
239            std::fs::remove_file(&path)?;
240        }
241        Ok(())
242    }
243}
244
245// ---------------------------------------------------------------------------
246// Tests
247// ---------------------------------------------------------------------------
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn record_and_format_context() {
255        let mut mem = ConversationMemory::with_defaults();
256        mem.record_turn(Role::User, "Hello");
257        mem.record_turn(Role::Assistant, "Hi there!");
258
259        let ctx = mem.format_context();
260        assert!(ctx.starts_with("[Conversation Context]\n"));
261        assert!(ctx.contains("User: Hello"));
262        assert!(ctx.contains("Assistant: Hi there!"));
263        assert_eq!(mem.len(), 2);
264    }
265
266    #[test]
267    fn evicts_oldest_when_max_turns_exceeded() {
268        let config = MemoryConfig {
269            max_turns: 2,
270            budget_chars: 10000,
271        };
272        let mut mem = ConversationMemory::new(config);
273
274        mem.record_turn(Role::User, "first");
275        mem.record_turn(Role::Assistant, "second");
276        mem.record_turn(Role::User, "third");
277
278        assert_eq!(mem.len(), 2);
279        let ctx = mem.format_context();
280        assert!(!ctx.contains("first"), "oldest turn should be evicted");
281        assert!(ctx.contains("second"));
282        assert!(ctx.contains("third"));
283    }
284
285    #[test]
286    fn budget_truncation() {
287        let config = MemoryConfig {
288            max_turns: 100,
289            budget_chars: 60, // very tight budget
290        };
291        let mut mem = ConversationMemory::new(config);
292
293        mem.record_turn(Role::User, "AAAA BBBB CCCC DDDD");
294        mem.record_turn(Role::Assistant, "EEEE FFFF GGGG HHHH");
295        mem.record_turn(Role::User, "IIII JJJJ KKKK LLLL");
296
297        let ctx = mem.format_context();
298        // With a 60-char budget, not all 3 turns should fit
299        // (header ~24 chars, each turn line ~30+ chars)
300        assert!(ctx.len() <= 60 + 30); // some tolerance for the header
301        // The most recent turns should be preferred
302        assert!(ctx.contains("IIII") || ctx.contains("EEEE"));
303    }
304
305    #[test]
306    fn empty_memory_formats_to_empty_string() {
307        let mem = ConversationMemory::with_defaults();
308        assert_eq!(mem.format_context(), "");
309        assert!(mem.is_empty());
310        assert_eq!(mem.len(), 0);
311    }
312
313    #[test]
314    fn clear_removes_all_turns() {
315        let mut mem = ConversationMemory::with_defaults();
316        mem.record_turn(Role::User, "hello");
317        mem.record_turn(Role::Assistant, "world");
318        assert_eq!(mem.len(), 2);
319
320        mem.clear();
321        assert!(mem.is_empty());
322        assert_eq!(mem.format_context(), "");
323    }
324
325    #[test]
326    fn serde_round_trip() {
327        let mut mem = ConversationMemory::with_defaults();
328        mem.record_turn(Role::User, "question");
329        mem.record_turn(Role::Assistant, "answer");
330
331        let json = serde_json::to_string_pretty(&mem).unwrap();
332        let parsed: ConversationMemory = serde_json::from_str(&json).unwrap();
333
334        assert_eq!(parsed.len(), 2);
335        assert_eq!(parsed.config().max_turns, 5);
336        let ctx = parsed.format_context();
337        assert!(ctx.contains("question"));
338        assert!(ctx.contains("answer"));
339    }
340
341    #[test]
342    fn memory_manager_save_load_delete() {
343        let dir = tempfile::tempdir().unwrap();
344        let mgr = MemoryManager::new(dir.path().to_path_buf());
345
346        let mut mem = ConversationMemory::with_defaults();
347        mem.record_turn(Role::User, "ping");
348        mem.record_turn(Role::Assistant, "pong");
349
350        // Save
351        mgr.save("team1", "agent1", &mem).unwrap();
352
353        // Load
354        let loaded = mgr.load("team1", "agent1").unwrap().unwrap();
355        assert_eq!(loaded.len(), 2);
356        assert!(loaded.format_context().contains("pong"));
357
358        // Delete
359        mgr.delete("team1", "agent1").unwrap();
360        assert!(mgr.load("team1", "agent1").unwrap().is_none());
361    }
362
363    #[test]
364    fn memory_manager_load_nonexistent() {
365        let dir = tempfile::tempdir().unwrap();
366        let mgr = MemoryManager::new(dir.path().to_path_buf());
367
368        let result = mgr.load("no-team", "no-agent").unwrap();
369        assert!(result.is_none());
370    }
371}