Skip to main content

ai_session/persistence/
mod.rs

1//! Session state persistence and recovery
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::PathBuf;
7use tokio::fs;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9
10use crate::context::SessionContext;
11use crate::core::{SessionConfig, SessionId, SessionStatus};
12
13/// Manages persistent storage of session state
14pub struct PersistenceManager {
15    /// Base directory for session storage
16    storage_path: PathBuf,
17    /// Compression enabled
18    enable_compression: bool,
19    /// Encryption key (optional)
20    encryption_key: Option<Vec<u8>>,
21}
22
23impl PersistenceManager {
24    /// Create a new persistence manager
25    pub fn new(storage_path: PathBuf) -> Self {
26        Self {
27            storage_path,
28            enable_compression: true,
29            encryption_key: None,
30        }
31    }
32
33    /// Enable encryption with key
34    pub fn with_encryption(mut self, key: Vec<u8>) -> Self {
35        self.encryption_key = Some(key);
36        self
37    }
38
39    /// Save session state
40    pub async fn save_session(&self, session_id: &SessionId, state: &SessionState) -> Result<()> {
41        let session_dir = self.session_directory(session_id);
42        fs::create_dir_all(&session_dir).await?;
43
44        // Serialize state
45        let data = serde_json::to_vec_pretty(state)?;
46
47        // Optionally compress
48        let data = if self.enable_compression {
49            self.compress_data(&data)?
50        } else {
51            data
52        };
53
54        // Optionally encrypt
55        let data = if let Some(key) = &self.encryption_key {
56            self.encrypt_data(&data, key)?
57        } else {
58            data
59        };
60
61        // Write to file
62        let state_file = session_dir.join("state.json");
63        let mut file = fs::File::create(&state_file).await?;
64        file.write_all(&data).await?;
65        file.sync_all().await?;
66
67        Ok(())
68    }
69
70    /// Load session state
71    pub async fn load_session(&self, session_id: &SessionId) -> Result<SessionState> {
72        let state_file = self.session_directory(session_id).join("state.json");
73
74        // Read file
75        let mut file = fs::File::open(&state_file).await?;
76        let mut data = Vec::new();
77        file.read_to_end(&mut data).await?;
78
79        // Optionally decrypt
80        let data = if let Some(key) = &self.encryption_key {
81            self.decrypt_data(&data, key)?
82        } else {
83            data
84        };
85
86        // Optionally decompress
87        let data = if self.enable_compression {
88            self.decompress_data(&data)?
89        } else {
90            data
91        };
92
93        // Deserialize
94        let state: SessionState = serde_json::from_slice(&data)?;
95        Ok(state)
96    }
97
98    /// List all saved sessions
99    pub async fn list_sessions(&self) -> Result<Vec<SessionId>> {
100        let mut sessions = Vec::new();
101
102        let mut entries = fs::read_dir(&self.storage_path).await?;
103        while let Some(entry) = entries.next_entry().await? {
104            if entry.file_type().await?.is_dir()
105                && let Ok(name) = entry.file_name().into_string()
106                && let Ok(id) = SessionId::parse_str(&name)
107            {
108                sessions.push(id);
109            }
110        }
111
112        Ok(sessions)
113    }
114
115    /// Delete session data
116    pub async fn delete_session(&self, session_id: &SessionId) -> Result<()> {
117        let session_dir = self.session_directory(session_id);
118        if session_dir.exists() {
119            fs::remove_dir_all(&session_dir).await?;
120        }
121        Ok(())
122    }
123
124    /// Get session directory
125    fn session_directory(&self, session_id: &SessionId) -> PathBuf {
126        self.storage_path.join(session_id.to_string())
127    }
128
129    /// Compress data
130    fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
131        use zstd::stream::encode_all;
132
133        encode_all(data, 3).map_err(|e| anyhow::anyhow!("Failed to compress data: {}", e))
134    }
135
136    /// Decompress data
137    fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
138        use zstd::stream::decode_all;
139
140        decode_all(data).map_err(|e| anyhow::anyhow!("Failed to decompress data: {}", e))
141    }
142
143    /// Encrypt data (simplified - use proper crypto in production)
144    fn encrypt_data(&self, data: &[u8], _key: &[u8]) -> Result<Vec<u8>> {
145        // TODO: Implement proper encryption
146        Ok(data.to_vec())
147    }
148
149    /// Decrypt data (simplified - use proper crypto in production)
150    fn decrypt_data(&self, data: &[u8], _key: &[u8]) -> Result<Vec<u8>> {
151        // TODO: Implement proper decryption
152        Ok(data.to_vec())
153    }
154}
155
156/// Persistent session state
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct SessionState {
159    /// Session ID
160    pub session_id: SessionId,
161    /// Session configuration
162    #[serde(default)]
163    pub config: SessionConfig,
164    /// Current status
165    #[serde(default)]
166    pub status: SessionStatus,
167    /// Session context
168    pub context: SessionContext,
169    /// Command history
170    #[serde(default)]
171    pub command_history: Vec<CommandRecord>,
172    /// Session metadata
173    pub metadata: SessionMetadata,
174}
175
176/// Command execution record
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct CommandRecord {
179    /// Command text
180    pub command: String,
181    /// Execution timestamp
182    pub timestamp: chrono::DateTime<chrono::Utc>,
183    /// Exit code
184    pub exit_code: Option<i32>,
185    /// Output preview
186    pub output_preview: String,
187    /// Execution duration
188    pub duration_ms: u64,
189}
190
191/// Session metadata
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct SessionMetadata {
194    /// Creation time
195    pub created_at: chrono::DateTime<chrono::Utc>,
196    /// Last accessed time
197    pub last_accessed: chrono::DateTime<chrono::Utc>,
198    /// Total commands executed
199    pub command_count: usize,
200    /// Total tokens used
201    pub total_tokens: usize,
202    /// Custom metadata
203    pub custom: HashMap<String, serde_json::Value>,
204}
205
206impl Default for SessionMetadata {
207    fn default() -> Self {
208        let now = chrono::Utc::now();
209        Self {
210            created_at: now,
211            last_accessed: now,
212            command_count: 0,
213            total_tokens: 0,
214            custom: HashMap::new(),
215        }
216    }
217}
218
219/// Session snapshot for quick restore
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct SessionSnapshot {
222    /// Snapshot ID
223    pub id: String,
224    /// Creation timestamp
225    pub created_at: chrono::DateTime<chrono::Utc>,
226    /// Session state at snapshot time
227    pub state: SessionState,
228    /// Snapshot description
229    pub description: Option<String>,
230}
231
232/// Manages session snapshots
233pub struct SnapshotManager {
234    /// Base directory for snapshots
235    snapshot_path: PathBuf,
236}
237
238impl SnapshotManager {
239    /// Create new snapshot manager
240    pub fn new(snapshot_path: PathBuf) -> Self {
241        Self { snapshot_path }
242    }
243
244    /// Create a snapshot
245    pub async fn create_snapshot(
246        &self,
247        session_id: &SessionId,
248        state: &SessionState,
249        description: Option<String>,
250    ) -> Result<String> {
251        let snapshot = SessionSnapshot {
252            id: uuid::Uuid::new_v4().to_string(),
253            created_at: chrono::Utc::now(),
254            state: state.clone(),
255            description,
256        };
257
258        let snapshot_dir = self.snapshot_path.join(session_id.to_string());
259        fs::create_dir_all(&snapshot_dir).await?;
260
261        let snapshot_file = snapshot_dir.join(format!("{}.json", snapshot.id));
262        let data = serde_json::to_vec_pretty(&snapshot)?;
263        fs::write(&snapshot_file, data).await?;
264
265        Ok(snapshot.id)
266    }
267
268    /// List snapshots for a session
269    pub async fn list_snapshots(&self, session_id: &SessionId) -> Result<Vec<SessionSnapshot>> {
270        let snapshot_dir = self.snapshot_path.join(session_id.to_string());
271        if !snapshot_dir.exists() {
272            return Ok(Vec::new());
273        }
274
275        let mut snapshots = Vec::new();
276        let mut entries = fs::read_dir(&snapshot_dir).await?;
277
278        while let Some(entry) = entries.next_entry().await? {
279            if entry
280                .path()
281                .extension()
282                .map(|e| e == "json")
283                .unwrap_or(false)
284            {
285                let data = fs::read(entry.path()).await?;
286                if let Ok(snapshot) = serde_json::from_slice::<SessionSnapshot>(&data) {
287                    snapshots.push(snapshot);
288                }
289            }
290        }
291
292        // Sort by creation time
293        snapshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
294        Ok(snapshots)
295    }
296
297    /// Restore from snapshot
298    pub async fn restore_snapshot(
299        &self,
300        session_id: &SessionId,
301        snapshot_id: &str,
302    ) -> Result<SessionState> {
303        let snapshot_file = self
304            .snapshot_path
305            .join(session_id.to_string())
306            .join(format!("{}.json", snapshot_id));
307
308        let data = fs::read(&snapshot_file).await?;
309        let snapshot: SessionSnapshot = serde_json::from_slice(&data)?;
310
311        Ok(snapshot.state)
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use tempfile::TempDir;
319
320    #[tokio::test]
321    async fn test_persistence_manager() {
322        let temp_dir = TempDir::new().unwrap();
323        let manager = PersistenceManager::new(temp_dir.path().to_path_buf());
324
325        let session_id = SessionId::new_v4();
326        let state = SessionState {
327            session_id: session_id.clone(),
328            config: SessionConfig::default(),
329            status: SessionStatus::Running,
330            context: SessionContext::new(session_id.clone()),
331            command_history: vec![],
332            metadata: SessionMetadata::default(),
333        };
334
335        // Save and load
336        manager.save_session(&session_id, &state).await.unwrap();
337        let loaded = manager.load_session(&session_id).await.unwrap();
338
339        assert_eq!(loaded.session_id, state.session_id);
340    }
341}