use crate::error::LlmError;
use async_trait::async_trait;
use futures_core::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(default)]
pub tool_calls: Vec<ToolCall>,
#[serde(default)]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDef {
pub name: String,
pub description: String,
pub json_schema: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct ChatRequest {
pub messages: Vec<Message>,
pub tools: Vec<ToolDef>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub response_format: ResponseFormat,
pub stop: Vec<String>,
pub timeout: Option<Duration>,
}
impl ChatRequest {
pub fn new(messages: Vec<Message>) -> Self {
Self {
messages,
tools: Vec::new(),
temperature: None,
max_tokens: None,
response_format: ResponseFormat::Text,
stop: Vec::new(),
timeout: None,
}
}
}
#[derive(Debug, Clone)]
pub enum ResponseFormat {
Text,
Json {
schema: serde_json::Value,
},
StructuredOutput {
schema: serde_json::Value,
},
}
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub message: Message,
pub usage: Usage,
pub finish_reason: FinishReason,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FinishReason {
Stop,
ToolCalls,
Length,
ContentFilter,
Error,
}
#[derive(Debug, Clone, Default)]
pub struct Capabilities {
pub tool_calling: bool,
pub streaming: bool,
pub structured_output: bool,
pub embeddings: bool,
pub max_context_tokens: u32,
pub vision: bool,
}
impl Capabilities {
pub fn builder() -> CapabilitiesBuilder {
CapabilitiesBuilder(Capabilities::default())
}
}
#[derive(Debug, Clone, Default)]
pub struct CapabilitiesBuilder(Capabilities);
impl CapabilitiesBuilder {
pub fn tool_calling(mut self, v: bool) -> Self {
self.0.tool_calling = v;
self
}
pub fn streaming(mut self, v: bool) -> Self {
self.0.streaming = v;
self
}
pub fn structured_output(mut self, v: bool) -> Self {
self.0.structured_output = v;
self
}
pub fn embeddings(mut self, v: bool) -> Self {
self.0.embeddings = v;
self
}
pub fn max_context_tokens(mut self, v: u32) -> Self {
self.0.max_context_tokens = v;
self
}
pub fn vision(mut self, v: bool) -> Self {
self.0.vision = v;
self
}
pub fn build(self) -> Capabilities {
self.0
}
}
#[derive(Debug, Clone)]
pub struct ChatChunk {
pub delta: String,
pub tool_calls: Vec<ToolCall>,
pub finish_reason: Option<FinishReason>,
}
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<ChatChunk, LlmError>> + Send + 'static>>;
pub type Embedding = Vec<f32>;
#[async_trait]
pub trait LlmClient: Send + Sync {
fn name(&self) -> &str;
fn capabilities(&self) -> &Capabilities;
async fn complete(&self, req: ChatRequest) -> Result<ChatResponse, LlmError>;
async fn stream(&self, req: ChatRequest) -> Result<ChunkStream, LlmError>;
async fn embed(&self, texts: &[String]) -> Result<Vec<Embedding>, LlmError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(dead_code)]
fn _assert_dyn_compatible(_: &dyn LlmClient) {}
#[test]
fn chat_request_default_has_no_tools() {
let req = ChatRequest::new(vec![]);
assert!(req.tools.is_empty());
assert!(matches!(req.response_format, ResponseFormat::Text));
}
#[test]
fn capabilities_builder_default_matches_struct_default() {
let built = Capabilities::builder().build();
let direct = Capabilities::default();
assert_eq!(built.tool_calling, direct.tool_calling);
assert_eq!(built.streaming, direct.streaming);
assert_eq!(built.structured_output, direct.structured_output);
assert_eq!(built.embeddings, direct.embeddings);
assert_eq!(built.max_context_tokens, direct.max_context_tokens);
assert_eq!(built.vision, direct.vision);
}
#[test]
fn capabilities_builder_sets_tool_calling() {
let c = Capabilities::builder().tool_calling(true).build();
assert!(c.tool_calling);
}
#[test]
fn capabilities_builder_sets_streaming() {
let c = Capabilities::builder().streaming(true).build();
assert!(c.streaming);
}
#[test]
fn capabilities_builder_sets_structured_output() {
let c = Capabilities::builder().structured_output(true).build();
assert!(c.structured_output);
}
#[test]
fn capabilities_builder_sets_embeddings() {
let c = Capabilities::builder().embeddings(true).build();
assert!(c.embeddings);
}
#[test]
fn capabilities_builder_sets_max_context_tokens() {
let c = Capabilities::builder().max_context_tokens(8000).build();
assert_eq!(c.max_context_tokens, 8000);
}
#[test]
fn capabilities_builder_sets_vision() {
let c = Capabilities::builder().vision(true).build();
assert!(c.vision);
}
#[test]
fn capabilities_builder_chains_all_setters() {
let c = Capabilities::builder()
.tool_calling(true)
.streaming(true)
.structured_output(true)
.embeddings(true)
.max_context_tokens(32_000)
.vision(false)
.build();
assert!(c.tool_calling && c.streaming && c.structured_output && c.embeddings);
assert_eq!(c.max_context_tokens, 32_000);
assert!(!c.vision);
}
}