use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::{Result, ZeptoError};
use crate::providers::structured::OutputFormat;
use crate::session::Message;
#[derive(Debug)]
pub enum StreamEvent {
Delta(String),
ToolCalls(Vec<LLMToolCall>),
Done {
content: String,
usage: Option<Usage>,
},
Error(ZeptoError),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
impl ToolDefinition {
pub fn new(name: &str, description: &str, parameters: serde_json::Value) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
parameters,
}
}
}
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn chat(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
model: Option<&str>,
options: ChatOptions,
) -> Result<LLMResponse>;
fn default_model(&self) -> &str;
fn name(&self) -> &str;
async fn chat_stream(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
model: Option<&str>,
options: ChatOptions,
) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>> {
let response = self.chat(messages, tools, model, options).await?;
let (tx, rx) = tokio::sync::mpsc::channel(1);
let _ = tx
.send(StreamEvent::Done {
content: response.content,
usage: response.usage,
})
.await;
Ok(rx)
}
async fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
Err(ZeptoError::Provider(
"Embedding not supported by this provider".into(),
))
}
}
#[derive(Debug, Clone, Default)]
pub struct ChatOptions {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
pub output_format: OutputFormat,
}
impl ChatOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn with_output_format(mut self, output_format: OutputFormat) -> Self {
self.output_format = output_format;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMResponse {
pub content: String,
pub tool_calls: Vec<LLMToolCall>,
pub usage: Option<Usage>,
}
impl LLMResponse {
pub fn text(content: &str) -> Self {
Self {
content: content.to_string(),
tool_calls: vec![],
usage: None,
}
}
pub fn with_tools(content: &str, tool_calls: Vec<LLMToolCall>) -> Self {
Self {
content: content.to_string(),
tool_calls,
usage: None,
}
}
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
pub fn with_usage(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
impl LLMToolCall {
pub fn new(id: &str, name: &str, arguments: &str) -> Self {
Self {
id: id.to_string(),
name: name.to_string(),
arguments: arguments.to_string(),
}
}
pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> serde_json::Result<T> {
serde_json::from_str(&self.arguments)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl Usage {
pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_response_creation() {
let response = LLMResponse {
content: "Hello".to_string(),
tool_calls: vec![],
usage: None,
};
assert_eq!(response.content, "Hello");
assert!(!response.has_tool_calls());
}
#[test]
fn test_llm_response_text() {
let response = LLMResponse::text("Hello, world!");
assert_eq!(response.content, "Hello, world!");
assert!(!response.has_tool_calls());
assert!(response.usage.is_none());
}
#[test]
fn test_llm_response_with_tools() {
let tool_call = LLMToolCall::new("call_1", "search", r#"{"query": "rust"}"#);
let response = LLMResponse::with_tools("Searching...", vec![tool_call]);
assert_eq!(response.content, "Searching...");
assert!(response.has_tool_calls());
assert_eq!(response.tool_calls.len(), 1);
assert_eq!(response.tool_calls[0].name, "search");
}
#[test]
fn test_llm_response_with_usage() {
let usage = Usage::new(100, 50);
let response = LLMResponse::text("Hello").with_usage(usage);
assert!(response.usage.is_some());
let usage = response.usage.unwrap();
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
#[test]
fn test_chat_options_builder() {
let options = ChatOptions::new()
.with_max_tokens(1000)
.with_temperature(0.7);
assert_eq!(options.max_tokens, Some(1000));
assert_eq!(options.temperature, Some(0.7));
}
#[test]
fn test_chat_options_all_fields() {
let options = ChatOptions::new()
.with_max_tokens(2000)
.with_temperature(0.5)
.with_top_p(0.9)
.with_stop(vec!["END".to_string(), "STOP".to_string()]);
assert_eq!(options.max_tokens, Some(2000));
assert_eq!(options.temperature, Some(0.5));
assert_eq!(options.top_p, Some(0.9));
assert!(options.stop.is_some());
let stop = options.stop.unwrap();
assert_eq!(stop.len(), 2);
assert_eq!(stop[0], "END");
}
#[test]
fn test_chat_options_default() {
let options = ChatOptions::default();
assert!(options.max_tokens.is_none());
assert!(options.temperature.is_none());
assert!(options.top_p.is_none());
assert!(options.stop.is_none());
}
#[test]
fn test_tool_definition() {
let tool = ToolDefinition {
name: "search".to_string(),
description: "Search the web".to_string(),
parameters: serde_json::json!({"type": "object"}),
};
assert_eq!(tool.name, "search");
}
#[test]
fn test_tool_definition_new() {
let tool = ToolDefinition::new(
"web_search",
"Search the web for information",
serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" }
},
"required": ["query"]
}),
);
assert_eq!(tool.name, "web_search");
assert_eq!(tool.description, "Search the web for information");
assert!(tool.parameters.is_object());
}
#[test]
fn test_llm_tool_call_new() {
let call = LLMToolCall::new("call_123", "web_search", r#"{"query": "rust"}"#);
assert_eq!(call.id, "call_123");
assert_eq!(call.name, "web_search");
assert_eq!(call.arguments, r#"{"query": "rust"}"#);
}
#[test]
fn test_llm_tool_call_parse_arguments() {
#[derive(Debug, Deserialize, PartialEq)]
struct SearchArgs {
query: String,
}
let call = LLMToolCall::new("call_1", "search", r#"{"query": "rust programming"}"#);
let args: SearchArgs = call.parse_arguments().unwrap();
assert_eq!(args.query, "rust programming");
}
#[test]
fn test_usage_new() {
let usage = Usage::new(100, 50);
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
#[test]
fn test_llm_response_serialization() {
let response = LLMResponse::text("Hello");
let json = serde_json::to_string(&response).unwrap();
let parsed: LLMResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.content, "Hello");
assert!(!parsed.has_tool_calls());
}
#[test]
fn test_tool_definition_serialization() {
let tool = ToolDefinition::new(
"search",
"Search the web",
serde_json::json!({"type": "object"}),
);
let json = serde_json::to_string(&tool).unwrap();
let parsed: ToolDefinition = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "search");
assert_eq!(parsed.description, "Search the web");
}
#[tokio::test]
async fn test_stream_event_done_carries_content() {
let event = StreamEvent::Done {
content: "hello".to_string(),
usage: Some(Usage::new(10, 5)),
};
match event {
StreamEvent::Done { content, usage } => {
assert_eq!(content, "hello");
assert!(usage.is_some());
}
_ => panic!("Expected Done event"),
}
}
#[tokio::test]
async fn test_stream_event_delta() {
let event = StreamEvent::Delta("chunk".to_string());
match event {
StreamEvent::Delta(text) => assert_eq!(text, "chunk"),
_ => panic!("Expected Delta event"),
}
}
#[tokio::test]
async fn test_stream_event_tool_calls() {
let tc = LLMToolCall::new("call_1", "search", r#"{"q":"rust"}"#);
let event = StreamEvent::ToolCalls(vec![tc]);
match event {
StreamEvent::ToolCalls(calls) => {
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "search");
}
_ => panic!("Expected ToolCalls event"),
}
}
#[tokio::test]
async fn test_stream_event_error() {
let event = StreamEvent::Error(ZeptoError::Provider("fail".into()));
assert!(matches!(event, StreamEvent::Error(_)));
}
#[tokio::test]
async fn test_chat_stream_default_impl() {
struct FakeProvider;
#[async_trait]
impl LLMProvider for FakeProvider {
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Ok(LLMResponse::text("hello from fake"))
}
fn default_model(&self) -> &str {
"fake"
}
fn name(&self) -> &str {
"fake"
}
}
let provider = FakeProvider;
let mut rx = provider
.chat_stream(vec![], vec![], None, ChatOptions::default())
.await
.unwrap();
let event = rx.recv().await.unwrap();
match event {
StreamEvent::Done { content, .. } => {
assert_eq!(content, "hello from fake");
}
_ => panic!("Expected Done event from default chat_stream"),
}
}
#[tokio::test]
async fn test_embed_method_exists_on_trait() {
struct MinimalProvider;
#[async_trait]
impl LLMProvider for MinimalProvider {
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Ok(LLMResponse::text("ok"))
}
fn default_model(&self) -> &str {
"minimal"
}
fn name(&self) -> &str {
"minimal"
}
}
let provider = MinimalProvider;
let result = provider.embed(&["hello".to_string()]).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_embed_default_returns_error() {
struct DefaultProvider;
#[async_trait]
impl LLMProvider for DefaultProvider {
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Ok(LLMResponse::text("ok"))
}
fn default_model(&self) -> &str {
"default"
}
fn name(&self) -> &str {
"default"
}
}
let provider = DefaultProvider;
let result = provider.embed(&["text".to_string()]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string()
.contains("Embedding not supported by this provider"),
"Expected 'Embedding not supported' error, got: {}",
err
);
}
#[tokio::test]
async fn test_embed_default_empty_input_returns_error() {
struct DefaultProvider;
#[async_trait]
impl LLMProvider for DefaultProvider {
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> Result<LLMResponse> {
Ok(LLMResponse::text("ok"))
}
fn default_model(&self) -> &str {
"default"
}
fn name(&self) -> &str {
"default"
}
}
let provider = DefaultProvider;
let result = provider.embed(&[]).await;
assert!(result.is_err());
}
}