1use crate::{Result, ServerlessError};
4use arcanum_primitives::prelude::Blake3;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::Instant;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SnapshotConfig {
13 pub snapshot_dir: PathBuf,
15 pub compression: bool,
17 pub compression_level: i32,
19 pub incremental: bool,
21 pub max_age_seconds: u64,
23 pub verify_checksum: bool,
25}
26
27impl Default for SnapshotConfig {
28 fn default() -> Self {
29 Self {
30 snapshot_dir: PathBuf::from("/tmp/haagenti-snapshots"),
31 compression: true,
32 compression_level: 3,
33 incremental: true,
34 max_age_seconds: 3600, verify_checksum: true,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct GpuSnapshot {
43 pub id: String,
45 pub version: u32,
47 pub created_at: u64,
49 pub total_size: u64,
51 pub buffers: Vec<BufferSnapshot>,
53 pub weights_hash: String,
55 pub checksum: String,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct BufferSnapshot {
62 pub name: String,
64 pub size: u64,
66 pub offset: u64,
68 pub compressed_size: Option<u64>,
70 pub buffer_type: BufferType,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76pub enum BufferType {
77 Weights,
79 KvCache,
81 Activations,
83 Gradients,
85 OptimizerState,
87 Other,
89}
90
91impl GpuSnapshot {
92 pub fn new(id: impl Into<String>) -> Self {
94 Self {
95 id: id.into(),
96 version: 1,
97 created_at: std::time::SystemTime::now()
98 .duration_since(std::time::UNIX_EPOCH)
99 .unwrap_or_default()
100 .as_millis() as u64,
101 total_size: 0,
102 buffers: Vec::new(),
103 weights_hash: String::new(),
104 checksum: String::new(),
105 }
106 }
107
108 pub fn add_buffer(&mut self, name: impl Into<String>, size: u64, buffer_type: BufferType) {
110 let offset = self.total_size;
111 self.buffers.push(BufferSnapshot {
112 name: name.into(),
113 size,
114 offset,
115 compressed_size: None,
116 buffer_type,
117 });
118 self.total_size += size;
119 }
120
121 pub fn get_buffer(&self, name: &str) -> Option<&BufferSnapshot> {
123 self.buffers.iter().find(|b| b.name == name)
124 }
125
126 pub fn get_buffers_by_type(&self, buffer_type: BufferType) -> Vec<&BufferSnapshot> {
128 self.buffers
129 .iter()
130 .filter(|b| b.buffer_type == buffer_type)
131 .collect()
132 }
133
134 pub fn compute_checksum(&mut self, data: &[u8]) {
136 let hash = Blake3::hash(data);
137 self.checksum = hash.iter().map(|b| format!("{:02x}", b)).collect();
138 }
139
140 pub fn verify_checksum(&self, data: &[u8]) -> bool {
142 let hash = Blake3::hash(data);
143 let computed: String = hash.iter().map(|b| format!("{:02x}", b)).collect();
144 computed == self.checksum
145 }
146}
147
148#[derive(Debug)]
150pub struct SnapshotManager {
151 config: SnapshotConfig,
153 snapshots: HashMap<String, GpuSnapshot>,
155 stats: SnapshotStats,
157}
158
159#[derive(Debug, Default)]
161pub struct SnapshotStats {
162 pub created: u64,
164 pub restored: u64,
166 pub bytes_saved: u64,
168 pub bytes_restored: u64,
170 pub avg_save_ms: f64,
172 pub avg_restore_ms: f64,
174}
175
176impl SnapshotManager {
177 pub fn new(config: SnapshotConfig) -> Self {
179 Self {
180 config,
181 snapshots: HashMap::new(),
182 stats: SnapshotStats::default(),
183 }
184 }
185
186 pub async fn create_snapshot(
188 &mut self,
189 id: impl Into<String>,
190 buffers: Vec<(String, Vec<u8>, BufferType)>,
191 ) -> Result<GpuSnapshot> {
192 let start = Instant::now();
193 let id = id.into();
194
195 let mut snapshot = GpuSnapshot::new(&id);
196 let mut data = Vec::new();
197
198 for (name, buffer_data, buffer_type) in buffers {
199 snapshot.add_buffer(&name, buffer_data.len() as u64, buffer_type);
200
201 if self.config.compression {
202 data.extend_from_slice(&buffer_data);
204 } else {
205 data.extend_from_slice(&buffer_data);
206 }
207 }
208
209 if self.config.verify_checksum {
210 snapshot.compute_checksum(&data);
211 }
212
213 self.save_to_disk(&snapshot, &data).await?;
215
216 self.snapshots.insert(id, snapshot.clone());
217 self.stats.created += 1;
218 self.stats.bytes_saved += snapshot.total_size;
219
220 let elapsed = start.elapsed().as_millis() as f64;
221 self.stats.avg_save_ms = (self.stats.avg_save_ms * (self.stats.created - 1) as f64
222 + elapsed)
223 / self.stats.created as f64;
224
225 Ok(snapshot)
226 }
227
228 pub async fn restore_snapshot(&mut self, id: &str) -> Result<Vec<(String, Vec<u8>)>> {
230 let start = Instant::now();
231
232 let snapshot = if let Some(s) = self.snapshots.get(id) {
234 s.clone()
235 } else {
236 self.load_from_disk(id).await?
238 };
239
240 let data = self.load_data(id).await?;
242
243 if self.config.verify_checksum && !snapshot.verify_checksum(&data) {
245 return Err(ServerlessError::SnapshotError(
246 "Checksum verification failed".into(),
247 ));
248 }
249
250 let mut buffers = Vec::new();
252 for buffer in &snapshot.buffers {
253 let start = buffer.offset as usize;
254 let end = start + buffer.size as usize;
255 let buffer_data = data[start..end].to_vec();
256 buffers.push((buffer.name.clone(), buffer_data));
257 }
258
259 self.stats.restored += 1;
260 self.stats.bytes_restored += snapshot.total_size;
261
262 let elapsed = start.elapsed().as_millis() as f64;
263 self.stats.avg_restore_ms = (self.stats.avg_restore_ms * (self.stats.restored - 1) as f64
264 + elapsed)
265 / self.stats.restored as f64;
266
267 Ok(buffers)
268 }
269
270 async fn save_to_disk(&self, snapshot: &GpuSnapshot, data: &[u8]) -> Result<()> {
272 let dir = &self.config.snapshot_dir;
273 std::fs::create_dir_all(dir)?;
274
275 let meta_path = dir.join(format!("{}.meta.json", snapshot.id));
277 let meta_json = serde_json::to_string_pretty(snapshot)
278 .map_err(|e| ServerlessError::SerializationError(e.to_string()))?;
279 std::fs::write(&meta_path, meta_json)?;
280
281 let data_path = dir.join(format!("{}.data", snapshot.id));
283 std::fs::write(&data_path, data)?;
284
285 Ok(())
286 }
287
288 async fn load_from_disk(&mut self, id: &str) -> Result<GpuSnapshot> {
290 let meta_path = self.config.snapshot_dir.join(format!("{}.meta.json", id));
291 let meta_json = std::fs::read_to_string(&meta_path)?;
292 let snapshot: GpuSnapshot = serde_json::from_str(&meta_json)
293 .map_err(|e| ServerlessError::DeserializationError(e.to_string()))?;
294
295 self.snapshots.insert(id.to_string(), snapshot.clone());
296 Ok(snapshot)
297 }
298
299 async fn load_data(&self, id: &str) -> Result<Vec<u8>> {
301 let data_path = self.config.snapshot_dir.join(format!("{}.data", id));
302 let data = std::fs::read(&data_path)?;
303 Ok(data)
304 }
305
306 pub fn list_snapshots(&self) -> Vec<&str> {
308 self.snapshots.keys().map(|s| s.as_str()).collect()
309 }
310
311 pub fn delete_snapshot(&mut self, id: &str) -> Result<()> {
313 self.snapshots.remove(id);
314
315 let meta_path = self.config.snapshot_dir.join(format!("{}.meta.json", id));
316 let data_path = self.config.snapshot_dir.join(format!("{}.data", id));
317
318 if meta_path.exists() {
319 std::fs::remove_file(meta_path)?;
320 }
321 if data_path.exists() {
322 std::fs::remove_file(data_path)?;
323 }
324
325 Ok(())
326 }
327
328 pub fn clear_old(&mut self) -> Result<usize> {
330 let now = std::time::SystemTime::now()
331 .duration_since(std::time::UNIX_EPOCH)
332 .unwrap_or_default()
333 .as_millis() as u64;
334
335 let max_age_ms = self.config.max_age_seconds * 1000;
336 let mut to_delete = Vec::new();
337
338 for (id, snapshot) in &self.snapshots {
339 if now - snapshot.created_at > max_age_ms {
340 to_delete.push(id.clone());
341 }
342 }
343
344 for id in &to_delete {
345 self.delete_snapshot(id)?;
346 }
347
348 Ok(to_delete.len())
349 }
350
351 pub fn stats(&self) -> &SnapshotStats {
353 &self.stats
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_config_default() {
363 let config = SnapshotConfig::default();
364 assert!(config.compression);
365 assert!(config.verify_checksum);
366 }
367
368 #[test]
369 fn test_snapshot_creation() {
370 let mut snapshot = GpuSnapshot::new("test-snapshot");
371
372 snapshot.add_buffer("weights", 1024, BufferType::Weights);
373 snapshot.add_buffer("kv_cache", 512, BufferType::KvCache);
374
375 assert_eq!(snapshot.buffers.len(), 2);
376 assert_eq!(snapshot.total_size, 1536);
377 }
378
379 #[test]
380 fn test_buffer_lookup() {
381 let mut snapshot = GpuSnapshot::new("test");
382
383 snapshot.add_buffer("weights", 1024, BufferType::Weights);
384 snapshot.add_buffer("cache", 512, BufferType::KvCache);
385
386 assert!(snapshot.get_buffer("weights").is_some());
387 assert!(snapshot.get_buffer("nonexistent").is_none());
388
389 let weights = snapshot.get_buffers_by_type(BufferType::Weights);
390 assert_eq!(weights.len(), 1);
391 }
392
393 #[test]
394 fn test_checksum() {
395 let mut snapshot = GpuSnapshot::new("test");
396 let data = vec![1, 2, 3, 4, 5];
397
398 snapshot.compute_checksum(&data);
399 assert!(!snapshot.checksum.is_empty());
400 assert!(snapshot.verify_checksum(&data));
401 assert!(!snapshot.verify_checksum(&[1, 2, 3]));
402 }
403
404 #[test]
405 fn test_manager_creation() {
406 let config = SnapshotConfig::default();
407 let manager = SnapshotManager::new(config);
408
409 assert!(manager.list_snapshots().is_empty());
410 }
411}