use crate::error::{ LlmError, Result };
use crate::tools::ToolCall;
use async_trait::async_trait;
use serde::{ Deserialize, Serialize };
use std::collections::HashMap;
use super::message::LlmMessage;
#[async_trait]
pub trait LlmClient: Send + Sync {
async fn chat_completion(
&self,
messages: Vec<LlmMessage>,
tools: Option<Vec<ToolDefinition>>,
options: Option<ChatOptions>
) -> Result<LlmResponse>;
fn model_name(&self) -> &str;
fn provider_name(&self) -> &str;
fn supports_streaming(&self) -> bool {
false
}
async fn chat_completion_stream(
&self,
_messages: Vec<LlmMessage>,
_tools: Option<Vec<ToolDefinition>>,
_options: Option<ChatOptions>
) -> Result<Box<dyn futures::Stream<Item = Result<LlmStreamChunk>> + Send + Unpin + '_>> {
Err(
(LlmError::InvalidRequest {
message: "Streaming not supported by this client".to_string(),
}).into()
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub message: LlmMessage,
pub usage: Option<Usage>,
pub model: String,
pub finish_reason: Option<FinishReason>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmStreamChunk {
pub delta: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub finish_reason: Option<FinishReason>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
Other(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatOptions {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub stop: Option<Vec<String>>,
pub stream: Option<bool>,
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto,
None,
Required {
name: String,
},
}
impl Default for ChatOptions {
fn default() -> Self {
Self {
max_tokens: Some(8192),
temperature: Some(0.7),
top_p: Some(1.0),
top_k: None,
stop: None,
stream: Some(false),
tool_choice: Some(ToolChoice::Auto),
}
}
}