use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use tokio::fs;
use uuid::Uuid;
use crate::provider::{FinishReason, Message, TokenUsage};
use crate::runtime::error::{RuntimeError, RuntimeResult};
use crate::runtime::run::{RunId, RunRequest, RunStatus};
use crate::runtime::turn::TurnState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Snapshot {
pub run_id: RunId,
pub session_id: Uuid,
pub status: RunStatus,
pub iteration: usize,
pub current_state: TurnState,
pub total_usage: TokenUsage,
pub last_finish: Option<FinishReason>,
pub assistant_message: Option<Message>,
pub assistant_msg_id: Option<Uuid>,
pub request: RunRequest,
#[serde(default)]
pub output_recovery_count: u32,
pub timestamp: DateTime<Utc>,
}
#[async_trait]
pub trait SnapshotStore: Send + Sync + 'static {
async fn save(&self, snapshot: &Snapshot) -> RuntimeResult<()>;
async fn load(&self, run_id: RunId) -> RuntimeResult<Option<Snapshot>>;
async fn delete(&self, run_id: RunId) -> RuntimeResult<()>;
async fn list(&self) -> RuntimeResult<Vec<Snapshot>>;
}
pub struct FileSnapshotStore {
base_dir: PathBuf,
}
impl FileSnapshotStore {
#[must_use]
pub fn new(base_dir: PathBuf) -> Self {
Self { base_dir }
}
fn path_for(&self, run_id: RunId) -> PathBuf {
self.base_dir.join(format!("snapshot_{run_id}.json"))
}
}
#[async_trait]
impl SnapshotStore for FileSnapshotStore {
async fn save(&self, snapshot: &Snapshot) -> RuntimeResult<()> {
if let Err(e) = fs::create_dir_all(&self.base_dir).await {
return Err(RuntimeError::RecoveryFailed(format!(
"failed to create snapshot dir: {e}"
)));
}
let serialized = serde_json::to_string_pretty(snapshot).map_err(|e| {
RuntimeError::RecoveryFailed(format!("failed to serialize snapshot: {e}"))
})?;
let temp_path = self
.base_dir
.join(format!("snapshot_{}.json.tmp", snapshot.run_id));
if let Err(e) = fs::write(&temp_path, serialized).await {
return Err(RuntimeError::RecoveryFailed(format!(
"failed to write snapshot: {e}"
)));
}
let final_path = self.path_for(snapshot.run_id);
fs::rename(temp_path, final_path).await.map_err(|e| {
RuntimeError::RecoveryFailed(format!("failed to finalize snapshot: {e}"))
})?;
Ok(())
}
async fn load(&self, run_id: RunId) -> RuntimeResult<Option<Snapshot>> {
let path = self.path_for(run_id);
if !path.exists() {
return Ok(None);
}
let content = fs::read_to_string(&path).await.map_err(|e| {
RuntimeError::RecoveryFailed(format!("failed to read snapshot file: {e}"))
})?;
let snapshot: Snapshot = serde_json::from_str(&content).map_err(|e| {
RuntimeError::RecoveryFailed(format!("failed to deserialize snapshot file: {e}"))
})?;
Ok(Some(snapshot))
}
async fn delete(&self, run_id: RunId) -> RuntimeResult<()> {
let path = self.path_for(run_id);
if path.exists() {
fs::remove_file(path).await.map_err(|e| {
RuntimeError::RecoveryFailed(format!("failed to delete snapshot file: {e}"))
})?;
}
Ok(())
}
async fn list(&self) -> RuntimeResult<Vec<Snapshot>> {
if !self.base_dir.exists() {
return Ok(Vec::new());
}
let mut snapshots = Vec::new();
let mut entries = fs::read_dir(&self.base_dir).await.map_err(|e| {
RuntimeError::RecoveryFailed(format!("failed to read snapshot dir: {e}"))
})?;
while let Some(entry) = entries.next_entry().await.map_err(|e| {
RuntimeError::RecoveryFailed(format!("failed to read snapshot directory entry: {e}"))
})? {
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension() {
if ext == "json" {
if let Ok(content) = fs::read_to_string(&path).await {
if let Ok(snapshot) = serde_json::from_str::<Snapshot>(&content) {
snapshots.push(snapshot);
}
}
}
}
}
}
Ok(snapshots)
}
}