use std::collections::HashMap;
use std::sync::Mutex;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::executor::InFlightTool;
use super::message::ToolCall;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamCheckpoint {
pub run_id: String,
pub thread_id: String,
pub upstream_model: String,
pub partial_text: String,
pub completed_tool_calls: Vec<ToolCall>,
pub in_flight_tool: Option<InFlightTool>,
pub updated_at_ms: u64,
}
#[derive(Debug, Error)]
#[error("stream checkpoint backend error: {0}")]
pub struct StreamCheckpointError(pub String);
#[async_trait]
pub trait StreamCheckpointStore: Send + Sync {
async fn put(&self, checkpoint: StreamCheckpoint) -> Result<(), StreamCheckpointError>;
async fn get(&self, run_id: &str) -> Result<Option<StreamCheckpoint>, StreamCheckpointError>;
async fn delete(&self, run_id: &str) -> Result<(), StreamCheckpointError>;
}
pub struct InMemoryStreamCheckpointStore {
data: Mutex<HashMap<String, StreamCheckpoint>>,
}
impl Default for InMemoryStreamCheckpointStore {
fn default() -> Self {
Self {
data: Mutex::new(HashMap::new()),
}
}
}
impl InMemoryStreamCheckpointStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.data.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[async_trait]
impl StreamCheckpointStore for InMemoryStreamCheckpointStore {
async fn put(&self, checkpoint: StreamCheckpoint) -> Result<(), StreamCheckpointError> {
let mut guard = self.data.lock().unwrap();
guard.insert(checkpoint.run_id.clone(), checkpoint);
Ok(())
}
async fn get(&self, run_id: &str) -> Result<Option<StreamCheckpoint>, StreamCheckpointError> {
Ok(self.data.lock().unwrap().get(run_id).cloned())
}
async fn delete(&self, run_id: &str) -> Result<(), StreamCheckpointError> {
self.data.lock().unwrap().remove(run_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample(run_id: &str) -> StreamCheckpoint {
StreamCheckpoint {
run_id: run_id.into(),
thread_id: "thread-1".into(),
upstream_model: "test-model".into(),
partial_text: "hello".into(),
completed_tool_calls: vec![],
in_flight_tool: None,
updated_at_ms: 1_000,
}
}
#[tokio::test]
async fn in_memory_put_get_delete_roundtrip() {
let store = InMemoryStreamCheckpointStore::new();
assert!(store.is_empty());
store.put(sample("run-a")).await.unwrap();
store.put(sample("run-b")).await.unwrap();
assert_eq!(store.len(), 2);
let got = store.get("run-a").await.unwrap().unwrap();
assert_eq!(got.run_id, "run-a");
assert_eq!(got.partial_text, "hello");
store.delete("run-a").await.unwrap();
assert_eq!(store.len(), 1);
assert!(store.get("run-a").await.unwrap().is_none());
}
#[tokio::test]
async fn delete_nonexistent_is_not_an_error() {
let store = InMemoryStreamCheckpointStore::new();
store.delete("no-such-run").await.unwrap();
}
#[tokio::test]
async fn put_overwrites_existing_entry() {
let store = InMemoryStreamCheckpointStore::new();
store.put(sample("run-a")).await.unwrap();
let updated = StreamCheckpoint {
partial_text: "updated text".into(),
updated_at_ms: 2_000,
..sample("run-a")
};
store.put(updated).await.unwrap();
let got = store.get("run-a").await.unwrap().unwrap();
assert_eq!(got.partial_text, "updated text");
assert_eq!(got.updated_at_ms, 2_000);
}
#[test]
fn checkpoint_serde_roundtrip() {
let checkpoint = sample("run-1");
let json = serde_json::to_string(&checkpoint).unwrap();
let parsed: StreamCheckpoint = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.run_id, "run-1");
assert_eq!(parsed.partial_text, "hello");
}
}