use std::future::Future;
use std::pin::Pin;
use std::{
any::TypeId,
collections::HashMap,
sync::{LazyLock, Mutex},
};
use futures_core::Stream;
use serde::{Deserialize, Serialize};
use zeph_common::ToolName;
pub use zeph_common::ToolDefinition;
use crate::embed::owned_strs;
use crate::error::LlmError;
static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
-> Result<(serde_json::Value, String), crate::LlmError> {
let type_id = TypeId::of::<T>();
if let Ok(cache) = SCHEMA_CACHE.lock()
&& let Some(entry) = cache.get(&type_id)
{
return Ok(entry.clone());
}
let schema = schemars::schema_for!(T);
let value = serde_json::to_value(&schema)
.map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
let pretty = serde_json::to_string_pretty(&schema)
.map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
if let Ok(mut cache) = SCHEMA_CACHE.lock() {
cache.insert(type_id, (value.clone(), pretty.clone()));
}
Ok((value, pretty))
}
pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
std::any::type_name::<T>()
.rsplit("::")
.next()
.unwrap_or("Output")
}
#[derive(Debug, Clone)]
pub enum StreamChunk {
Content(String),
Thinking(String),
Compaction(String),
ToolUse(Vec<ToolUseRequest>),
}
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolUseRequest {
pub id: String,
pub name: ToolName,
pub input: serde_json::Value,
}
#[derive(Debug, Clone)]
pub enum ThinkingBlock {
Thinking { thinking: String, signature: String },
Redacted { data: String },
}
pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
#[derive(Debug, Clone)]
pub enum ChatResponse {
Text(String),
ToolUse {
text: Option<String>,
tool_calls: Vec<ToolUseRequest>,
thinking_blocks: Vec<ThinkingBlock>,
},
}
pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
#[must_use]
pub fn default_debug_request_json(
messages: &[Message],
tools: &[ToolDefinition],
) -> serde_json::Value {
serde_json::json!({
"model": serde_json::Value::Null,
"max_tokens": serde_json::Value::Null,
"messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
"tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
"temperature": serde_json::Value::Null,
"cache_control": serde_json::Value::Null,
})
}
#[derive(Debug, Clone, Default)]
pub struct GenerationOverrides {
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<usize>,
pub frequency_penalty: Option<f64>,
pub presence_penalty: Option<f64>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum MessagePart {
Text { text: String },
ToolOutput {
tool_name: zeph_common::ToolName,
body: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
compacted_at: Option<i64>,
},
Recall { text: String },
CodeContext { text: String },
Summary { text: String },
CrossSession { text: String },
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
#[serde(default)]
is_error: bool,
},
Image(Box<ImageData>),
ThinkingBlock { thinking: String, signature: String },
RedactedThinkingBlock { data: String },
Compaction { summary: String },
}
impl MessagePart {
#[must_use]
pub fn as_plain_text(&self) -> Option<&str> {
match self {
Self::Text { text }
| Self::Recall { text }
| Self::CodeContext { text }
| Self::Summary { text }
| Self::CrossSession { text } => Some(text.as_str()),
_ => None,
}
}
#[must_use]
pub fn as_image(&self) -> Option<&ImageData> {
if let Self::Image(img) = self {
Some(img)
} else {
None
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ImageData {
#[serde(with = "serde_bytes_base64")]
pub data: Vec<u8>,
pub mime_type: String,
}
mod serde_bytes_base64 {
use base64::{Engine, engine::general_purpose::STANDARD};
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
s.serialize_str(&STANDARD.encode(bytes))
}
pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(d)?;
STANDARD.decode(&s).map_err(serde::de::Error::custom)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageVisibility {
Both,
AgentOnly,
UserOnly,
}
impl MessageVisibility {
#[must_use]
pub fn is_agent_visible(self) -> bool {
matches!(self, MessageVisibility::Both | MessageVisibility::AgentOnly)
}
#[must_use]
pub fn is_user_visible(self) -> bool {
matches!(self, MessageVisibility::Both | MessageVisibility::UserOnly)
}
}
impl Default for MessageVisibility {
fn default() -> Self {
MessageVisibility::Both
}
}
impl MessageVisibility {
#[must_use]
pub fn as_db_str(self) -> &'static str {
match self {
MessageVisibility::Both => "both",
MessageVisibility::AgentOnly => "agent_only",
MessageVisibility::UserOnly => "user_only",
}
}
#[must_use]
pub fn from_db_str(s: &str) -> Self {
match s {
"agent_only" => MessageVisibility::AgentOnly,
"user_only" => MessageVisibility::UserOnly,
_ => MessageVisibility::Both,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MessageMetadata {
pub visibility: MessageVisibility,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub compacted_at: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub deferred_summary: Option<String>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub focus_pinned: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub focus_marker_id: Option<uuid::Uuid>,
#[serde(skip)]
pub db_id: Option<i64>,
}
impl Default for MessageMetadata {
fn default() -> Self {
Self {
visibility: MessageVisibility::Both,
compacted_at: None,
deferred_summary: None,
focus_pinned: false,
focus_marker_id: None,
db_id: None,
}
}
}
impl MessageMetadata {
#[must_use]
pub fn agent_only() -> Self {
Self {
visibility: MessageVisibility::AgentOnly,
compacted_at: None,
deferred_summary: None,
focus_pinned: false,
focus_marker_id: None,
db_id: None,
}
}
#[must_use]
pub fn user_only() -> Self {
Self {
visibility: MessageVisibility::UserOnly,
compacted_at: None,
deferred_summary: None,
focus_pinned: false,
focus_marker_id: None,
db_id: None,
}
}
#[must_use]
pub fn focus_pinned() -> Self {
Self {
visibility: MessageVisibility::AgentOnly,
compacted_at: None,
deferred_summary: None,
focus_pinned: true,
focus_marker_id: None,
db_id: None,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(default)]
pub parts: Vec<MessagePart>,
#[serde(default)]
pub metadata: MessageMetadata,
}
impl Default for Message {
fn default() -> Self {
Self {
role: Role::User,
content: String::new(),
parts: vec![],
metadata: MessageMetadata::default(),
}
}
}
impl Message {
#[must_use]
pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
parts: vec![],
metadata: MessageMetadata::default(),
}
}
#[must_use]
pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
let content = Self::flatten_parts(&parts);
Self {
role,
content,
parts,
metadata: MessageMetadata::default(),
}
}
#[must_use]
pub fn to_llm_content(&self) -> &str {
&self.content
}
pub fn rebuild_content(&mut self) {
if !self.parts.is_empty() {
self.content = Self::flatten_parts(&self.parts);
}
}
fn flatten_parts(parts: &[MessagePart]) -> String {
use std::fmt::Write;
let mut out = String::new();
for part in parts {
match part {
MessagePart::Text { text }
| MessagePart::Recall { text }
| MessagePart::CodeContext { text }
| MessagePart::Summary { text }
| MessagePart::CrossSession { text } => out.push_str(text),
MessagePart::ToolOutput {
tool_name,
body,
compacted_at,
} => {
if compacted_at.is_some() {
if body.is_empty() {
let _ = write!(out, "[tool output: {tool_name}] (pruned)");
} else {
let _ = write!(out, "[tool output: {tool_name}] {body}");
}
} else {
let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
}
}
MessagePart::ToolUse { id, name, .. } => {
let _ = write!(out, "[tool_use: {name}({id})]");
}
MessagePart::ToolResult {
tool_use_id,
content,
..
} => {
let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
}
MessagePart::Image(img) => {
let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
}
MessagePart::ThinkingBlock { .. }
| MessagePart::RedactedThinkingBlock { .. }
| MessagePart::Compaction { .. } => {}
}
}
out
}
}
pub trait LlmProvider: Send + Sync {
fn context_window(&self) -> Option<usize> {
None
}
fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
fn chat_stream(
&self,
messages: &[Message],
) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
fn supports_streaming(&self) -> bool;
fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
fn embed_batch(
&self,
texts: &[&str],
) -> impl Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
let owned = owned_strs(texts);
async move {
let mut results = Vec::with_capacity(owned.len());
for text in &owned {
results.push(self.embed(text).await?);
}
Ok(results)
}
}
fn supports_embeddings(&self) -> bool;
fn name(&self) -> &str;
#[allow(clippy::unnecessary_literal_bound)]
fn model_identifier(&self) -> &str {
""
}
fn supports_vision(&self) -> bool {
false
}
fn supports_tool_use(&self) -> bool {
true
}
#[allow(async_fn_in_trait)]
async fn chat_with_tools(
&self,
messages: &[Message],
_tools: &[ToolDefinition],
) -> Result<ChatResponse, LlmError> {
Ok(ChatResponse::Text(self.chat(messages).await?))
}
fn last_cache_usage(&self) -> Option<(u64, u64)> {
None
}
fn last_usage(&self) -> Option<(u64, u64)> {
None
}
fn take_compaction_summary(&self) -> Option<String> {
None
}
fn record_quality_outcome(&self, _provider_name: &str, _success: bool) {}
#[must_use]
fn debug_request_json(
&self,
messages: &[Message],
tools: &[ToolDefinition],
_stream: bool,
) -> serde_json::Value {
default_debug_request_json(messages, tools)
}
fn list_models(&self) -> Vec<String> {
vec![]
}
fn supports_structured_output(&self) -> bool {
false
}
#[allow(async_fn_in_trait)]
async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
where
T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
Self: Sized,
{
let (_, schema_json) = cached_schema::<T>()?;
let type_name = short_type_name::<T>();
let mut augmented = messages.to_vec();
let instruction = format!(
"Respond with a valid JSON object matching this schema. \
Output ONLY the JSON, no markdown fences or extra text.\n\n\
Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
);
augmented.insert(0, Message::from_legacy(Role::System, instruction));
let raw = self.chat(&augmented).await?;
let cleaned = strip_json_fences(&raw);
match serde_json::from_str::<T>(cleaned) {
Ok(val) => Ok(val),
Err(first_err) => {
augmented.push(Message::from_legacy(Role::Assistant, &raw));
augmented.push(Message::from_legacy(
Role::User,
format!(
"Your response was not valid JSON. Error: {first_err}. \
Please output ONLY valid JSON matching the schema."
),
));
let retry_raw = self.chat(&augmented).await?;
let retry_cleaned = strip_json_fences(&retry_raw);
serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
LlmError::StructuredParse(format!("parse failed after retry: {e}"))
})
}
}
}
}
fn strip_json_fences(s: &str) -> &str {
s.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim()
}
#[cfg(test)]
mod tests {
use tokio_stream::StreamExt;
use super::*;
struct StubProvider {
response: String,
}
impl LlmProvider for StubProvider {
async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
Ok(self.response.clone())
}
async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
let response = self.chat(messages).await?;
Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
response,
)))))
}
fn supports_streaming(&self) -> bool {
false
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
Ok(vec![0.1, 0.2, 0.3])
}
fn supports_embeddings(&self) -> bool {
false
}
fn name(&self) -> &'static str {
"stub"
}
}
#[test]
fn context_window_default_returns_none() {
let provider = StubProvider {
response: String::new(),
};
assert!(provider.context_window().is_none());
}
#[test]
fn supports_streaming_default_returns_false() {
let provider = StubProvider {
response: String::new(),
};
assert!(!provider.supports_streaming());
}
#[tokio::test]
async fn chat_stream_default_yields_single_chunk() {
let provider = StubProvider {
response: "hello world".into(),
};
let messages = vec![Message {
role: Role::User,
content: "test".into(),
parts: vec![],
metadata: MessageMetadata::default(),
}];
let mut stream = provider.chat_stream(&messages).await.unwrap();
let chunk = stream.next().await.unwrap().unwrap();
assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn chat_stream_default_propagates_chat_error() {
struct FailProvider;
impl LlmProvider for FailProvider {
async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
Err(LlmError::Unavailable)
}
async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
let response = self.chat(messages).await?;
Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
response,
)))))
}
fn supports_streaming(&self) -> bool {
false
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
Err(LlmError::Unavailable)
}
fn supports_embeddings(&self) -> bool {
false
}
fn name(&self) -> &'static str {
"fail"
}
}
let provider = FailProvider;
let messages = vec![Message {
role: Role::User,
content: "test".into(),
parts: vec![],
metadata: MessageMetadata::default(),
}];
let result = provider.chat_stream(&messages).await;
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("provider unavailable"));
}
}
#[tokio::test]
async fn stub_provider_embed_returns_vector() {
let provider = StubProvider {
response: String::new(),
};
let embedding = provider.embed("test").await.unwrap();
assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
}
#[tokio::test]
async fn fail_provider_embed_propagates_error() {
struct FailProvider;
impl LlmProvider for FailProvider {
async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
Err(LlmError::Unavailable)
}
async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
let response = self.chat(messages).await?;
Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
response,
)))))
}
fn supports_streaming(&self) -> bool {
false
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
Err(LlmError::EmbedUnsupported {
provider: "fail".into(),
})
}
fn supports_embeddings(&self) -> bool {
false
}
fn name(&self) -> &'static str {
"fail"
}
}
let provider = FailProvider;
let result = provider.embed("test").await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("embedding not supported")
);
}
#[test]
fn role_serialization() {
let system = Role::System;
let user = Role::User;
let assistant = Role::Assistant;
assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
}
#[test]
fn role_deserialization() {
let system: Role = serde_json::from_str("\"system\"").unwrap();
let user: Role = serde_json::from_str("\"user\"").unwrap();
let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
assert_eq!(system, Role::System);
assert_eq!(user, Role::User);
assert_eq!(assistant, Role::Assistant);
}
#[test]
fn message_clone() {
let msg = Message {
role: Role::User,
content: "test".into(),
parts: vec![],
metadata: MessageMetadata::default(),
};
let cloned = msg.clone();
assert_eq!(cloned.role, msg.role);
assert_eq!(cloned.content, msg.content);
}
#[test]
fn message_debug() {
let msg = Message {
role: Role::Assistant,
content: "response".into(),
parts: vec![],
metadata: MessageMetadata::default(),
};
let debug = format!("{msg:?}");
assert!(debug.contains("Assistant"));
assert!(debug.contains("response"));
}
#[test]
fn message_serialization() {
let msg = Message {
role: Role::User,
content: "hello".into(),
parts: vec![],
metadata: MessageMetadata::default(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"hello\""));
}
#[test]
fn message_part_serde_round_trip() {
let parts = vec![
MessagePart::Text {
text: "hello".into(),
},
MessagePart::ToolOutput {
tool_name: "bash".into(),
body: "output".into(),
compacted_at: None,
},
MessagePart::Recall {
text: "recall".into(),
},
MessagePart::CodeContext {
text: "code".into(),
},
MessagePart::Summary {
text: "summary".into(),
},
];
let json = serde_json::to_string(&parts).unwrap();
let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.len(), 5);
}
#[test]
fn from_legacy_creates_empty_parts() {
let msg = Message::from_legacy(Role::User, "hello");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content, "hello");
assert!(msg.parts.is_empty());
assert_eq!(msg.to_llm_content(), "hello");
}
#[test]
fn from_parts_flattens_content() {
let msg = Message::from_parts(
Role::System,
vec![MessagePart::Recall {
text: "recalled data".into(),
}],
);
assert_eq!(msg.content, "recalled data");
assert_eq!(msg.to_llm_content(), "recalled data");
assert_eq!(msg.parts.len(), 1);
}
#[test]
fn from_parts_tool_output_format() {
let msg = Message::from_parts(
Role::User,
vec![MessagePart::ToolOutput {
tool_name: "bash".into(),
body: "hello world".into(),
compacted_at: None,
}],
);
assert!(msg.content.contains("[tool output: bash]"));
assert!(msg.content.contains("hello world"));
}
#[test]
fn message_deserializes_without_parts() {
let json = r#"{"role":"user","content":"hello"}"#;
let msg: Message = serde_json::from_str(json).unwrap();
assert_eq!(msg.content, "hello");
assert!(msg.parts.is_empty());
}
#[test]
fn flatten_skips_compacted_tool_output_empty_body() {
let msg = Message::from_parts(
Role::User,
vec![
MessagePart::Text {
text: "prefix ".into(),
},
MessagePart::ToolOutput {
tool_name: "bash".into(),
body: String::new(),
compacted_at: Some(1234),
},
MessagePart::Text {
text: " suffix".into(),
},
],
);
assert!(msg.content.contains("(pruned)"));
assert!(msg.content.contains("prefix "));
assert!(msg.content.contains(" suffix"));
}
#[test]
fn flatten_compacted_tool_output_with_reference_renders_body() {
let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
let msg = Message::from_parts(
Role::User,
vec![MessagePart::ToolOutput {
tool_name: "bash".into(),
body: ref_notice.into(),
compacted_at: Some(1234),
}],
);
assert!(msg.content.contains(ref_notice));
assert!(!msg.content.contains("(pruned)"));
}
#[test]
fn rebuild_content_syncs_after_mutation() {
let mut msg = Message::from_parts(
Role::User,
vec![MessagePart::ToolOutput {
tool_name: "bash".into(),
body: "original".into(),
compacted_at: None,
}],
);
assert!(msg.content.contains("original"));
if let MessagePart::ToolOutput {
ref mut compacted_at,
ref mut body,
..
} = msg.parts[0]
{
*compacted_at = Some(999);
body.clear(); }
msg.rebuild_content();
assert!(msg.content.contains("(pruned)"));
assert!(!msg.content.contains("original"));
}
#[test]
fn message_part_tool_use_serde_round_trip() {
let part = MessagePart::ToolUse {
id: "toolu_123".into(),
name: "bash".into(),
input: serde_json::json!({"command": "ls"}),
};
let json = serde_json::to_string(&part).unwrap();
let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
if let MessagePart::ToolUse { id, name, input } = deserialized {
assert_eq!(id, "toolu_123");
assert_eq!(name, "bash");
assert_eq!(input["command"], "ls");
} else {
panic!("expected ToolUse");
}
}
#[test]
fn message_part_tool_result_serde_round_trip() {
let part = MessagePart::ToolResult {
tool_use_id: "toolu_123".into(),
content: "file1.rs\nfile2.rs".into(),
is_error: false,
};
let json = serde_json::to_string(&part).unwrap();
let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
if let MessagePart::ToolResult {
tool_use_id,
content,
is_error,
} = deserialized
{
assert_eq!(tool_use_id, "toolu_123");
assert_eq!(content, "file1.rs\nfile2.rs");
assert!(!is_error);
} else {
panic!("expected ToolResult");
}
}
#[test]
fn message_part_tool_result_is_error_default() {
let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
let part: MessagePart = serde_json::from_str(json).unwrap();
if let MessagePart::ToolResult { is_error, .. } = part {
assert!(!is_error);
} else {
panic!("expected ToolResult");
}
}
#[test]
fn chat_response_construction() {
let text = ChatResponse::Text("hello".into());
assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
let tool_use = ChatResponse::ToolUse {
text: Some("I'll run that".into()),
tool_calls: vec![ToolUseRequest {
id: "1".into(),
name: "bash".into(),
input: serde_json::json!({}),
}],
thinking_blocks: vec![],
};
assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
}
#[test]
fn flatten_parts_tool_use() {
let msg = Message::from_parts(
Role::Assistant,
vec![MessagePart::ToolUse {
id: "t1".into(),
name: "bash".into(),
input: serde_json::json!({"command": "ls"}),
}],
);
assert!(msg.content.contains("[tool_use: bash(t1)]"));
}
#[test]
fn flatten_parts_tool_result() {
let msg = Message::from_parts(
Role::User,
vec![MessagePart::ToolResult {
tool_use_id: "t1".into(),
content: "output here".into(),
is_error: false,
}],
);
assert!(msg.content.contains("[tool_result: t1]"));
assert!(msg.content.contains("output here"));
}
#[test]
fn tool_definition_serde_round_trip() {
let def = ToolDefinition {
name: "bash".into(),
description: "Execute a shell command".into(),
parameters: serde_json::json!({"type": "object"}),
};
let json = serde_json::to_string(&def).unwrap();
let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, "bash");
assert_eq!(deserialized.description, "Execute a shell command");
}
#[tokio::test]
async fn chat_with_tools_default_delegates_to_chat() {
let provider = StubProvider {
response: "hello".into(),
};
let messages = vec![Message::from_legacy(Role::User, "test")];
let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
}
#[test]
fn tool_output_compacted_at_serde_default() {
let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
let part: MessagePart = serde_json::from_str(json).unwrap();
if let MessagePart::ToolOutput { compacted_at, .. } = part {
assert!(compacted_at.is_none());
} else {
panic!("expected ToolOutput");
}
}
#[test]
fn strip_json_fences_plain_json() {
assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
}
#[test]
fn strip_json_fences_with_json_fence() {
assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
}
#[test]
fn strip_json_fences_with_plain_fence() {
assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
}
#[test]
fn strip_json_fences_whitespace() {
assert_eq!(strip_json_fences(" \n "), "");
}
#[test]
fn strip_json_fences_empty() {
assert_eq!(strip_json_fences(""), "");
}
#[test]
fn strip_json_fences_outer_whitespace() {
assert_eq!(
strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
r#"{"a": 1}"#
);
}
#[test]
fn strip_json_fences_only_opening_fence() {
assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
struct TestOutput {
value: String,
}
struct SequentialStub {
responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
}
impl SequentialStub {
fn new(responses: Vec<Result<String, LlmError>>) -> Self {
Self {
responses: std::sync::Mutex::new(responses),
}
}
}
impl LlmProvider for SequentialStub {
async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
return Err(LlmError::Other("no more responses".into()));
}
responses.remove(0)
}
async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
let response = self.chat(messages).await?;
Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
response,
)))))
}
fn supports_streaming(&self) -> bool {
false
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
Err(LlmError::EmbedUnsupported {
provider: "sequential-stub".into(),
})
}
fn supports_embeddings(&self) -> bool {
false
}
fn name(&self) -> &'static str {
"sequential-stub"
}
}
#[tokio::test]
async fn chat_typed_happy_path() {
let provider = StubProvider {
response: r#"{"value": "hello"}"#.into(),
};
let messages = vec![Message::from_legacy(Role::User, "test")];
let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
assert_eq!(
result,
TestOutput {
value: "hello".into()
}
);
}
#[tokio::test]
async fn chat_typed_retry_succeeds() {
let provider = SequentialStub::new(vec![
Ok("not valid json".into()),
Ok(r#"{"value": "ok"}"#.into()),
]);
let messages = vec![Message::from_legacy(Role::User, "test")];
let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
assert_eq!(result, TestOutput { value: "ok".into() });
}
#[tokio::test]
async fn chat_typed_both_fail() {
let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
let messages = vec![Message::from_legacy(Role::User, "test")];
let result = provider.chat_typed::<TestOutput>(&messages).await;
let err = result.unwrap_err();
assert!(err.to_string().contains("parse failed after retry"));
}
#[tokio::test]
async fn chat_typed_chat_error_propagates() {
let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
let messages = vec![Message::from_legacy(Role::User, "test")];
let result = provider.chat_typed::<TestOutput>(&messages).await;
assert!(matches!(result, Err(LlmError::Unavailable)));
}
#[tokio::test]
async fn chat_typed_strips_fences() {
let provider = StubProvider {
response: "```json\n{\"value\": \"fenced\"}\n```".into(),
};
let messages = vec![Message::from_legacy(Role::User, "test")];
let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
assert_eq!(
result,
TestOutput {
value: "fenced".into()
}
);
}
#[test]
fn supports_structured_output_default_false() {
let provider = StubProvider {
response: String::new(),
};
assert!(!provider.supports_structured_output());
}
#[test]
fn structured_parse_error_display() {
let err = LlmError::StructuredParse("test error".into());
assert_eq!(
err.to_string(),
"structured output parse failed: test error"
);
}
#[test]
fn message_part_image_roundtrip_json() {
let part = MessagePart::Image(Box::new(ImageData {
data: vec![1, 2, 3, 4],
mime_type: "image/jpeg".into(),
}));
let json = serde_json::to_string(&part).unwrap();
let decoded: MessagePart = serde_json::from_str(&json).unwrap();
match decoded {
MessagePart::Image(img) => {
assert_eq!(img.data, vec![1, 2, 3, 4]);
assert_eq!(img.mime_type, "image/jpeg");
}
_ => panic!("expected Image variant"),
}
}
#[test]
fn flatten_parts_includes_image_placeholder() {
let msg = Message::from_parts(
Role::User,
vec![
MessagePart::Text {
text: "see this".into(),
},
MessagePart::Image(Box::new(ImageData {
data: vec![0u8; 100],
mime_type: "image/png".into(),
})),
],
);
let content = msg.to_llm_content();
assert!(content.contains("see this"));
assert!(content.contains("[image: image/png"));
}
#[test]
fn supports_vision_default_false() {
let provider = StubProvider {
response: String::new(),
};
assert!(!provider.supports_vision());
}
#[test]
fn message_metadata_default_both_visible() {
let m = MessageMetadata::default();
assert!(m.visibility.is_agent_visible());
assert!(m.visibility.is_user_visible());
assert_eq!(m.visibility, MessageVisibility::Both);
assert!(m.compacted_at.is_none());
}
#[test]
fn message_metadata_agent_only() {
let m = MessageMetadata::agent_only();
assert!(m.visibility.is_agent_visible());
assert!(!m.visibility.is_user_visible());
assert_eq!(m.visibility, MessageVisibility::AgentOnly);
}
#[test]
fn message_metadata_user_only() {
let m = MessageMetadata::user_only();
assert!(!m.visibility.is_agent_visible());
assert!(m.visibility.is_user_visible());
assert_eq!(m.visibility, MessageVisibility::UserOnly);
}
#[test]
fn message_metadata_serde_default() {
let json = r#"{"role":"user","content":"hello"}"#;
let msg: Message = serde_json::from_str(json).unwrap();
assert!(msg.metadata.visibility.is_agent_visible());
assert!(msg.metadata.visibility.is_user_visible());
}
#[test]
fn message_metadata_round_trip() {
let msg = Message {
role: Role::User,
content: "test".into(),
parts: vec![],
metadata: MessageMetadata::agent_only(),
};
let json = serde_json::to_string(&msg).unwrap();
let decoded: Message = serde_json::from_str(&json).unwrap();
assert!(decoded.metadata.visibility.is_agent_visible());
assert!(!decoded.metadata.visibility.is_user_visible());
assert_eq!(decoded.metadata.visibility, MessageVisibility::AgentOnly);
}
#[test]
fn message_part_compaction_round_trip() {
let part = MessagePart::Compaction {
summary: "Context was summarized.".to_owned(),
};
let json = serde_json::to_string(&part).unwrap();
let decoded: MessagePart = serde_json::from_str(&json).unwrap();
assert!(
matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
);
}
#[test]
fn flatten_parts_compaction_contributes_no_text() {
let parts = vec![
MessagePart::Text {
text: "Hello".to_owned(),
},
MessagePart::Compaction {
summary: "Summary".to_owned(),
},
];
let msg = Message::from_parts(Role::Assistant, parts);
assert_eq!(msg.content.trim(), "Hello");
}
#[test]
fn stream_chunk_compaction_variant() {
let chunk = StreamChunk::Compaction("A summary".to_owned());
assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
}
#[test]
fn short_type_name_extracts_last_segment() {
struct MyOutput;
assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
}
#[test]
fn short_type_name_primitive_returns_full_name() {
assert_eq!(short_type_name::<u32>(), "u32");
assert_eq!(short_type_name::<bool>(), "bool");
}
#[test]
fn short_type_name_nested_path_returns_last() {
assert_eq!(
short_type_name::<std::collections::HashMap<u32, u32>>(),
"HashMap<u32, u32>"
);
}
#[test]
fn summary_roundtrip() {
let part = MessagePart::Summary {
text: "hello".to_string(),
};
let json = serde_json::to_string(&part).expect("serialization must not fail");
assert!(
json.contains("\"kind\":\"summary\""),
"must use internally-tagged format, got: {json}"
);
assert!(
!json.contains("\"Summary\""),
"must not use externally-tagged format, got: {json}"
);
let decoded: MessagePart =
serde_json::from_str(&json).expect("deserialization must not fail");
match decoded {
MessagePart::Summary { text } => assert_eq!(text, "hello"),
other => panic!("expected MessagePart::Summary, got {other:?}"),
}
}
#[tokio::test]
async fn embed_batch_default_empty_returns_empty() {
let provider = StubProvider {
response: String::new(),
};
let result = provider.embed_batch(&[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn embed_batch_default_calls_embed_sequentially() {
let provider = StubProvider {
response: String::new(),
};
let texts = ["hello", "world", "foo"];
let result = provider.embed_batch(&texts).await.unwrap();
assert_eq!(result.len(), 3);
for vec in &result {
assert_eq!(vec, &[0.1_f32, 0.2, 0.3]);
}
}
#[test]
fn message_visibility_db_roundtrip_both() {
assert_eq!(MessageVisibility::Both.as_db_str(), "both");
assert_eq!(
MessageVisibility::from_db_str("both"),
MessageVisibility::Both
);
}
#[test]
fn message_visibility_db_roundtrip_agent_only() {
assert_eq!(MessageVisibility::AgentOnly.as_db_str(), "agent_only");
assert_eq!(
MessageVisibility::from_db_str("agent_only"),
MessageVisibility::AgentOnly
);
}
#[test]
fn message_visibility_db_roundtrip_user_only() {
assert_eq!(MessageVisibility::UserOnly.as_db_str(), "user_only");
assert_eq!(
MessageVisibility::from_db_str("user_only"),
MessageVisibility::UserOnly
);
}
#[test]
fn message_visibility_from_db_str_unknown_defaults_to_both() {
assert_eq!(
MessageVisibility::from_db_str("unknown_future_value"),
MessageVisibility::Both
);
assert_eq!(MessageVisibility::from_db_str(""), MessageVisibility::Both);
}
}