use std::collections::{HashMap, VecDeque};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use agent_sdk_foundation::llm::{
ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage,
};
use anyhow::{Context, Result, anyhow, bail};
use async_trait::async_trait;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use crate::provider::LlmProvider;
use crate::streaming::{StreamBox, StreamDelta, StreamErrorKind};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecordReplayMode {
Record,
Replay,
}
pub struct RecordReplayProvider {
inner: Option<Arc<dyn LlmProvider>>,
mode: RecordReplayMode,
path: PathBuf,
recorded: Mutex<Cassette>,
replay: Mutex<HashMap<String, VecDeque<CassetteInteraction>>>,
model: String,
}
impl RecordReplayProvider {
#[must_use]
pub fn record(inner: Arc<dyn LlmProvider>, path: impl Into<PathBuf>) -> Self {
let model = inner.model().to_owned();
Self {
inner: Some(inner),
mode: RecordReplayMode::Record,
path: path.into(),
recorded: Mutex::new(Cassette {
model: model.clone(),
entries: Vec::new(),
}),
replay: Mutex::new(HashMap::new()),
model,
}
}
pub fn replay(path: impl Into<PathBuf>) -> Result<Self> {
let path = path.into();
let cassette = load_cassette(&path)?;
let model = cassette.model.clone();
let replay = build_replay_map(&cassette);
Ok(Self {
inner: None,
mode: RecordReplayMode::Replay,
path,
recorded: Mutex::new(Cassette::default()),
replay: Mutex::new(replay),
model,
})
}
#[must_use]
pub const fn mode(&self) -> RecordReplayMode {
self.mode
}
fn record_chat(&self, key: String, outcome: &ChatOutcome) -> Result<()> {
self.persist(CassetteEntry {
key,
interaction: CassetteInteraction::Chat(CassetteOutcome::from_outcome(outcome)),
})
}
fn record_stream(&self, key: String, deltas: Vec<CassetteDelta>) -> Result<()> {
self.persist(CassetteEntry {
key,
interaction: CassetteInteraction::Stream(deltas),
})
}
fn persist(&self, entry: CassetteEntry) -> Result<()> {
let json = {
let mut cassette = self
.recorded
.lock()
.map_err(|_| anyhow!("record cassette lock poisoned"))?;
cassette.entries.push(entry);
serde_json::to_string_pretty(&*cassette).context("serialize cassette")?
};
std::fs::write(&self.path, json)
.with_context(|| format!("write cassette to {}", self.path.display()))
}
fn take_replay(&self, key: &str) -> Result<CassetteInteraction> {
let mut map = self
.replay
.lock()
.map_err(|_| anyhow!("replay cassette lock poisoned"))?;
map.get_mut(key)
.and_then(VecDeque::pop_front)
.with_context(|| format!("no recorded interaction for request key '{key}'"))
}
}
#[async_trait]
impl LlmProvider for RecordReplayProvider {
async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
let key = entry_key("chat", &request);
match self.mode {
RecordReplayMode::Record => {
let inner = self
.inner
.as_ref()
.context("record mode requires an inner provider")?;
let outcome = inner.chat(request).await?;
self.record_chat(key, &outcome)?;
Ok(outcome)
}
RecordReplayMode::Replay => match self.take_replay(&key)? {
CassetteInteraction::Chat(outcome) => Ok(outcome.into_outcome()),
CassetteInteraction::Stream(_) => {
bail!("recorded interaction for '{key}' is a stream, not a chat")
}
},
}
}
fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
let key = entry_key("stream", &request);
match self.mode {
RecordReplayMode::Record => {
let inner = self.inner.clone();
Box::pin(async_stream::stream! {
let Some(inner) = inner else {
yield Err(anyhow!("record mode requires an inner provider"));
return;
};
let mut stream = inner.chat_stream(request);
let mut captured: Vec<CassetteDelta> = Vec::new();
while let Some(item) = stream.next().await {
match item {
Ok(delta) => {
captured.push(CassetteDelta::from_delta(&delta));
yield Ok(delta);
}
Err(error) => {
yield Err(error);
return;
}
}
}
if let Err(error) = self.record_stream(key, captured) {
log::warn!("record/replay: failed to persist stream cassette: {error}");
}
})
}
RecordReplayMode::Replay => Box::pin(async_stream::stream! {
match self.take_replay(&key) {
Ok(CassetteInteraction::Stream(deltas)) => {
for delta in deltas {
yield Ok(delta.into_delta());
}
}
Ok(CassetteInteraction::Chat(_)) => {
yield Err(anyhow!(
"recorded interaction for '{key}' is a chat, not a stream"
));
}
Err(error) => yield Err(error),
}
}),
}
}
async fn list_models(&self) -> Result<Vec<crate::provider::ModelInfo>> {
match &self.inner {
Some(inner) => inner.list_models().await,
None => bail!("list_models is not supported in replay mode (no inner provider)"),
}
}
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &'static str {
"record-replay"
}
}
#[derive(Default, Serialize, Deserialize)]
struct Cassette {
#[serde(default)]
model: String,
#[serde(default)]
entries: Vec<CassetteEntry>,
}
#[derive(Serialize, Deserialize)]
struct CassetteEntry {
key: String,
interaction: CassetteInteraction,
}
#[derive(Clone, Serialize, Deserialize)]
enum CassetteInteraction {
Chat(CassetteOutcome),
Stream(Vec<CassetteDelta>),
}
#[derive(Clone, Serialize, Deserialize)]
enum CassetteOutcome {
Success(CassetteResponse),
RateLimited(Option<u64>),
InvalidRequest(String),
ServerError(String),
}
impl CassetteOutcome {
fn from_outcome(outcome: &ChatOutcome) -> Self {
match outcome {
ChatOutcome::Success(response) => {
Self::Success(CassetteResponse::from_response(response))
}
ChatOutcome::RateLimited(retry_after) => {
Self::RateLimited(retry_after.map(millis_from_duration))
}
ChatOutcome::InvalidRequest(msg) => Self::InvalidRequest(msg.clone()),
ChatOutcome::ServerError(msg) => Self::ServerError(msg.clone()),
_ => Self::ServerError("unrecognized provider outcome".to_owned()),
}
}
fn into_outcome(self) -> ChatOutcome {
match self {
Self::Success(response) => ChatOutcome::Success(response.into_response()),
Self::RateLimited(ms) => ChatOutcome::RateLimited(ms.map(Duration::from_millis)),
Self::InvalidRequest(msg) => ChatOutcome::InvalidRequest(msg),
Self::ServerError(msg) => ChatOutcome::ServerError(msg),
}
}
}
#[derive(Clone, Serialize, Deserialize)]
struct CassetteResponse {
id: String,
content: Vec<ContentBlock>,
model: String,
stop_reason: Option<StopReason>,
usage: Usage,
}
impl CassetteResponse {
fn from_response(response: &ChatResponse) -> Self {
Self {
id: response.id.clone(),
content: response.content.clone(),
model: response.model.clone(),
stop_reason: response.stop_reason,
usage: response.usage.clone(),
}
}
fn into_response(self) -> ChatResponse {
ChatResponse {
id: self.id,
content: self.content,
model: self.model,
stop_reason: self.stop_reason,
usage: self.usage,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
enum CassetteDelta {
TextDelta {
delta: String,
block_index: usize,
},
ThinkingDelta {
delta: String,
block_index: usize,
},
ToolUseStart {
id: String,
name: String,
block_index: usize,
thought_signature: Option<String>,
},
ToolInputDelta {
id: String,
delta: String,
block_index: usize,
},
SignatureDelta {
delta: String,
block_index: usize,
},
RedactedThinking {
data: String,
block_index: usize,
},
Usage(Usage),
Done {
stop_reason: Option<StopReason>,
},
Error {
message: String,
kind: CassetteErrorKind,
},
}
impl CassetteDelta {
fn from_delta(delta: &StreamDelta) -> Self {
match delta {
StreamDelta::TextDelta { delta, block_index } => Self::TextDelta {
delta: delta.clone(),
block_index: *block_index,
},
StreamDelta::ThinkingDelta { delta, block_index } => Self::ThinkingDelta {
delta: delta.clone(),
block_index: *block_index,
},
StreamDelta::ToolUseStart {
id,
name,
block_index,
thought_signature,
} => Self::ToolUseStart {
id: id.clone(),
name: name.clone(),
block_index: *block_index,
thought_signature: thought_signature.clone(),
},
StreamDelta::ToolInputDelta {
id,
delta,
block_index,
} => Self::ToolInputDelta {
id: id.clone(),
delta: delta.clone(),
block_index: *block_index,
},
StreamDelta::SignatureDelta { delta, block_index } => Self::SignatureDelta {
delta: delta.clone(),
block_index: *block_index,
},
StreamDelta::RedactedThinking { data, block_index } => Self::RedactedThinking {
data: data.clone(),
block_index: *block_index,
},
StreamDelta::Usage(usage) => Self::Usage(usage.clone()),
StreamDelta::Done { stop_reason } => Self::Done {
stop_reason: *stop_reason,
},
StreamDelta::Error { message, kind } => Self::Error {
message: message.clone(),
kind: CassetteErrorKind::from_kind(*kind),
},
}
}
fn into_delta(self) -> StreamDelta {
match self {
Self::TextDelta { delta, block_index } => StreamDelta::TextDelta { delta, block_index },
Self::ThinkingDelta { delta, block_index } => {
StreamDelta::ThinkingDelta { delta, block_index }
}
Self::ToolUseStart {
id,
name,
block_index,
thought_signature,
} => StreamDelta::ToolUseStart {
id,
name,
block_index,
thought_signature,
},
Self::ToolInputDelta {
id,
delta,
block_index,
} => StreamDelta::ToolInputDelta {
id,
delta,
block_index,
},
Self::SignatureDelta { delta, block_index } => {
StreamDelta::SignatureDelta { delta, block_index }
}
Self::RedactedThinking { data, block_index } => {
StreamDelta::RedactedThinking { data, block_index }
}
Self::Usage(usage) => StreamDelta::Usage(usage),
Self::Done { stop_reason } => StreamDelta::Done { stop_reason },
Self::Error { message, kind } => StreamDelta::Error {
message,
kind: kind.into_kind(),
},
}
}
}
#[derive(Clone, Copy, Serialize, Deserialize)]
enum CassetteErrorKind {
RateLimited,
ServerError,
InvalidRequest,
Unknown,
}
impl CassetteErrorKind {
const fn from_kind(kind: StreamErrorKind) -> Self {
match kind {
StreamErrorKind::RateLimited => Self::RateLimited,
StreamErrorKind::ServerError => Self::ServerError,
StreamErrorKind::InvalidRequest => Self::InvalidRequest,
_ => Self::Unknown,
}
}
const fn into_kind(self) -> StreamErrorKind {
match self {
Self::RateLimited => StreamErrorKind::RateLimited,
Self::ServerError => StreamErrorKind::ServerError,
Self::InvalidRequest => StreamErrorKind::InvalidRequest,
Self::Unknown => StreamErrorKind::Unknown,
}
}
}
fn millis_from_duration(duration: Duration) -> u64 {
u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
}
fn fingerprint(request: &ChatRequest) -> String {
let canonical = serde_json::json!({
"system": request.system,
"messages": request.messages,
"tools": request.tools,
"max_tokens": request.max_tokens,
"response_format": request.response_format,
});
let bytes = serde_json::to_vec(&canonical).unwrap_or_default();
let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
for byte in bytes {
hash ^= u64::from(byte);
hash = hash.wrapping_mul(0x0000_0100_0000_01b3);
}
format!("{hash:016x}")
}
fn entry_key(method: &str, request: &ChatRequest) -> String {
format!("{method}:{}", fingerprint(request))
}
fn load_cassette(path: &Path) -> Result<Cassette> {
let data = std::fs::read_to_string(path)
.with_context(|| format!("read cassette {}", path.display()))?;
serde_json::from_str(&data).with_context(|| format!("parse cassette {}", path.display()))
}
fn build_replay_map(cassette: &Cassette) -> HashMap<String, VecDeque<CassetteInteraction>> {
let mut map: HashMap<String, VecDeque<CassetteInteraction>> = HashMap::new();
for entry in &cassette.entries {
map.entry(entry.key.clone())
.or_default()
.push_back(entry.interaction.clone());
}
map
}
#[cfg(test)]
mod tests {
use super::*;
use agent_sdk_foundation::llm::Message;
struct InnerProvider {
model: &'static str,
chat_outcome: ChatOutcome,
deltas: Vec<StreamDelta>,
}
#[async_trait]
impl LlmProvider for InnerProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
Ok(self.chat_outcome.clone())
}
async fn list_models(&self) -> Result<Vec<crate::provider::ModelInfo>> {
Ok(vec![crate::provider::ModelInfo {
id: "inner-discovered-model".to_owned(),
display_name: None,
context_window: None,
max_output_tokens: None,
}])
}
fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
let deltas = self.deltas.clone();
Box::pin(async_stream::stream! {
for delta in deltas {
yield Ok(delta);
}
})
}
fn model(&self) -> &str {
self.model
}
fn provider(&self) -> &'static str {
"inner"
}
}
fn success_outcome(text: &str) -> ChatOutcome {
ChatOutcome::Success(ChatResponse {
id: "resp-1".to_owned(),
content: vec![ContentBlock::Text {
text: text.to_owned(),
}],
model: "inner-model".to_owned(),
stop_reason: Some(StopReason::EndTurn),
usage: Usage {
input_tokens: 7,
output_tokens: 3,
cached_input_tokens: 0,
cache_creation_input_tokens: 0,
},
})
}
fn temp_cassette_path() -> PathBuf {
std::env::temp_dir().join(format!("agent-sdk-cassette-{}.json", uuid::Uuid::new_v4()))
}
fn request() -> ChatRequest {
ChatRequest::new("sys", vec![Message::user("hello")])
}
#[tokio::test]
async fn chat_round_trips_through_cassette() -> Result<()> {
let path = temp_cassette_path();
let inner = Arc::new(InnerProvider {
model: "inner-model",
chat_outcome: success_outcome("recorded answer"),
deltas: Vec::new(),
});
let recorder = RecordReplayProvider::record(inner, &path);
let live_outcome = recorder.chat(request()).await?;
assert!(
matches!(&live_outcome, ChatOutcome::Success(r) if r.first_text() == Some("recorded answer"))
);
let player = RecordReplayProvider::replay(&path)?;
assert_eq!(player.mode(), RecordReplayMode::Replay);
assert_eq!(player.model(), "inner-model");
let replayed = player.chat(request()).await?;
match replayed {
ChatOutcome::Success(response) => {
assert_eq!(response.first_text(), Some("recorded answer"));
assert_eq!(response.usage.input_tokens, 7);
assert_eq!(response.stop_reason, Some(StopReason::EndTurn));
}
other => panic!("expected Success, got {other:?}"),
}
let missing = player
.chat(ChatRequest::new("other", vec![Message::user("nope")]))
.await;
assert!(missing.is_err());
let _ = std::fs::remove_file(&path);
Ok(())
}
#[tokio::test]
async fn list_models_delegates_in_record_and_errors_in_replay() -> Result<()> {
let path = temp_cassette_path();
let inner = Arc::new(InnerProvider {
model: "inner-model",
chat_outcome: success_outcome("x"),
deltas: Vec::new(),
});
let recorder = RecordReplayProvider::record(inner, &path);
let models = recorder.list_models().await?;
assert_eq!(models.len(), 1);
assert_eq!(models[0].id, "inner-discovered-model");
recorder.chat(request()).await?;
let player = RecordReplayProvider::replay(&path)?;
assert!(player.list_models().await.is_err());
let _ = std::fs::remove_file(&path);
Ok(())
}
#[tokio::test]
async fn stream_round_trips_through_cassette() -> Result<()> {
let path = temp_cassette_path();
let inner = Arc::new(InnerProvider {
model: "inner-model",
chat_outcome: success_outcome("unused"),
deltas: vec![
StreamDelta::TextDelta {
delta: "hel".to_owned(),
block_index: 0,
},
StreamDelta::TextDelta {
delta: "lo".to_owned(),
block_index: 0,
},
StreamDelta::Done {
stop_reason: Some(StopReason::EndTurn),
},
],
});
let recorder = RecordReplayProvider::record(inner, &path);
let mut text = String::new();
let mut stream = recorder.chat_stream(request());
while let Some(item) = stream.next().await {
if let StreamDelta::TextDelta { delta, .. } = item? {
text.push_str(&delta);
}
}
drop(stream);
assert_eq!(text, "hello");
let player = RecordReplayProvider::replay(&path)?;
let mut replayed = String::new();
let mut stop_seen = false;
let mut stream = player.chat_stream(request());
while let Some(item) = stream.next().await {
match item? {
StreamDelta::TextDelta { delta, .. } => replayed.push_str(&delta),
StreamDelta::Done { .. } => stop_seen = true,
_ => {}
}
}
assert_eq!(replayed, "hello");
assert!(stop_seen, "Done delta should replay");
let _ = std::fs::remove_file(&path);
Ok(())
}
}