use crate::{Result, ServerlessError};
use arcanum_primitives::prelude::Blake3;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotConfig {
pub snapshot_dir: PathBuf,
pub compression: bool,
pub compression_level: i32,
pub incremental: bool,
pub max_age_seconds: u64,
pub verify_checksum: bool,
}
impl Default for SnapshotConfig {
fn default() -> Self {
Self {
snapshot_dir: PathBuf::from("/tmp/haagenti-snapshots"),
compression: true,
compression_level: 3,
incremental: true,
max_age_seconds: 3600, verify_checksum: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuSnapshot {
pub id: String,
pub version: u32,
pub created_at: u64,
pub total_size: u64,
pub buffers: Vec<BufferSnapshot>,
pub weights_hash: String,
pub checksum: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BufferSnapshot {
pub name: String,
pub size: u64,
pub offset: u64,
pub compressed_size: Option<u64>,
pub buffer_type: BufferType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BufferType {
Weights,
KvCache,
Activations,
Gradients,
OptimizerState,
Other,
}
impl GpuSnapshot {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
version: 1,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64,
total_size: 0,
buffers: Vec::new(),
weights_hash: String::new(),
checksum: String::new(),
}
}
pub fn add_buffer(&mut self, name: impl Into<String>, size: u64, buffer_type: BufferType) {
let offset = self.total_size;
self.buffers.push(BufferSnapshot {
name: name.into(),
size,
offset,
compressed_size: None,
buffer_type,
});
self.total_size += size;
}
pub fn get_buffer(&self, name: &str) -> Option<&BufferSnapshot> {
self.buffers.iter().find(|b| b.name == name)
}
pub fn get_buffers_by_type(&self, buffer_type: BufferType) -> Vec<&BufferSnapshot> {
self.buffers
.iter()
.filter(|b| b.buffer_type == buffer_type)
.collect()
}
pub fn compute_checksum(&mut self, data: &[u8]) {
let hash = Blake3::hash(data);
self.checksum = hash.iter().map(|b| format!("{:02x}", b)).collect();
}
pub fn verify_checksum(&self, data: &[u8]) -> bool {
let hash = Blake3::hash(data);
let computed: String = hash.iter().map(|b| format!("{:02x}", b)).collect();
computed == self.checksum
}
}
#[derive(Debug)]
pub struct SnapshotManager {
config: SnapshotConfig,
snapshots: HashMap<String, GpuSnapshot>,
stats: SnapshotStats,
}
#[derive(Debug, Default)]
pub struct SnapshotStats {
pub created: u64,
pub restored: u64,
pub bytes_saved: u64,
pub bytes_restored: u64,
pub avg_save_ms: f64,
pub avg_restore_ms: f64,
}
impl SnapshotManager {
pub fn new(config: SnapshotConfig) -> Self {
Self {
config,
snapshots: HashMap::new(),
stats: SnapshotStats::default(),
}
}
pub async fn create_snapshot(
&mut self,
id: impl Into<String>,
buffers: Vec<(String, Vec<u8>, BufferType)>,
) -> Result<GpuSnapshot> {
let start = Instant::now();
let id = id.into();
let mut snapshot = GpuSnapshot::new(&id);
let mut data = Vec::new();
for (name, buffer_data, buffer_type) in buffers {
snapshot.add_buffer(&name, buffer_data.len() as u64, buffer_type);
if self.config.compression {
data.extend_from_slice(&buffer_data);
} else {
data.extend_from_slice(&buffer_data);
}
}
if self.config.verify_checksum {
snapshot.compute_checksum(&data);
}
self.save_to_disk(&snapshot, &data).await?;
self.snapshots.insert(id, snapshot.clone());
self.stats.created += 1;
self.stats.bytes_saved += snapshot.total_size;
let elapsed = start.elapsed().as_millis() as f64;
self.stats.avg_save_ms = (self.stats.avg_save_ms * (self.stats.created - 1) as f64
+ elapsed)
/ self.stats.created as f64;
Ok(snapshot)
}
pub async fn restore_snapshot(&mut self, id: &str) -> Result<Vec<(String, Vec<u8>)>> {
let start = Instant::now();
let snapshot = if let Some(s) = self.snapshots.get(id) {
s.clone()
} else {
self.load_from_disk(id).await?
};
let data = self.load_data(id).await?;
if self.config.verify_checksum && !snapshot.verify_checksum(&data) {
return Err(ServerlessError::SnapshotError(
"Checksum verification failed".into(),
));
}
let mut buffers = Vec::new();
for buffer in &snapshot.buffers {
let start = buffer.offset as usize;
let end = start + buffer.size as usize;
let buffer_data = data[start..end].to_vec();
buffers.push((buffer.name.clone(), buffer_data));
}
self.stats.restored += 1;
self.stats.bytes_restored += snapshot.total_size;
let elapsed = start.elapsed().as_millis() as f64;
self.stats.avg_restore_ms = (self.stats.avg_restore_ms * (self.stats.restored - 1) as f64
+ elapsed)
/ self.stats.restored as f64;
Ok(buffers)
}
async fn save_to_disk(&self, snapshot: &GpuSnapshot, data: &[u8]) -> Result<()> {
let dir = &self.config.snapshot_dir;
std::fs::create_dir_all(dir)?;
let meta_path = dir.join(format!("{}.meta.json", snapshot.id));
let meta_json = serde_json::to_string_pretty(snapshot)
.map_err(|e| ServerlessError::SerializationError(e.to_string()))?;
std::fs::write(&meta_path, meta_json)?;
let data_path = dir.join(format!("{}.data", snapshot.id));
std::fs::write(&data_path, data)?;
Ok(())
}
async fn load_from_disk(&mut self, id: &str) -> Result<GpuSnapshot> {
let meta_path = self.config.snapshot_dir.join(format!("{}.meta.json", id));
let meta_json = std::fs::read_to_string(&meta_path)?;
let snapshot: GpuSnapshot = serde_json::from_str(&meta_json)
.map_err(|e| ServerlessError::DeserializationError(e.to_string()))?;
self.snapshots.insert(id.to_string(), snapshot.clone());
Ok(snapshot)
}
async fn load_data(&self, id: &str) -> Result<Vec<u8>> {
let data_path = self.config.snapshot_dir.join(format!("{}.data", id));
let data = std::fs::read(&data_path)?;
Ok(data)
}
pub fn list_snapshots(&self) -> Vec<&str> {
self.snapshots.keys().map(|s| s.as_str()).collect()
}
pub fn delete_snapshot(&mut self, id: &str) -> Result<()> {
self.snapshots.remove(id);
let meta_path = self.config.snapshot_dir.join(format!("{}.meta.json", id));
let data_path = self.config.snapshot_dir.join(format!("{}.data", id));
if meta_path.exists() {
std::fs::remove_file(meta_path)?;
}
if data_path.exists() {
std::fs::remove_file(data_path)?;
}
Ok(())
}
pub fn clear_old(&mut self) -> Result<usize> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let max_age_ms = self.config.max_age_seconds * 1000;
let mut to_delete = Vec::new();
for (id, snapshot) in &self.snapshots {
if now - snapshot.created_at > max_age_ms {
to_delete.push(id.clone());
}
}
for id in &to_delete {
self.delete_snapshot(id)?;
}
Ok(to_delete.len())
}
pub fn stats(&self) -> &SnapshotStats {
&self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = SnapshotConfig::default();
assert!(config.compression);
assert!(config.verify_checksum);
}
#[test]
fn test_snapshot_creation() {
let mut snapshot = GpuSnapshot::new("test-snapshot");
snapshot.add_buffer("weights", 1024, BufferType::Weights);
snapshot.add_buffer("kv_cache", 512, BufferType::KvCache);
assert_eq!(snapshot.buffers.len(), 2);
assert_eq!(snapshot.total_size, 1536);
}
#[test]
fn test_buffer_lookup() {
let mut snapshot = GpuSnapshot::new("test");
snapshot.add_buffer("weights", 1024, BufferType::Weights);
snapshot.add_buffer("cache", 512, BufferType::KvCache);
assert!(snapshot.get_buffer("weights").is_some());
assert!(snapshot.get_buffer("nonexistent").is_none());
let weights = snapshot.get_buffers_by_type(BufferType::Weights);
assert_eq!(weights.len(), 1);
}
#[test]
fn test_checksum() {
let mut snapshot = GpuSnapshot::new("test");
let data = vec![1, 2, 3, 4, 5];
snapshot.compute_checksum(&data);
assert!(!snapshot.checksum.is_empty());
assert!(snapshot.verify_checksum(&data));
assert!(!snapshot.verify_checksum(&[1, 2, 3]));
}
#[test]
fn test_manager_creation() {
let config = SnapshotConfig::default();
let manager = SnapshotManager::new(config);
assert!(manager.list_snapshots().is_empty());
}
}