pub mod anthropic;
pub mod openai;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::tools::ToolDefinition;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
},
#[serde(rename = "thinking")]
Thinking {
thinking: String,
#[serde(skip_serializing_if = "Option::is_none")]
signature: Option<String>,
},
#[serde(rename = "server_tool_use")]
ServerToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "web_search_tool_result")]
WebSearchResult {
tool_use_id: String,
content: WebSearchContent,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WebSearchContent {
pub results: Vec<WebSearchResultItem>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WebSearchResultItem {
pub title: Option<String>,
pub url: String,
pub encrypted_content: Option<String>,
pub snippet: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerTool {
#[serde(rename = "type")]
pub tool_type: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_uses: Option<u32>,
}
impl ServerTool {
pub fn web_search(max_uses: Option<u32>) -> Self {
Self {
tool_type: "web_search_tool".to_string(),
name: "web_search".to_string(),
max_uses,
}
}
}
#[derive(Debug, Clone)]
pub struct ChatRequest {
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub system: Option<String>,
pub think: bool,
pub max_tokens: u32,
pub server_tools: Vec<ServerTool>,
pub enable_caching: bool,
}
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub content: Vec<ContentBlock>,
pub stop_reason: StopReason,
pub usage: Usage,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
pub cache_creation_input_tokens: u32,
pub cache_read_input_tokens: u32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum StopReason {
EndTurn,
ToolUse,
MaxTokens,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
FirstByte,
ThinkingDelta(String),
TextDelta(String),
ToolUseStart { id: String, name: String },
ToolInputDelta { bytes_so_far: usize },
Usage { output_tokens: u32 },
Done(ChatResponse),
Error(String),
}
#[async_trait]
pub trait Provider: Send + Sync {
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
fn context_size(&self) -> Option<u32> {
None
}
async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
let (tx, rx) = mpsc::channel(32);
let response = self.chat(request).await?;
let _ = tx.send(StreamEvent::FirstByte).await;
for block in &response.content {
if let ContentBlock::Text { text } = block {
let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
}
}
let _ = tx.send(StreamEvent::Done(response)).await;
Ok(rx)
}
fn clone_box(&self) -> Box<dyn Provider>;
}
impl Clone for Box<dyn Provider> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProviderType {
Anthropic,
OpenAI,
}
pub fn create_provider(
provider_type: ProviderType,
api_key: String,
model: String,
base_url: Option<String>,
) -> Result<Box<dyn Provider>> {
match provider_type {
ProviderType::Anthropic => {
let provider = anthropic::AnthropicProvider::new(
api_key,
model,
base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
);
Ok(Box::new(provider))
}
ProviderType::OpenAI => {
let provider = openai::OpenAIProvider::new(
api_key,
model,
base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
);
Ok(Box::new(provider))
}
}
}
pub fn infer_provider_type(model: &str) -> ProviderType {
let lower = model.to_lowercase();
if lower.contains("claude")
|| lower.contains("opus")
|| lower.contains("sonnet")
|| lower.contains("haiku")
{
ProviderType::Anthropic
} else if lower.contains("gpt")
|| lower.contains("o1")
|| lower.contains("o3")
|| lower.contains("o4")
{
ProviderType::OpenAI
} else {
ProviderType::Anthropic
}
}