Skip to main content

haagenti_serverless/
snapshot.rs

1//! GPU memory snapshot and restore for fast recovery
2
3use crate::{Result, ServerlessError};
4use arcanum_primitives::prelude::Blake3;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::Instant;
9
10/// Snapshot configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SnapshotConfig {
13    /// Snapshot directory
14    pub snapshot_dir: PathBuf,
15    /// Enable compression
16    pub compression: bool,
17    /// Compression level (1-22 for zstd)
18    pub compression_level: i32,
19    /// Enable incremental snapshots
20    pub incremental: bool,
21    /// Maximum snapshot age in seconds
22    pub max_age_seconds: u64,
23    /// Enable checksum verification
24    pub verify_checksum: bool,
25}
26
27impl Default for SnapshotConfig {
28    fn default() -> Self {
29        Self {
30            snapshot_dir: PathBuf::from("/tmp/haagenti-snapshots"),
31            compression: true,
32            compression_level: 3,
33            incremental: true,
34            max_age_seconds: 3600, // 1 hour
35            verify_checksum: true,
36        }
37    }
38}
39
40/// GPU memory snapshot
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct GpuSnapshot {
43    /// Snapshot ID
44    pub id: String,
45    /// Version
46    pub version: u32,
47    /// Creation timestamp (unix ms)
48    pub created_at: u64,
49    /// Total size in bytes
50    pub total_size: u64,
51    /// Buffer snapshots
52    pub buffers: Vec<BufferSnapshot>,
53    /// Model weights hash
54    pub weights_hash: String,
55    /// Checksum
56    pub checksum: String,
57}
58
59/// Individual buffer snapshot
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct BufferSnapshot {
62    /// Buffer name
63    pub name: String,
64    /// Buffer size
65    pub size: u64,
66    /// Offset in snapshot file
67    pub offset: u64,
68    /// Compressed size (if compression enabled)
69    pub compressed_size: Option<u64>,
70    /// Buffer type
71    pub buffer_type: BufferType,
72}
73
74/// Buffer type
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76pub enum BufferType {
77    /// Model weights (read-only)
78    Weights,
79    /// KV cache
80    KvCache,
81    /// Activations
82    Activations,
83    /// Gradients
84    Gradients,
85    /// Optimizer state
86    OptimizerState,
87    /// Other
88    Other,
89}
90
91impl GpuSnapshot {
92    /// Create new snapshot
93    pub fn new(id: impl Into<String>) -> Self {
94        Self {
95            id: id.into(),
96            version: 1,
97            created_at: std::time::SystemTime::now()
98                .duration_since(std::time::UNIX_EPOCH)
99                .unwrap_or_default()
100                .as_millis() as u64,
101            total_size: 0,
102            buffers: Vec::new(),
103            weights_hash: String::new(),
104            checksum: String::new(),
105        }
106    }
107
108    /// Add buffer to snapshot
109    pub fn add_buffer(&mut self, name: impl Into<String>, size: u64, buffer_type: BufferType) {
110        let offset = self.total_size;
111        self.buffers.push(BufferSnapshot {
112            name: name.into(),
113            size,
114            offset,
115            compressed_size: None,
116            buffer_type,
117        });
118        self.total_size += size;
119    }
120
121    /// Get buffer by name
122    pub fn get_buffer(&self, name: &str) -> Option<&BufferSnapshot> {
123        self.buffers.iter().find(|b| b.name == name)
124    }
125
126    /// Get all buffers of a type
127    pub fn get_buffers_by_type(&self, buffer_type: BufferType) -> Vec<&BufferSnapshot> {
128        self.buffers
129            .iter()
130            .filter(|b| b.buffer_type == buffer_type)
131            .collect()
132    }
133
134    /// Compute checksum
135    pub fn compute_checksum(&mut self, data: &[u8]) {
136        let hash = Blake3::hash(data);
137        self.checksum = hash.iter().map(|b| format!("{:02x}", b)).collect();
138    }
139
140    /// Verify checksum
141    pub fn verify_checksum(&self, data: &[u8]) -> bool {
142        let hash = Blake3::hash(data);
143        let computed: String = hash.iter().map(|b| format!("{:02x}", b)).collect();
144        computed == self.checksum
145    }
146}
147
148/// Snapshot manager
149#[derive(Debug)]
150pub struct SnapshotManager {
151    /// Configuration
152    config: SnapshotConfig,
153    /// Cached snapshots
154    snapshots: HashMap<String, GpuSnapshot>,
155    /// Statistics
156    stats: SnapshotStats,
157}
158
159/// Snapshot statistics
160#[derive(Debug, Default)]
161pub struct SnapshotStats {
162    /// Total snapshots created
163    pub created: u64,
164    /// Total snapshots restored
165    pub restored: u64,
166    /// Total bytes saved
167    pub bytes_saved: u64,
168    /// Total bytes restored
169    pub bytes_restored: u64,
170    /// Average save time ms
171    pub avg_save_ms: f64,
172    /// Average restore time ms
173    pub avg_restore_ms: f64,
174}
175
176impl SnapshotManager {
177    /// Create new manager
178    pub fn new(config: SnapshotConfig) -> Self {
179        Self {
180            config,
181            snapshots: HashMap::new(),
182            stats: SnapshotStats::default(),
183        }
184    }
185
186    /// Create a snapshot
187    pub async fn create_snapshot(
188        &mut self,
189        id: impl Into<String>,
190        buffers: Vec<(String, Vec<u8>, BufferType)>,
191    ) -> Result<GpuSnapshot> {
192        let start = Instant::now();
193        let id = id.into();
194
195        let mut snapshot = GpuSnapshot::new(&id);
196        let mut data = Vec::new();
197
198        for (name, buffer_data, buffer_type) in buffers {
199            snapshot.add_buffer(&name, buffer_data.len() as u64, buffer_type);
200
201            if self.config.compression {
202                // In real implementation, compress with zstd
203                data.extend_from_slice(&buffer_data);
204            } else {
205                data.extend_from_slice(&buffer_data);
206            }
207        }
208
209        if self.config.verify_checksum {
210            snapshot.compute_checksum(&data);
211        }
212
213        // Save to disk
214        self.save_to_disk(&snapshot, &data).await?;
215
216        self.snapshots.insert(id, snapshot.clone());
217        self.stats.created += 1;
218        self.stats.bytes_saved += snapshot.total_size;
219
220        let elapsed = start.elapsed().as_millis() as f64;
221        self.stats.avg_save_ms = (self.stats.avg_save_ms * (self.stats.created - 1) as f64
222            + elapsed)
223            / self.stats.created as f64;
224
225        Ok(snapshot)
226    }
227
228    /// Restore a snapshot
229    pub async fn restore_snapshot(&mut self, id: &str) -> Result<Vec<(String, Vec<u8>)>> {
230        let start = Instant::now();
231
232        // Try to load from cache
233        let snapshot = if let Some(s) = self.snapshots.get(id) {
234            s.clone()
235        } else {
236            // Load from disk
237            self.load_from_disk(id).await?
238        };
239
240        // Load data from disk
241        let data = self.load_data(id).await?;
242
243        // Verify checksum
244        if self.config.verify_checksum && !snapshot.verify_checksum(&data) {
245            return Err(ServerlessError::SnapshotError(
246                "Checksum verification failed".into(),
247            ));
248        }
249
250        // Extract buffers
251        let mut buffers = Vec::new();
252        for buffer in &snapshot.buffers {
253            let start = buffer.offset as usize;
254            let end = start + buffer.size as usize;
255            let buffer_data = data[start..end].to_vec();
256            buffers.push((buffer.name.clone(), buffer_data));
257        }
258
259        self.stats.restored += 1;
260        self.stats.bytes_restored += snapshot.total_size;
261
262        let elapsed = start.elapsed().as_millis() as f64;
263        self.stats.avg_restore_ms = (self.stats.avg_restore_ms * (self.stats.restored - 1) as f64
264            + elapsed)
265            / self.stats.restored as f64;
266
267        Ok(buffers)
268    }
269
270    /// Save snapshot to disk
271    async fn save_to_disk(&self, snapshot: &GpuSnapshot, data: &[u8]) -> Result<()> {
272        let dir = &self.config.snapshot_dir;
273        std::fs::create_dir_all(dir)?;
274
275        // Save metadata
276        let meta_path = dir.join(format!("{}.meta.json", snapshot.id));
277        let meta_json = serde_json::to_string_pretty(snapshot)
278            .map_err(|e| ServerlessError::SerializationError(e.to_string()))?;
279        std::fs::write(&meta_path, meta_json)?;
280
281        // Save data
282        let data_path = dir.join(format!("{}.data", snapshot.id));
283        std::fs::write(&data_path, data)?;
284
285        Ok(())
286    }
287
288    /// Load snapshot from disk
289    async fn load_from_disk(&mut self, id: &str) -> Result<GpuSnapshot> {
290        let meta_path = self.config.snapshot_dir.join(format!("{}.meta.json", id));
291        let meta_json = std::fs::read_to_string(&meta_path)?;
292        let snapshot: GpuSnapshot = serde_json::from_str(&meta_json)
293            .map_err(|e| ServerlessError::DeserializationError(e.to_string()))?;
294
295        self.snapshots.insert(id.to_string(), snapshot.clone());
296        Ok(snapshot)
297    }
298
299    /// Load data from disk
300    async fn load_data(&self, id: &str) -> Result<Vec<u8>> {
301        let data_path = self.config.snapshot_dir.join(format!("{}.data", id));
302        let data = std::fs::read(&data_path)?;
303        Ok(data)
304    }
305
306    /// List available snapshots
307    pub fn list_snapshots(&self) -> Vec<&str> {
308        self.snapshots.keys().map(|s| s.as_str()).collect()
309    }
310
311    /// Delete snapshot
312    pub fn delete_snapshot(&mut self, id: &str) -> Result<()> {
313        self.snapshots.remove(id);
314
315        let meta_path = self.config.snapshot_dir.join(format!("{}.meta.json", id));
316        let data_path = self.config.snapshot_dir.join(format!("{}.data", id));
317
318        if meta_path.exists() {
319            std::fs::remove_file(meta_path)?;
320        }
321        if data_path.exists() {
322            std::fs::remove_file(data_path)?;
323        }
324
325        Ok(())
326    }
327
328    /// Clear old snapshots
329    pub fn clear_old(&mut self) -> Result<usize> {
330        let now = std::time::SystemTime::now()
331            .duration_since(std::time::UNIX_EPOCH)
332            .unwrap_or_default()
333            .as_millis() as u64;
334
335        let max_age_ms = self.config.max_age_seconds * 1000;
336        let mut to_delete = Vec::new();
337
338        for (id, snapshot) in &self.snapshots {
339            if now - snapshot.created_at > max_age_ms {
340                to_delete.push(id.clone());
341            }
342        }
343
344        for id in &to_delete {
345            self.delete_snapshot(id)?;
346        }
347
348        Ok(to_delete.len())
349    }
350
351    /// Get statistics
352    pub fn stats(&self) -> &SnapshotStats {
353        &self.stats
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_config_default() {
363        let config = SnapshotConfig::default();
364        assert!(config.compression);
365        assert!(config.verify_checksum);
366    }
367
368    #[test]
369    fn test_snapshot_creation() {
370        let mut snapshot = GpuSnapshot::new("test-snapshot");
371
372        snapshot.add_buffer("weights", 1024, BufferType::Weights);
373        snapshot.add_buffer("kv_cache", 512, BufferType::KvCache);
374
375        assert_eq!(snapshot.buffers.len(), 2);
376        assert_eq!(snapshot.total_size, 1536);
377    }
378
379    #[test]
380    fn test_buffer_lookup() {
381        let mut snapshot = GpuSnapshot::new("test");
382
383        snapshot.add_buffer("weights", 1024, BufferType::Weights);
384        snapshot.add_buffer("cache", 512, BufferType::KvCache);
385
386        assert!(snapshot.get_buffer("weights").is_some());
387        assert!(snapshot.get_buffer("nonexistent").is_none());
388
389        let weights = snapshot.get_buffers_by_type(BufferType::Weights);
390        assert_eq!(weights.len(), 1);
391    }
392
393    #[test]
394    fn test_checksum() {
395        let mut snapshot = GpuSnapshot::new("test");
396        let data = vec![1, 2, 3, 4, 5];
397
398        snapshot.compute_checksum(&data);
399        assert!(!snapshot.checksum.is_empty());
400        assert!(snapshot.verify_checksum(&data));
401        assert!(!snapshot.verify_checksum(&[1, 2, 3]));
402    }
403
404    #[test]
405    fn test_manager_creation() {
406        let config = SnapshotConfig::default();
407        let manager = SnapshotManager::new(config);
408
409        assert!(manager.list_snapshots().is_empty());
410    }
411}