use async_trait::async_trait;
use futures_util::stream::BoxStream;
use serde::{Deserialize, Serialize};
use super::capabilities::Capabilities;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub images: Vec<ChatImage>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChatImage {
pub data_base64: String,
pub media_type: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolDef {
pub name: String,
pub description: String,
pub schema: serde_json::Value,
}
#[derive(Clone, Debug)]
pub struct ChatRequest {
pub messages: Vec<ChatMessage>,
pub tools: Vec<ToolDef>,
pub tool_choice: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Clone, Debug)]
pub struct ChatResponse {
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub finish_reason: FinishReason,
}
#[derive(Clone, Debug, PartialEq)]
pub enum FinishReason {
Stop,
ToolCalls,
Length,
ContentFilter,
Other(String),
}
#[derive(Clone, Debug)]
pub enum StreamEvent {
TextChunk(String),
ToolCallStart { id: String, name: String },
ToolCallArgs { id: String, args_delta: String },
ToolCallEnd { id: String, args: serde_json::Value },
Done { finish_reason: FinishReason },
}
pub type ChatStream = BoxStream<'static, Result<StreamEvent, LlmError>>;
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("provider HTTP error: {0}")]
Transport(String),
#[error("provider returned status {status}: {body}")]
Status { status: u16, body: String },
#[error("provider response could not be parsed: {0}")]
Parse(String),
#[error("capability '{0}' not supported by this provider")]
UnsupportedCapability(&'static str),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
#[async_trait]
pub trait LlmProvider: Send + Sync {
fn capabilities(&self) -> Capabilities;
fn provider_name(&self) -> &'static str;
fn model(&self) -> &str;
async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, LlmError>;
async fn chat_stream(&self, req: ChatRequest) -> Result<ChatStream, LlmError>;
}