use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCapabilities {
pub max_tokens: u32,
pub max_output_tokens: u32,
pub supports_streaming: bool,
pub supports_tools: bool,
pub supports_reasoning: bool,
pub supports_vision: bool,
pub supports_json_mode: bool,
pub supports_embeddings: bool,
pub supports_image_generation: bool,
pub supports_audio_transcription: bool,
pub supports_speech: bool,
pub supports_video_generation: bool,
pub pii_safe: bool,
pub cost_per_1m_input: Option<f64>,
pub cost_per_1m_output: Option<f64>,
pub cost_per_1m_pixels: Option<f64>,
}
impl Default for ModelCapabilities {
fn default() -> Self {
Self {
max_tokens: 4096,
max_output_tokens: 4096,
supports_streaming: true,
supports_tools: false,
supports_reasoning: false,
supports_vision: false,
supports_json_mode: false,
supports_embeddings: false,
supports_image_generation: false,
supports_audio_transcription: false,
supports_speech: false,
supports_video_generation: false,
pii_safe: false,
cost_per_1m_input: None,
cost_per_1m_output: None,
cost_per_1m_pixels: None,
}
}
}
impl ModelCapabilities {
pub fn gpt4() -> Self {
Self {
max_tokens: 128_000,
max_output_tokens: 4096,
supports_streaming: true,
supports_tools: true,
supports_reasoning: false,
supports_vision: true,
supports_json_mode: true,
supports_embeddings: true,
supports_image_generation: false,
supports_audio_transcription: false,
supports_speech: false,
supports_video_generation: false,
pii_safe: false,
cost_per_1m_input: Some(0.03),
cost_per_1m_output: Some(0.06),
cost_per_1m_pixels: None,
}
}
pub fn claude3_opus() -> Self {
Self {
max_tokens: 200_000,
max_output_tokens: 4096,
supports_streaming: true,
supports_tools: true,
supports_reasoning: false,
supports_vision: true,
supports_json_mode: true,
supports_embeddings: true,
supports_image_generation: false,
supports_audio_transcription: false,
supports_speech: false,
supports_video_generation: false,
pii_safe: false,
cost_per_1m_input: Some(0.015),
cost_per_1m_output: Some(0.075),
cost_per_1m_pixels: None,
}
}
pub fn gemini_pro() -> Self {
Self {
max_tokens: 1_000_000,
max_output_tokens: 8192,
supports_streaming: true,
supports_tools: true,
supports_reasoning: false,
supports_vision: true,
supports_json_mode: true,
supports_embeddings: true,
supports_image_generation: false,
supports_audio_transcription: false,
supports_speech: false,
supports_video_generation: false,
pii_safe: false,
cost_per_1m_input: Some(0.00125),
cost_per_1m_output: Some(0.005),
cost_per_1m_pixels: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatTool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ChatToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatToolFunction {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
String(String),
Specific {
#[serde(rename = "type")]
choice_type: String,
function: ToolChoiceFunction,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: MessageToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageToolCallFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrlContent {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrlContent },
}
impl ContentPart {
pub fn text(text: impl Into<String>) -> Self {
ContentPart::Text { text: text.into() }
}
pub fn image_url(url: impl Into<String>) -> Self {
ContentPart::ImageUrl {
image_url: ImageUrlContent {
url: url.into(),
detail: None,
},
}
}
pub fn image_base64(base64_data: impl Into<String>, mime_type: impl Into<String>) -> Self {
let data_url = format!("data:{};base64,{}", mime_type.into(), base64_data.into());
ContentPart::ImageUrl {
image_url: ImageUrlContent {
url: data_url,
detail: None,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
impl MessageContent {
pub fn has_images(&self) -> bool {
match self {
MessageContent::Text(_) => false,
MessageContent::Parts(parts) => parts
.iter()
.any(|p| matches!(p, ContentPart::ImageUrl { .. })),
}
}
pub fn as_text(&self) -> String {
match self {
MessageContent::Text(s) => s.clone(),
MessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n"),
}
}
}
impl From<String> for MessageContent {
fn from(s: String) -> Self {
MessageContent::Text(s)
}
}
impl From<&str> for MessageContent {
fn from(s: &str) -> Self {
MessageContent::Text(s.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub multimodal_content: Option<Vec<ContentPart>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<MessageToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: Some(content.into()),
multimodal_content: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: Some(content.into()),
multimodal_content: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn user_with_images<S: Into<String>>(
text: S,
images: Vec<(Vec<u8>, String)>, ) -> Self {
use base64::Engine;
let mut parts = vec![ContentPart::text(text)];
for (data, mime_type) in images {
let b64 = base64::engine::general_purpose::STANDARD.encode(&data);
parts.push(ContentPart::image_base64(b64, mime_type));
}
Self {
role: "user".to_string(),
content: None, multimodal_content: Some(parts),
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: Some(content.into()),
multimodal_content: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant_with_tool_calls(
content: Option<String>,
tool_calls: Vec<MessageToolCall>,
) -> Self {
Self {
role: "assistant".to_string(),
content,
multimodal_content: None,
tool_calls: Some(tool_calls),
tool_call_id: None,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: "tool".to_string(),
content: Some(content.into()),
multimodal_content: None,
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
}
}
pub fn has_images(&self) -> bool {
self.multimodal_content
.as_ref()
.map(|parts| {
parts
.iter()
.any(|p| matches!(p, ContentPart::ImageUrl { .. }))
})
.unwrap_or(false)
}
pub fn effective_content(&self) -> MessageContent {
if let Some(parts) = &self.multimodal_content {
MessageContent::Parts(parts.clone())
} else if let Some(text) = &self.content {
MessageContent::Text(text.clone())
} else {
MessageContent::Text(String::new())
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatRequest {
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChatResponse {
pub id: String,
pub choices: Vec<ChatChoice>,
pub usage: Option<ChatUsage>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChatUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingRequest {
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingResponse {
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Option<EmbeddingUsage>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingData {
pub embedding: Vec<f32>,
pub index: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageGenerationRequest {
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub size: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quality: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub style: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageGenerationResponse {
pub created: u64,
pub data: Vec<ImageData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageData {
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub b64_json: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub revised_prompt: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AudioTranscriptionRequest {
pub file: Vec<u8>,
pub filename: String,
pub model: Option<String>,
pub language: Option<String>,
pub prompt: Option<String>,
pub response_format: Option<String>,
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub duration: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<TranscriptionWord>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<TranscriptionSegment>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptionWord {
pub word: String,
pub start: f64,
pub end: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptionSegment {
pub id: u32,
pub start: f64,
pub end: f64,
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpeechRequest {
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
pub voice: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct SpeechResponse {
pub audio: Vec<u8>,
pub content_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoGenerationRequest {
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub duration: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub size: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fps: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub negative_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoGenerationResponse {
pub created: u64,
pub data: Vec<VideoData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoData {
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub b64_json: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub revised_prompt: Option<String>,
}
#[async_trait]
pub trait ModelProvider: Send + Sync {
fn name(&self) -> &str;
fn model(&self) -> &str {
"default"
}
fn capabilities(&self) -> ModelCapabilities {
ModelCapabilities::default()
}
fn requires_network(&self) -> bool {
true
}
async fn chat(&self, request: ChatRequest) -> anyhow::Result<ChatResponse>;
async fn embed(&self, _request: EmbeddingRequest) -> anyhow::Result<EmbeddingResponse> {
anyhow::bail!("Embeddings not supported by this provider")
}
async fn generate_image(
&self,
_request: ImageGenerationRequest,
) -> anyhow::Result<ImageGenerationResponse> {
anyhow::bail!("Image generation not supported by this provider")
}
async fn transcribe(
&self,
_request: AudioTranscriptionRequest,
) -> anyhow::Result<AudioTranscriptionResponse> {
anyhow::bail!("Audio transcription not supported by this provider")
}
async fn speak(&self, _request: SpeechRequest) -> anyhow::Result<SpeechResponse> {
anyhow::bail!("Text-to-speech not supported by this provider")
}
async fn generate_video(
&self,
_request: VideoGenerationRequest,
) -> anyhow::Result<VideoGenerationResponse> {
anyhow::bail!("Video generation not supported by this provider")
}
}