use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LlmRole {
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LlmMessage {
pub role: LlmRole,
pub content: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
pub model: String,
pub system: String,
pub messages: Vec<LlmMessage>,
pub temperature: f32,
pub max_tokens: u32,
pub json_schema: Option<serde_json::Value>,
pub timeout_ms: u64,
}
impl LlmRequest {
pub fn prompt_hash(&self) -> String {
let canonical = CanonicalPrompt {
system: &self.system,
messages: &self.messages,
temperature: self.temperature,
max_tokens: self.max_tokens,
json_schema: self.json_schema.as_ref(),
};
let bytes = serde_json::to_vec(&canonical).expect("CanonicalPrompt is always serializable");
blake3_hex(&bytes)
}
}
#[derive(Serialize)]
struct CanonicalPrompt<'a> {
system: &'a str,
messages: &'a [LlmMessage],
temperature: f32,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
json_schema: Option<&'a serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub text: String,
pub parsed_json: Option<serde_json::Value>,
pub model: String,
pub usage: Option<TokenUsage>,
pub raw_hash: String,
}
#[derive(Debug, Error)]
pub enum LlmError {
#[error("transport: {0}")]
Transport(String),
#[error("upstream: {0}")]
Upstream(String),
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("timeout after {timeout_ms} ms")]
Timeout {
timeout_ms: u64,
},
#[error("response parse: {0}")]
Parse(String),
#[error("no replay fixture for model={model} prompt_hash={prompt_hash}")]
NoFixture {
model: String,
prompt_hash: String,
},
#[error("fixture integrity failed: {0}")]
FixtureIntegrityFailed(String),
#[error("io: {0}")]
Io(String),
}
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub delta: String,
pub finish_reason: Option<String>,
}
pub type BoxStream<'a> = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send + 'a>>;
#[async_trait]
pub trait LlmAdapter: Send + Sync {
fn adapter_id(&self) -> &'static str;
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError>;
fn stream(&self, req: LlmRequest) -> impl Stream<Item = Result<StreamChunk, LlmError>> + Send
where
Self: Sized,
{
self.stream_boxed(req)
}
fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
let fut = self.complete(req);
Box::pin(async_stream::stream! {
match fut.await {
Ok(resp) => {
yield Ok(StreamChunk {
delta: resp.text,
finish_reason: Some("stop".into()),
});
}
Err(e) => yield Err(e),
}
})
}
}
#[must_use]
pub fn blake3_hex(bytes: &[u8]) -> String {
blake3::hash(bytes).to_hex().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
fn req_for(messages: &[(LlmRole, &str)]) -> LlmRequest {
LlmRequest {
model: "test-model".into(),
system: "be precise".into(),
messages: messages
.iter()
.map(|(r, c)| LlmMessage {
role: *r,
content: (*c).to_string(),
})
.collect(),
temperature: 0.0,
max_tokens: 256,
json_schema: None,
timeout_ms: 30_000,
}
}
#[test]
fn prompt_hash_is_stable_across_calls() {
let r = req_for(&[(LlmRole::User, "hello")]);
assert_eq!(r.prompt_hash(), r.prompt_hash());
}
#[test]
fn prompt_hash_ignores_model_and_timeout() {
let mut a = req_for(&[(LlmRole::User, "hello")]);
let mut b = a.clone();
b.model = "other-model".into();
b.timeout_ms = 1;
assert_eq!(a.prompt_hash(), b.prompt_hash());
a.temperature = 0.5;
assert_ne!(a.prompt_hash(), b.prompt_hash());
}
#[test]
fn prompt_hash_changes_with_message_content() {
let a = req_for(&[(LlmRole::User, "hello")]);
let b = req_for(&[(LlmRole::User, "world")]);
assert_ne!(a.prompt_hash(), b.prompt_hash());
}
#[test]
fn blake3_hex_is_64_chars() {
assert_eq!(blake3_hex(b"abc").len(), 64);
}
}