use std::collections::BTreeMap;
use std::path::Path;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha2::{Digest, Sha256};
use crate::error::{LlmError, LlmResult};
use crate::types::Message;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CassetteMethod {
Generate,
StructuredOutput,
TranscribeImage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CassetteEntry {
pub method: CassetteMethod,
pub user_input_preview: String,
pub schema_name: Option<String>,
pub response: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmCassette {
pub version: u32,
pub model: String,
pub entries: BTreeMap<String, CassetteEntry>,
}
impl LlmCassette {
pub fn load(path: impl AsRef<Path>) -> LlmResult<Self> {
let path = path.as_ref();
let contents = std::fs::read_to_string(path).map_err(|e| {
LlmError::ConfigError(format!("failed to read cassette {}: {e}", path.display()))
})?;
serde_json::from_str(&contents).map_err(|e| {
LlmError::DeserializationError(format!(
"failed to parse cassette {}: {e}",
path.display()
))
})
}
pub fn save(&self, path: impl AsRef<Path>) -> LlmResult<()> {
let path = path.as_ref();
let contents = serde_json::to_string_pretty(self).map_err(|e| {
LlmError::SerializationError(format!("failed to serialize cassette: {e}"))
})?;
std::fs::write(path, contents).map_err(|e| {
LlmError::ConfigError(format!("failed to write cassette {}: {e}", path.display()))
})
}
}
pub fn input_hash(messages: &[Message], schema: Option<&Value>) -> String {
let mut buf = String::new();
for message in messages {
buf.push_str(role_str(&message.role));
buf.push(':');
buf.push_str(&message.content);
buf.push('\n');
}
if let Some(schema) = schema {
buf.push_str(&canonicalize(schema));
}
hex_digest(buf.as_bytes())
}
pub fn vision_hash(image_bytes: &[u8], mime_type: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(mime_type.as_bytes());
hasher.update(image_bytes);
hex(hasher.finalize())
}
fn role_str(role: &crate::types::MessageRole) -> &'static str {
use crate::types::MessageRole;
match role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
}
}
fn canonicalize(value: &Value) -> String {
match value {
Value::Object(map) => {
let sorted: BTreeMap<&String, &Value> = map.iter().collect();
let mut out = String::from("{");
for (i, (key, val)) in sorted.iter().enumerate() {
if i > 0 {
out.push(',');
}
let key_str = serde_json::to_string(key).unwrap_or_else(|_| format!("{key:?}"));
out.push_str(&key_str);
out.push(':');
out.push_str(&canonicalize(val));
}
out.push('}');
out
}
Value::Array(items) => {
let mut out = String::from("[");
for (i, item) in items.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push_str(&canonicalize(item));
}
out.push(']');
out
}
other => serde_json::to_string(other).unwrap_or_else(|_| format!("{other:?}")),
}
}
fn hex_digest(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(bytes);
hex(hasher.finalize())
}
fn hex(digest: impl AsRef<[u8]>) -> String {
use std::fmt::Write;
let digest = digest.as_ref();
let mut out = String::with_capacity(digest.len() * 2);
for byte in digest {
let _ = write!(out, "{byte:02x}");
}
out
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable"
)]
use super::*;
use serde_json::json;
fn msgs() -> Vec<Message> {
vec![
Message::system("You are a helpful assistant."),
Message::user("Extract entities from: Alice met Bob."),
]
}
#[test]
fn input_hash_is_stable_for_same_input() {
let schema = json!({"type": "object", "properties": {"a": {"type": "string"}}});
let h1 = input_hash(&msgs(), Some(&schema));
let h2 = input_hash(&msgs(), Some(&schema));
assert_eq!(h1, h2);
assert_eq!(h1.len(), 64);
assert!(
h1.chars()
.all(|c| c.is_ascii_hexdigit() && !c.is_uppercase())
);
}
#[test]
fn input_hash_differs_when_content_differs() {
let base = input_hash(&msgs(), None);
let different_content = input_hash(&[Message::user("totally different prompt")], None);
assert_ne!(base, different_content);
let with_schema = input_hash(&msgs(), Some(&json!({"type": "object"})));
assert_ne!(base, with_schema);
}
#[test]
fn canonicalize_is_order_independent() {
let a = json!({
"b": [1, {"y": 2, "x": 1}],
"a": {"k2": "v2", "k1": "v1"}
});
let b = json!({
"a": {"k1": "v1", "k2": "v2"},
"b": [1, {"x": 1, "y": 2}]
});
assert_eq!(canonicalize(&a), canonicalize(&b));
assert_eq!(input_hash(&msgs(), Some(&a)), input_hash(&msgs(), Some(&b)));
}
#[test]
fn canonicalize_distinguishes_array_order() {
let a = json!([1, 2, 3]);
let b = json!([3, 2, 1]);
assert_ne!(canonicalize(&a), canonicalize(&b));
}
#[test]
fn vision_hash_is_stable_and_sensitive() {
let h1 = vision_hash(b"\x89PNG\r\n", "image/png");
let h2 = vision_hash(b"\x89PNG\r\n", "image/png");
assert_eq!(h1, h2);
assert_ne!(h1, vision_hash(b"\x89PNG\r\n", "image/jpeg"));
assert_ne!(h1, vision_hash(b"different", "image/png"));
}
#[test]
fn cassette_round_trips_through_save_load() {
let dir = tempfile::tempdir().expect("create tempdir");
let path = dir.path().join("cassette.json");
let mut entries = BTreeMap::new();
entries.insert(
input_hash(&msgs(), None),
CassetteEntry {
method: CassetteMethod::Generate,
user_input_preview: "Extract entities from: Alice met Bob.".to_string(),
schema_name: None,
response: json!({"content": "Alice, Bob"}),
},
);
entries.insert(
vision_hash(b"img", "image/png"),
CassetteEntry {
method: CassetteMethod::TranscribeImage,
user_input_preview: "[image/png]".to_string(),
schema_name: Some("KnowledgeGraph".to_string()),
response: json!({"text": "a cat"}),
},
);
let cassette = LlmCassette {
version: 1,
model: "gpt-4o-mini".to_string(),
entries,
};
cassette.save(&path).expect("save cassette");
let loaded = LlmCassette::load(&path).expect("load cassette");
assert_eq!(loaded.version, cassette.version);
assert_eq!(loaded.model, cassette.model);
assert_eq!(loaded.entries.len(), cassette.entries.len());
for (key, entry) in &cassette.entries {
let got = loaded
.entries
.get(key)
.expect("entry present after round-trip");
assert_eq!(got.method, entry.method);
assert_eq!(got.user_input_preview, entry.user_input_preview);
assert_eq!(got.schema_name, entry.schema_name);
assert_eq!(got.response, entry.response);
}
}
}