use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use serde_json::Value;
use super::cassette::{CassetteEntry, CassetteMethod, LlmCassette, input_hash, vision_hash};
use crate::error::LlmResult;
use crate::llm_trait::Llm;
use crate::types::{GenerationOptions, GenerationResponse, Message, MessageRole};
const PREVIEW_MAX_CHARS: usize = 120;
pub struct RecordingLlm {
inner: Arc<dyn Llm>,
entries: Mutex<BTreeMap<String, CassetteEntry>>,
path: PathBuf,
}
impl RecordingLlm {
pub fn new(inner: Arc<dyn Llm>, path: impl Into<PathBuf>) -> Self {
let path = path.into();
let entries = match LlmCassette::load(&path) {
Ok(cassette) => cassette.entries,
Err(_) => BTreeMap::new(),
};
Self {
inner,
entries: Mutex::new(entries),
path,
}
}
#[allow(clippy::unwrap_used, reason = "lock poison is unrecoverable")]
fn record(&self, hash: String, entry: CassetteEntry) {
self.entries.lock().unwrap().insert(hash, entry);
}
#[allow(clippy::unwrap_used, reason = "lock poison is unrecoverable")]
pub fn flush(&self) -> LlmResult<()> {
let entries = self.entries.lock().unwrap().clone();
let cassette = LlmCassette {
version: 1,
model: self.inner.model().to_string(),
entries,
};
cassette.save(&self.path)
}
}
fn user_input_preview(messages: &[Message]) -> String {
let last_user = messages
.iter()
.rev()
.find(|m| m.role == MessageRole::User)
.map(|m| m.content.as_str())
.unwrap_or("");
last_user.chars().take(PREVIEW_MAX_CHARS).collect()
}
fn schema_name(schema: &Value) -> Option<String> {
schema
.get("title")
.and_then(Value::as_str)
.map(str::to_string)
}
#[async_trait]
impl Llm for RecordingLlm {
async fn generate(
&self,
messages: Vec<Message>,
options: Option<GenerationOptions>,
) -> LlmResult<GenerationResponse> {
let hash = input_hash(&messages, None);
let preview = user_input_preview(&messages);
let response = self.inner.generate(messages, options).await?;
self.record(
hash,
CassetteEntry {
method: CassetteMethod::Generate,
user_input_preview: preview,
schema_name: None,
response: Value::String(response.content.clone()),
},
);
Ok(response)
}
async fn create_structured_output_with_messages_raw(
&self,
messages: Vec<Message>,
json_schema: &Value,
options: Option<GenerationOptions>,
) -> LlmResult<Value> {
let hash = input_hash(&messages, Some(json_schema));
let preview = user_input_preview(&messages);
let schema_name = schema_name(json_schema);
let response = self
.inner
.create_structured_output_with_messages_raw(messages, json_schema, options)
.await?;
self.record(
hash,
CassetteEntry {
method: CassetteMethod::StructuredOutput,
user_input_preview: preview,
schema_name,
response: response.clone(),
},
);
Ok(response)
}
async fn transcribe_image(
&self,
image_bytes: &[u8],
mime_type: &str,
options: Option<GenerationOptions>,
) -> LlmResult<String> {
let hash = vision_hash(image_bytes, mime_type);
let response = self
.inner
.transcribe_image(image_bytes, mime_type, options)
.await?;
self.record(
hash,
CassetteEntry {
method: CassetteMethod::TranscribeImage,
user_input_preview: format!("[{mime_type}]"),
schema_name: None,
response: Value::String(response.clone()),
},
);
Ok(response)
}
fn model(&self) -> &str {
self.inner.model()
}
fn supports_streaming(&self) -> bool {
self.inner.supports_streaming()
}
fn supports_function_calling(&self) -> bool {
self.inner.supports_function_calling()
}
fn max_context_length(&self) -> u32 {
self.inner.max_context_length()
}
fn supports_vision(&self) -> bool {
self.inner.supports_vision()
}
}
impl Drop for RecordingLlm {
fn drop(&mut self) {
if let Err(e) = self.flush() {
tracing::warn!(
error = %e,
path = %self.path.display(),
"RecordingLlm: failed to flush cassette on drop"
);
}
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable"
)]
use super::*;
use serde_json::json;
use std::collections::VecDeque;
use crate::error::LlmError;
struct StubLlm {
responses: Mutex<VecDeque<String>>,
}
impl StubLlm {
fn new(responses: Vec<String>) -> Self {
Self {
responses: Mutex::new(VecDeque::from(responses)),
}
}
fn pop(&self) -> String {
self.responses
.lock()
.unwrap()
.pop_front()
.unwrap_or_else(|| r#"{"nodes":[],"relationships":[]}"#.to_string())
}
}
#[async_trait]
impl Llm for StubLlm {
async fn generate(
&self,
_messages: Vec<Message>,
_options: Option<GenerationOptions>,
) -> LlmResult<GenerationResponse> {
Ok(GenerationResponse {
content: self.pop(),
model: "stub-llm".to_string(),
usage: None,
finish_reason: Some("stop".to_string()),
})
}
async fn create_structured_output_with_messages_raw(
&self,
_messages: Vec<Message>,
_json_schema: &Value,
_options: Option<GenerationOptions>,
) -> LlmResult<Value> {
let raw = self.pop();
serde_json::from_str(&raw)
.map_err(|e| LlmError::DeserializationError(format!("StubLlm: invalid JSON: {e}")))
}
fn model(&self) -> &str {
"stub-llm"
}
}
fn graph_msgs() -> Vec<Message> {
vec![
Message::system("Extract a knowledge graph."),
Message::user("Alice met Bob."),
]
}
#[tokio::test]
async fn records_structured_output_entry() {
let dir = tempfile::tempdir().expect("create tempdir");
let path = dir.path().join("cassette.json");
let graph = json!({"nodes": [{"name": "Alice"}], "relationships": []});
let mock: Arc<dyn Llm> = Arc::new(StubLlm::new(vec![graph.to_string()]));
let recorder = RecordingLlm::new(mock, &path);
let schema = json!({"title": "KnowledgeGraph", "type": "object"});
let value = recorder
.create_structured_output_with_messages_raw(graph_msgs(), &schema, None)
.await
.expect("structured output");
assert_eq!(value, graph);
recorder.flush().expect("flush");
let cassette = LlmCassette::load(&path).expect("load cassette");
assert_eq!(cassette.entries.len(), 1);
let key = input_hash(&graph_msgs(), Some(&schema));
let entry = cassette.entries.get(&key).expect("entry present");
assert_eq!(entry.method, CassetteMethod::StructuredOutput);
assert_eq!(entry.schema_name.as_deref(), Some("KnowledgeGraph"));
assert_eq!(entry.response, graph);
assert_eq!(entry.user_input_preview, "Alice met Bob.");
}
#[tokio::test]
async fn identical_inputs_dedup_to_one_entry() {
let dir = tempfile::tempdir().expect("create tempdir");
let path = dir.path().join("cassette.json");
let graph = json!({"nodes": [], "relationships": []});
let mock: Arc<dyn Llm> = Arc::new(StubLlm::new(vec![graph.to_string(), graph.to_string()]));
let recorder = RecordingLlm::new(mock, &path);
let schema = json!({"title": "KnowledgeGraph", "type": "object"});
recorder
.create_structured_output_with_messages_raw(graph_msgs(), &schema, None)
.await
.expect("first call");
recorder
.create_structured_output_with_messages_raw(graph_msgs(), &schema, None)
.await
.expect("second call");
recorder.flush().expect("flush");
let cassette = LlmCassette::load(&path).expect("load cassette");
assert_eq!(cassette.entries.len(), 1, "identical inputs must dedup");
}
#[tokio::test]
async fn drop_flushes_without_explicit_flush() {
let dir = tempfile::tempdir().expect("create tempdir");
let path = dir.path().join("cassette.json");
{
let mock: Arc<dyn Llm> = Arc::new(StubLlm::new(vec!["\"hello\"".to_string()]));
let recorder = RecordingLlm::new(mock, &path);
recorder
.generate(graph_msgs(), None)
.await
.expect("generate");
}
assert!(path.exists(), "cassette should exist after drop");
let cassette = LlmCassette::load(&path).expect("cassette parses after drop");
assert_eq!(cassette.entries.len(), 1);
let entry = cassette
.entries
.values()
.next()
.expect("one recorded entry");
assert_eq!(entry.method, CassetteMethod::Generate);
assert_eq!(entry.response, Value::String("\"hello\"".to_string()));
}
#[tokio::test]
async fn re_recording_merges_existing_entries() {
let dir = tempfile::tempdir().expect("create tempdir");
let path = dir.path().join("cassette.json");
let schema = json!({"title": "KnowledgeGraph", "type": "object"});
{
let graph = json!({"nodes": [{"name": "Alice"}], "relationships": []});
let mock: Arc<dyn Llm> = Arc::new(StubLlm::new(vec![graph.to_string()]));
let recorder = RecordingLlm::new(mock, &path);
recorder
.create_structured_output_with_messages_raw(graph_msgs(), &schema, None)
.await
.expect("first session");
recorder.flush().expect("flush first");
}
{
let other_msgs = vec![
Message::system("Extract a knowledge graph."),
Message::user("Carol knows Dave."),
];
let graph = json!({"nodes": [{"name": "Carol"}], "relationships": []});
let mock: Arc<dyn Llm> = Arc::new(StubLlm::new(vec![graph.to_string()]));
let recorder = RecordingLlm::new(mock, &path);
recorder
.create_structured_output_with_messages_raw(other_msgs, &schema, None)
.await
.expect("second session");
recorder.flush().expect("flush second");
}
let cassette = LlmCassette::load(&path).expect("load merged cassette");
assert_eq!(cassette.entries.len(), 2, "re-recording must merge");
}
}