pub mod mock;
#[cfg(any(feature = "openai", feature = "ollama", feature = "all-providers"))]
pub mod siumai_adapter;
use async_trait::async_trait;
use futures::stream::BoxStream;
use crate::error::LlmError;
#[async_trait]
pub trait LlmProvider: Send + Sync {
fn id(&self) -> &str;
fn default_model(&self) -> &str;
fn chat_stream(&self, req: ChatRequest) -> BoxStream<'static, Result<ChatEvent, LlmError>>;
async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, LlmError>;
fn supports_embed(&self) -> bool {
true
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ChatEvent {
Delta {
text: String,
},
Finished {
usage: Option<LlmUsage>,
model: Option<String>,
finish_reason: Option<FinishReason>,
metadata: serde_json::Map<String, serde_json::Value>,
},
ToolCall {
id: String,
name: String,
arguments: String,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct LlmUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone)]
pub enum FinishReason {
Stop,
Length,
ToolCall,
ContentFilter,
Other(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum ToolChoice {
Auto,
None,
Specific(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EmittedToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub stop: Option<Vec<String>>,
pub system_prompt: Option<String>,
pub tools: Vec<ToolDefinition>,
pub tool_choice: Option<ToolChoice>,
pub extra: serde_json::Map<String, serde_json::Value>,
}
impl ChatRequest {
pub fn new(model: impl Into<String>, messages: Vec<ChatMessage>) -> Self {
Self {
model: model.into(),
messages,
temperature: None,
max_tokens: None,
stop: None,
system_prompt: None,
tools: Vec::new(),
tool_choice: None,
extra: serde_json::Map::new(),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
pub tool_calls: Option<Vec<EmittedToolCall>>,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
tool_calls: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum ChatRole {
System,
User,
Assistant,
Tool {
tool_call_id: String,
},
}
#[derive(Debug, Clone)]
pub struct EmbedRequest {
pub model: String,
pub inputs: Vec<String>,
pub extra: serde_json::Map<String, serde_json::Value>,
}
impl EmbedRequest {
pub fn new(model: impl Into<String>, inputs: Vec<String>) -> Self {
Self {
model: model.into(),
inputs,
extra: serde_json::Map::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct EmbedResponse {
pub embeddings: Vec<Vec<f32>>,
pub usage: Option<LlmUsage>,
pub model: String,
pub metadata: serde_json::Map<String, serde_json::Value>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chat_request_builder() {
let req = ChatRequest::new("gpt-4o", vec![ChatMessage::user("hello")]);
assert_eq!(req.model, "gpt-4o");
assert_eq!(req.messages.len(), 1);
assert_eq!(req.messages[0].role, ChatRole::User);
}
#[test]
fn chat_request_accepts_tools() {
let tool = ToolDefinition {
name: "get_weather".into(),
description: "Get weather for a city".into(),
parameters: serde_json::Map::new(),
};
let req = ChatRequest {
model: "gpt-4o".into(),
messages: vec![ChatMessage::user("what's the weather?")],
temperature: None,
max_tokens: None,
stop: None,
system_prompt: None,
extra: serde_json::Map::new(),
tools: vec![tool],
tool_choice: Some(ToolChoice::Auto),
};
assert_eq!(req.tools.len(), 1);
assert_eq!(req.tools[0].name, "get_weather");
assert_eq!(req.tool_choice, Some(ToolChoice::Auto));
}
#[test]
fn tool_message_carries_tool_call_id() {
let msg = ChatMessage {
role: ChatRole::Tool {
tool_call_id: "call_123".into(),
},
content: "weather result".into(),
tool_calls: None,
};
match &msg.role {
ChatRole::Tool { tool_call_id } => assert_eq!(tool_call_id, "call_123"),
_ => panic!("expected Tool role"),
}
}
#[test]
fn assistant_message_carries_prior_tool_calls() {
let tool_call = EmittedToolCall {
id: "call_123".into(),
name: "get_weather".into(),
arguments: r#"{"city":"London"}"#.into(),
};
let msg = ChatMessage {
role: ChatRole::Assistant,
content: "I'll check the weather".into(),
tool_calls: Some(vec![tool_call]),
};
assert_eq!(msg.tool_calls.as_ref().unwrap().len(), 1);
assert_eq!(msg.tool_calls.as_ref().unwrap()[0].id, "call_123");
}
}