pub mod anthropic;
pub mod openai;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::tools::ToolDefinition;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
},
#[serde(rename = "thinking")]
Thinking {
thinking: String,
#[serde(skip_serializing_if = "Option::is_none")]
signature: Option<String>,
},
#[serde(rename = "server_tool_use")]
ServerToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "web_search_tool_result")]
WebSearchResult {
tool_use_id: String,
content: WebSearchContent,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WebSearchContent {
pub results: Vec<WebSearchResultItem>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WebSearchResultItem {
pub title: Option<String>,
pub url: String,
pub encrypted_content: Option<String>,
pub snippet: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerTool {
#[serde(rename = "type")]
pub tool_type: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_uses: Option<u32>,
}
impl ServerTool {
pub fn web_search(max_uses: Option<u32>) -> Self {
Self {
tool_type: "web_search_tool".to_string(),
name: "web_search".to_string(),
max_uses,
}
}
}
#[derive(Debug, Clone)]
pub struct ChatRequest {
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub system: Option<String>,
pub think: bool,
pub max_tokens: u32,
pub server_tools: Vec<ServerTool>,
pub enable_caching: bool,
}
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub content: Vec<ContentBlock>,
pub stop_reason: StopReason,
pub usage: Usage,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
pub cache_creation_input_tokens: u32,
pub cache_read_input_tokens: u32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum StopReason {
EndTurn,
ToolUse,
MaxTokens,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
FirstByte,
ThinkingDelta(String),
TextDelta(String),
ToolUseStart { id: String, name: String },
ToolInputDelta { bytes_so_far: usize },
Usage { output_tokens: u32 },
Done(ChatResponse),
Error(String),
}
#[async_trait]
pub trait Provider: Send + Sync {
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
fn context_size(&self) -> Option<u32> {
None
}
async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
let (tx, rx) = mpsc::channel(32);
let response = self.chat(request).await?;
let _ = tx.send(StreamEvent::FirstByte).await;
for block in &response.content {
if let ContentBlock::Text { text } = block {
let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
}
}
let _ = tx.send(StreamEvent::Done(response)).await;
Ok(rx)
}
fn clone_box(&self) -> Box<dyn Provider>;
}
impl Clone for Box<dyn Provider> {
fn clone(&self) -> Self {
self.clone_box()
}
}