use crate::llm::{Message, TokenUsage};
use crate::verification::VerificationReport;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
pub const LOOP_CHECKPOINT_SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoopCheckpoint {
#[serde(default)]
pub schema_version: u32,
pub run_id: String,
pub session_id: String,
pub turn: usize,
pub messages: Vec<Message>,
pub total_usage: TokenUsage,
pub tool_calls_count: usize,
#[serde(default)]
pub verification_reports: Vec<VerificationReport>,
pub checkpoint_ms: u64,
}
#[async_trait]
pub trait LoopCheckpointSink: Send + Sync {
async fn save_checkpoint(&self, checkpoint: &LoopCheckpoint);
async fn load_latest(&self, run_id: &str) -> Option<LoopCheckpoint>;
}
pub struct SessionStoreCheckpointSink {
inner: std::sync::Arc<dyn crate::store::SessionStore>,
}
impl SessionStoreCheckpointSink {
pub fn new(store: std::sync::Arc<dyn crate::store::SessionStore>) -> Self {
Self { inner: store }
}
}
#[async_trait]
impl LoopCheckpointSink for SessionStoreCheckpointSink {
async fn save_checkpoint(&self, checkpoint: &LoopCheckpoint) {
if let Err(e) = self
.inner
.save_loop_checkpoint(&checkpoint.run_id, checkpoint)
.await
{
tracing::warn!(
run_id = %checkpoint.run_id,
error = %e,
"Loop checkpoint save failed; live run continues"
);
}
}
async fn load_latest(&self, run_id: &str) -> Option<LoopCheckpoint> {
match self.inner.load_loop_checkpoint(run_id).await {
Ok(opt) => opt,
Err(e) => {
tracing::warn!(
run_id = %run_id,
error = %e,
"Loop checkpoint load failed"
);
None
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample(run_id: &str, turn: usize) -> LoopCheckpoint {
LoopCheckpoint {
schema_version: LOOP_CHECKPOINT_SCHEMA_VERSION,
run_id: run_id.to_string(),
session_id: "session-1".to_string(),
turn,
messages: vec![Message::user("hi")],
total_usage: TokenUsage::default(),
tool_calls_count: 0,
verification_reports: Vec::new(),
checkpoint_ms: 1_700_000_000_000,
}
}
#[test]
fn checkpoint_round_trips_through_json() {
let cp = sample("run-1", 3);
let json = serde_json::to_string(&cp).unwrap();
let back: LoopCheckpoint = serde_json::from_str(&json).unwrap();
assert_eq!(back.run_id, "run-1");
assert_eq!(back.turn, 3);
assert_eq!(back.schema_version, LOOP_CHECKPOINT_SCHEMA_VERSION);
}
#[test]
fn missing_schema_version_defaults_to_zero() {
let json = r#"{
"run_id": "run-1",
"session_id": "s",
"turn": 1,
"messages": [],
"total_usage": {"prompt_tokens":0,"completion_tokens":0,"total_tokens":0},
"tool_calls_count": 0,
"checkpoint_ms": 0
}"#;
let cp: LoopCheckpoint = serde_json::from_str(json).unwrap();
assert_eq!(cp.schema_version, 0);
}
}