Skip to main content

ai_agents_storage/
storage.rs

1use std::path::{Path, PathBuf};
2
3use async_trait::async_trait;
4
5use ai_agents_core::{AgentSnapshot, AgentStorage, Result};
6
7pub struct FileStorage {
8    base_path: PathBuf,
9}
10
11impl FileStorage {
12    pub fn new(base_path: impl AsRef<Path>) -> Self {
13        Self {
14            base_path: base_path.as_ref().to_path_buf(),
15        }
16    }
17
18    fn session_path(&self, session_id: &str) -> PathBuf {
19        self.base_path.join(format!("{}.json", session_id))
20    }
21}
22
23#[async_trait]
24impl AgentStorage for FileStorage {
25    async fn save(&self, session_id: &str, snapshot: &AgentSnapshot) -> Result<()> {
26        tokio::fs::create_dir_all(&self.base_path).await?;
27        let path = self.session_path(session_id);
28        let json = serde_json::to_string_pretty(snapshot)?;
29        tokio::fs::write(path, json).await?;
30        Ok(())
31    }
32
33    async fn load(&self, session_id: &str) -> Result<Option<AgentSnapshot>> {
34        let path = self.session_path(session_id);
35        if !path.exists() {
36            return Ok(None);
37        }
38        let json = tokio::fs::read_to_string(path).await?;
39        let snapshot = serde_json::from_str(&json)?;
40        Ok(Some(snapshot))
41    }
42
43    async fn delete(&self, session_id: &str) -> Result<()> {
44        let path = self.session_path(session_id);
45        if path.exists() {
46            tokio::fs::remove_file(path).await?;
47        }
48        Ok(())
49    }
50
51    async fn list_sessions(&self) -> Result<Vec<String>> {
52        let mut sessions = Vec::new();
53        if !self.base_path.exists() {
54            return Ok(sessions);
55        }
56
57        let mut entries = tokio::fs::read_dir(&self.base_path).await?;
58        while let Some(entry) = entries.next_entry().await? {
59            let path = entry.path();
60            if path.extension().map(|e| e == "json").unwrap_or(false) {
61                if let Some(name) = path.file_stem() {
62                    sessions.push(name.to_string_lossy().to_string());
63                }
64            }
65        }
66        Ok(sessions)
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use tempfile::TempDir;
74
75    #[tokio::test]
76    async fn test_save_and_load() {
77        let temp_dir = TempDir::new().unwrap();
78        let storage = FileStorage::new(temp_dir.path());
79
80        let snapshot = AgentSnapshot::new("test-agent".into());
81        storage.save("session-1", &snapshot).await.unwrap();
82
83        let loaded = storage.load("session-1").await.unwrap();
84        assert!(loaded.is_some());
85        assert_eq!(loaded.unwrap().agent_id, "test-agent");
86    }
87
88    #[tokio::test]
89    async fn test_load_nonexistent() {
90        let temp_dir = TempDir::new().unwrap();
91        let storage = FileStorage::new(temp_dir.path());
92
93        let loaded = storage.load("nonexistent").await.unwrap();
94        assert!(loaded.is_none());
95    }
96
97    #[tokio::test]
98    async fn test_delete() {
99        let temp_dir = TempDir::new().unwrap();
100        let storage = FileStorage::new(temp_dir.path());
101
102        let snapshot = AgentSnapshot::new("test-agent".into());
103        storage.save("session-1", &snapshot).await.unwrap();
104        assert!(storage.load("session-1").await.unwrap().is_some());
105
106        storage.delete("session-1").await.unwrap();
107        assert!(storage.load("session-1").await.unwrap().is_none());
108    }
109
110    #[tokio::test]
111    async fn test_list_sessions() {
112        let temp_dir = TempDir::new().unwrap();
113        let storage = FileStorage::new(temp_dir.path());
114
115        storage
116            .save("session-1", &AgentSnapshot::new("agent".into()))
117            .await
118            .unwrap();
119        storage
120            .save("session-2", &AgentSnapshot::new("agent".into()))
121            .await
122            .unwrap();
123
124        let sessions = storage.list_sessions().await.unwrap();
125        assert_eq!(sessions.len(), 2);
126        assert!(sessions.contains(&"session-1".to_string()));
127        assert!(sessions.contains(&"session-2".to_string()));
128    }
129}