use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::pin::Pin;
use super::schema::ToolSchema;
use super::{ModelUri, ProviderConfig, ProviderType, ToolFormat};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
Image {
#[serde(flatten)]
source: ImageSource,
},
Audio {
#[serde(flatten)]
source: AudioSource,
},
ToolUse {
id: String,
name: String,
input: JsonValue,
},
ToolResult {
tool_use_id: String,
content: String,
is_error: bool,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ImageSource {
Base64 { media_type: String, data: String },
Url { url: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AudioSource {
Base64 { media_type: String, data: String },
Url { url: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<ContentPart>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl Message {
pub fn text(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: vec![ContentPart::Text {
text: content.into(),
}],
name: None,
tool_call_id: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::text(Role::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::text(Role::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::text(Role::Assistant, content)
}
pub fn user_with_image(text: impl Into<String>, image_url: impl Into<String>) -> Self {
Self {
role: Role::User,
content: vec![
ContentPart::Text { text: text.into() },
ContentPart::Image {
source: ImageSource::Url {
url: image_url.into(),
},
},
],
name: None,
tool_call_id: None,
}
}
pub fn tool_result(
tool_use_id: impl Into<String>,
content: impl Into<String>,
is_error: bool,
) -> Self {
Self {
role: Role::Tool,
content: vec![ContentPart::ToolResult {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error,
}],
name: None,
tool_call_id: None,
}
}
pub fn text_content(&self) -> String {
self.content
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatRequest {
pub model: ModelUri,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolSchema>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extra: Option<HashMap<String, JsonValue>>,
}
impl ChatRequest {
pub fn new(model: ModelUri, messages: Vec<Message>) -> Self {
Self {
model,
messages,
max_tokens: None,
temperature: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
tools: None,
tool_choice: None,
response_format: None,
seed: None,
user: None,
extra: None,
}
}
pub fn with_max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn with_tools(mut self, tools: Vec<ToolSchema>) -> Self {
self.tools = Some(tools);
self
}
pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
self.tool_choice = Some(choice);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto,
None,
Tool { name: String },
Required,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseFormat {
Text,
JsonObject,
JsonSchema { schema: JsonValue },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub id: String,
pub model: String,
pub message: Message,
pub usage: TokenUsage,
pub finish_reason: FinishReason,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, JsonValue>>,
}
impl ChatResponse {
pub fn text(&self) -> String {
self.message.text_content()
}
pub fn has_tool_calls(&self) -> bool {
self.tool_calls
.as_ref()
.map(|tc| !tc.is_empty())
.unwrap_or(false)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
Error,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: JsonValue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamChunk {
pub id: Option<String>,
pub delta: StreamDelta,
pub finish_reason: Option<FinishReason>,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum StreamDelta {
Text(String),
ToolCall {
index: usize,
id: Option<String>,
name: Option<String>,
arguments: Option<String>,
},
}
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, ProviderError>> + Send>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub model: ModelUri,
pub input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EncodingFormat {
Float,
Base64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<Vec<f32>>,
pub model: String,
pub usage: TokenUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub provider: ProviderType,
pub context_length: u32,
pub max_output_tokens: Option<u32>,
pub supports_tools: bool,
pub supports_vision: bool,
pub supports_audio: bool,
pub supports_json_mode: bool,
pub input_cost_per_million: Option<f64>,
pub output_cost_per_million: Option<f64>,
pub capabilities: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ProviderError {
AuthenticationFailed { message: String },
RateLimited {
message: String,
retry_after: Option<u64>,
},
InvalidRequest { message: String },
ModelNotFound { model: String },
ContextLengthExceeded {
max_tokens: u32,
requested_tokens: u32,
},
ContentFiltered { message: String },
Unavailable { message: String },
NetworkError { message: String },
Timeout { seconds: u64 },
Unknown { message: String },
}
impl std::fmt::Display for ProviderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AuthenticationFailed { message } => {
write!(f, "Authentication failed: {}", message)
}
Self::RateLimited { message, .. } => write!(f, "Rate limited: {}", message),
Self::InvalidRequest { message } => write!(f, "Invalid request: {}", message),
Self::ModelNotFound { model } => write!(f, "Model not found: {}", model),
Self::ContextLengthExceeded {
max_tokens,
requested_tokens,
} => {
write!(
f,
"Context length exceeded: {} > {}",
requested_tokens, max_tokens
)
}
Self::ContentFiltered { message } => write!(f, "Content filtered: {}", message),
Self::Unavailable { message } => write!(f, "Provider unavailable: {}", message),
Self::NetworkError { message } => write!(f, "Network error: {}", message),
Self::Timeout { seconds } => write!(f, "Timeout after {} seconds", seconds),
Self::Unknown { message } => write!(f, "Unknown error: {}", message),
}
}
}
impl std::error::Error for ProviderError {}
#[async_trait]
pub trait LLMProvider: Send + Sync {
fn provider_type(&self) -> ProviderType;
fn config(&self) -> &ProviderConfig;
async fn is_available(&self) -> bool;
async fn validate_credentials(&self) -> Result<(), ProviderError>;
async fn list_models(&self) -> Result<Vec<ModelInfo>, ProviderError>;
async fn get_model_info(&self, model: &str) -> Result<ModelInfo, ProviderError>;
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError>;
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream, ProviderError>;
async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, ProviderError>;
fn tool_format(&self) -> ToolFormat;
fn supports(&self, capability: &str) -> bool;
fn auth_headers(&self) -> HashMap<String, String>;
fn transform_request(&self, request: &ChatRequest) -> Result<JsonValue, ProviderError>;
fn parse_response(&self, response: &JsonValue) -> Result<ChatResponse, ProviderError>;
}
#[async_trait]
pub trait LLMProviderExt: LLMProvider {
async fn complete(&self, prompt: &str) -> Result<String, ProviderError> {
let request = ChatRequest::new(self.config().model_uri(), vec![Message::user(prompt)]);
let response = self.chat(request).await?;
Ok(response.text())
}
async fn chat_with_tools(
&self,
messages: Vec<Message>,
tools: Vec<ToolSchema>,
) -> Result<ChatResponse, ProviderError> {
let request = ChatRequest::new(self.config().model_uri(), messages).with_tools(tools);
self.chat(request).await
}
async fn embed_single(&self, text: &str) -> Result<Vec<f32>, ProviderError> {
let request = EmbeddingRequest {
model: self.config().model_uri(),
input: vec![text.to_string()],
encoding_format: Some(EncodingFormat::Float),
dimensions: None,
};
let response = self.embed(request).await?;
response
.embeddings
.into_iter()
.next()
.ok_or_else(|| ProviderError::Unknown {
message: "No embedding returned".to_string(),
})
}
fn count_tokens(&self, text: &str) -> u32 {
(text.len() / 4) as u32
}
}
impl<T: LLMProvider> LLMProviderExt for T {}
pub struct ProviderFactory;
impl ProviderFactory {
pub fn from_uri(uri: &ModelUri) -> Result<Box<dyn LLMProvider>, ProviderError> {
match uri.provider {
ProviderType::OpenAI => {
Err(ProviderError::Unknown {
message: "Provider implementation pending".to_string(),
})
}
ProviderType::Anthropic => Err(ProviderError::Unknown {
message: "Provider implementation pending".to_string(),
}),
_ => Err(ProviderError::Unknown {
message: format!("Provider {:?} not implemented", uri.provider),
}),
}
}
pub fn from_config(
provider_type: ProviderType,
_config: ProviderConfig,
) -> Result<Box<dyn LLMProvider>, ProviderError> {
match provider_type {
ProviderType::OpenAI => Err(ProviderError::Unknown {
message: "Provider implementation pending".to_string(),
}),
_ => Err(ProviderError::Unknown {
message: format!("Provider {:?} not implemented", provider_type),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_creation() {
let msg = Message::user("Hello!");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.text_content(), "Hello!");
}
#[test]
fn test_system_message() {
let msg = Message::system("You are a helpful assistant.");
assert_eq!(msg.role, Role::System);
}
#[test]
fn test_chat_request_builder() {
let uri = ModelUri::parse("openai:gpt-4o").unwrap();
let request = ChatRequest::new(uri, vec![Message::user("Hi")])
.with_max_tokens(100)
.with_temperature(0.7);
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.temperature, Some(0.7));
}
#[test]
fn test_response_text() {
let response = ChatResponse {
id: "test".to_string(),
model: "gpt-4o".to_string(),
message: Message::assistant("Hello there!"),
usage: TokenUsage::default(),
finish_reason: FinishReason::Stop,
tool_calls: None,
metadata: None,
};
assert_eq!(response.text(), "Hello there!");
}
}