1use async_trait::async_trait;
54use serde::{Deserialize, Serialize};
55use serde_json::Value;
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ModelCapabilities {
68 pub max_tokens: u32,
70 pub max_output_tokens: u32,
72 pub supports_streaming: bool,
74 pub supports_tools: bool,
76 pub supports_reasoning: bool,
78 pub supports_vision: bool,
80 pub supports_json_mode: bool,
82 pub supports_embeddings: bool,
84 pub supports_image_generation: bool,
86 pub supports_audio_transcription: bool,
88 pub supports_speech: bool,
90 pub supports_video_generation: bool,
92 pub pii_safe: bool,
94 pub cost_per_1m_input: Option<f64>,
96 pub cost_per_1m_output: Option<f64>,
98 pub cost_per_1m_pixels: Option<f64>,
100}
101
102impl Default for ModelCapabilities {
103 fn default() -> Self {
104 Self {
105 max_tokens: 4096,
106 max_output_tokens: 4096,
107 supports_streaming: true,
108 supports_tools: false,
109 supports_reasoning: false,
110 supports_vision: false,
111 supports_json_mode: false,
112 supports_embeddings: false,
113 supports_image_generation: false,
114 supports_audio_transcription: false,
115 supports_speech: false,
116 supports_video_generation: false,
117 pii_safe: false,
118 cost_per_1m_input: None,
119 cost_per_1m_output: None,
120 cost_per_1m_pixels: None,
121 }
122 }
123}
124
125impl ModelCapabilities {
126 pub fn gpt4() -> Self {
128 Self {
129 max_tokens: 128_000,
130 max_output_tokens: 4096,
131 supports_streaming: true,
132 supports_tools: true,
133 supports_reasoning: false,
134 supports_vision: true,
135 supports_json_mode: true,
136 supports_embeddings: true,
137 supports_image_generation: false,
138 supports_audio_transcription: false,
139 supports_speech: false,
140 supports_video_generation: false,
141 pii_safe: false,
142 cost_per_1m_input: Some(0.03),
143 cost_per_1m_output: Some(0.06),
144 cost_per_1m_pixels: None,
145 }
146 }
147
148 pub fn claude3_opus() -> Self {
150 Self {
151 max_tokens: 200_000,
152 max_output_tokens: 4096,
153 supports_streaming: true,
154 supports_tools: true,
155 supports_reasoning: false,
156 supports_vision: true,
157 supports_json_mode: true,
158 supports_embeddings: true,
159 supports_image_generation: false,
160 supports_audio_transcription: false,
161 supports_speech: false,
162 supports_video_generation: false,
163 pii_safe: false,
164 cost_per_1m_input: Some(0.015),
165 cost_per_1m_output: Some(0.075),
166 cost_per_1m_pixels: None,
167 }
168 }
169
170 pub fn gemini_pro() -> Self {
172 Self {
173 max_tokens: 1_000_000,
174 max_output_tokens: 8192,
175 supports_streaming: true,
176 supports_tools: true,
177 supports_reasoning: false,
178 supports_vision: true,
179 supports_json_mode: true,
180 supports_embeddings: true,
181 supports_image_generation: false,
182 supports_audio_transcription: false,
183 supports_speech: false,
184 supports_video_generation: false,
185 pii_safe: false,
186 cost_per_1m_input: Some(0.00125),
187 cost_per_1m_output: Some(0.005),
188 cost_per_1m_pixels: None,
189 }
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct ChatTool {
200 #[serde(rename = "type")]
201 pub tool_type: String,
202 pub function: ChatToolFunction,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ChatToolFunction {
208 pub name: String,
209 pub description: String,
210 pub parameters: Value,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215#[serde(untagged)]
216pub enum ToolChoice {
217 String(String),
218 Specific {
219 #[serde(rename = "type")]
220 choice_type: String,
221 function: ToolChoiceFunction,
222 },
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct ToolChoiceFunction {
227 pub name: String,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct MessageToolCall {
233 pub id: String,
234 #[serde(rename = "type")]
235 pub call_type: String,
236 pub function: MessageToolCallFunction,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct MessageToolCallFunction {
241 pub name: String,
242 pub arguments: String,
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct ImageUrlContent {
252 pub url: String,
254 #[serde(skip_serializing_if = "Option::is_none")]
256 pub detail: Option<String>,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
261#[serde(tag = "type", rename_all = "snake_case")]
262pub enum ContentPart {
263 Text { text: String },
265 ImageUrl { image_url: ImageUrlContent },
267}
268
269impl ContentPart {
270 pub fn text(text: impl Into<String>) -> Self {
272 ContentPart::Text { text: text.into() }
273 }
274
275 pub fn image_url(url: impl Into<String>) -> Self {
277 ContentPart::ImageUrl {
278 image_url: ImageUrlContent {
279 url: url.into(),
280 detail: None,
281 },
282 }
283 }
284
285 pub fn image_base64(base64_data: impl Into<String>, mime_type: impl Into<String>) -> Self {
287 let data_url = format!("data:{};base64,{}", mime_type.into(), base64_data.into());
288 ContentPart::ImageUrl {
289 image_url: ImageUrlContent {
290 url: data_url,
291 detail: None,
292 },
293 }
294 }
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299#[serde(untagged)]
300pub enum MessageContent {
301 Text(String),
303 Parts(Vec<ContentPart>),
305}
306
307impl MessageContent {
308 pub fn has_images(&self) -> bool {
310 match self {
311 MessageContent::Text(_) => false,
312 MessageContent::Parts(parts) => parts
313 .iter()
314 .any(|p| matches!(p, ContentPart::ImageUrl { .. })),
315 }
316 }
317
318 pub fn as_text(&self) -> String {
320 match self {
321 MessageContent::Text(s) => s.clone(),
322 MessageContent::Parts(parts) => parts
323 .iter()
324 .filter_map(|p| match p {
325 ContentPart::Text { text } => Some(text.as_str()),
326 _ => None,
327 })
328 .collect::<Vec<_>>()
329 .join("\n"),
330 }
331 }
332}
333
334impl From<String> for MessageContent {
335 fn from(s: String) -> Self {
336 MessageContent::Text(s)
337 }
338}
339
340impl From<&str> for MessageContent {
341 fn from(s: &str) -> Self {
342 MessageContent::Text(s.to_string())
343 }
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct ChatMessage {
349 pub role: String,
350 #[serde(skip_serializing_if = "Option::is_none")]
352 pub content: Option<String>,
353 #[serde(skip_serializing_if = "Option::is_none")]
355 pub multimodal_content: Option<Vec<ContentPart>>,
356 #[serde(skip_serializing_if = "Option::is_none")]
357 pub tool_calls: Option<Vec<MessageToolCall>>,
358 #[serde(skip_serializing_if = "Option::is_none")]
359 pub tool_call_id: Option<String>,
360}
361
362impl ChatMessage {
363 pub fn system(content: impl Into<String>) -> Self {
364 Self {
365 role: "system".to_string(),
366 content: Some(content.into()),
367 multimodal_content: None,
368 tool_calls: None,
369 tool_call_id: None,
370 }
371 }
372
373 pub fn user(content: impl Into<String>) -> Self {
374 Self {
375 role: "user".to_string(),
376 content: Some(content.into()),
377 multimodal_content: None,
378 tool_calls: None,
379 tool_call_id: None,
380 }
381 }
382
383 pub fn user_with_images<S: Into<String>>(
396 text: S,
397 images: Vec<(Vec<u8>, String)>, ) -> Self {
399 use base64::Engine;
400 let mut parts = vec![ContentPart::text(text)];
401
402 for (data, mime_type) in images {
403 let b64 = base64::engine::general_purpose::STANDARD.encode(&data);
404 parts.push(ContentPart::image_base64(b64, mime_type));
405 }
406
407 Self {
408 role: "user".to_string(),
409 content: None, multimodal_content: Some(parts),
411 tool_calls: None,
412 tool_call_id: None,
413 }
414 }
415
416 pub fn assistant(content: impl Into<String>) -> Self {
417 Self {
418 role: "assistant".to_string(),
419 content: Some(content.into()),
420 multimodal_content: None,
421 tool_calls: None,
422 tool_call_id: None,
423 }
424 }
425
426 pub fn assistant_with_tool_calls(
428 content: Option<String>,
429 tool_calls: Vec<MessageToolCall>,
430 ) -> Self {
431 Self {
432 role: "assistant".to_string(),
433 content,
434 multimodal_content: None,
435 tool_calls: Some(tool_calls),
436 tool_call_id: None,
437 }
438 }
439
440 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
442 Self {
443 role: "tool".to_string(),
444 content: Some(content.into()),
445 multimodal_content: None,
446 tool_calls: None,
447 tool_call_id: Some(tool_call_id.into()),
448 }
449 }
450
451 pub fn has_images(&self) -> bool {
453 self.multimodal_content
454 .as_ref()
455 .map(|parts| {
456 parts
457 .iter()
458 .any(|p| matches!(p, ContentPart::ImageUrl { .. }))
459 })
460 .unwrap_or(false)
461 }
462
463 pub fn effective_content(&self) -> MessageContent {
466 if let Some(parts) = &self.multimodal_content {
467 MessageContent::Parts(parts.clone())
468 } else if let Some(text) = &self.content {
469 MessageContent::Text(text.clone())
470 } else {
471 MessageContent::Text(String::new())
472 }
473 }
474}
475
476#[derive(Debug, Clone, Serialize)]
478pub struct ChatRequest {
479 pub messages: Vec<ChatMessage>,
480 #[serde(skip_serializing_if = "Option::is_none")]
481 pub max_tokens: Option<u32>,
482 #[serde(skip_serializing_if = "Option::is_none")]
483 pub temperature: Option<f32>,
484 #[serde(skip_serializing_if = "Option::is_none")]
485 pub tools: Option<Vec<ChatTool>>,
486 #[serde(skip_serializing_if = "Option::is_none")]
487 pub tool_choice: Option<ToolChoice>,
488}
489
490#[derive(Debug, Clone, Deserialize)]
492pub struct ChatResponse {
493 pub id: String,
494 pub choices: Vec<ChatChoice>,
495 pub usage: Option<ChatUsage>,
496}
497
498#[derive(Debug, Clone, Deserialize)]
499pub struct ChatChoice {
500 pub index: u32,
501 pub message: ChatMessage,
502 pub finish_reason: Option<String>,
503}
504
505#[derive(Debug, Clone, Deserialize)]
506pub struct ChatUsage {
507 pub prompt_tokens: u32,
508 pub completion_tokens: u32,
509 pub total_tokens: u32,
510}
511
512#[derive(Debug, Clone, Serialize)]
518pub struct EmbeddingRequest {
519 pub input: String,
521 #[serde(skip_serializing_if = "Option::is_none")]
523 pub model: Option<String>,
524}
525
526#[derive(Debug, Clone, Deserialize)]
528pub struct EmbeddingResponse {
529 pub data: Vec<EmbeddingData>,
531 pub model: String,
533 pub usage: Option<EmbeddingUsage>,
535}
536
537#[derive(Debug, Clone, Deserialize)]
538pub struct EmbeddingData {
539 pub embedding: Vec<f32>,
541 pub index: u32,
543}
544
545#[derive(Debug, Clone, Deserialize)]
546pub struct EmbeddingUsage {
547 pub prompt_tokens: u32,
549 pub total_tokens: u32,
551}
552
553#[derive(Debug, Clone, Serialize, Deserialize)]
559pub struct ImageGenerationRequest {
560 pub prompt: String,
562 #[serde(skip_serializing_if = "Option::is_none")]
564 pub model: Option<String>,
565 #[serde(skip_serializing_if = "Option::is_none")]
567 pub n: Option<u32>,
568 #[serde(skip_serializing_if = "Option::is_none")]
570 pub size: Option<String>,
571 #[serde(skip_serializing_if = "Option::is_none")]
573 pub quality: Option<String>,
574 #[serde(skip_serializing_if = "Option::is_none")]
576 pub style: Option<String>,
577 #[serde(skip_serializing_if = "Option::is_none")]
579 pub response_format: Option<String>,
580 #[serde(skip_serializing_if = "Option::is_none")]
582 pub user: Option<String>,
583}
584
585#[derive(Debug, Clone, Serialize, Deserialize)]
587pub struct ImageGenerationResponse {
588 pub created: u64,
590 pub data: Vec<ImageData>,
592}
593
594#[derive(Debug, Clone, Serialize, Deserialize)]
596pub struct ImageData {
597 #[serde(skip_serializing_if = "Option::is_none")]
599 pub url: Option<String>,
600 #[serde(skip_serializing_if = "Option::is_none")]
602 pub b64_json: Option<String>,
603 #[serde(skip_serializing_if = "Option::is_none")]
605 pub revised_prompt: Option<String>,
606}
607
608#[derive(Debug, Clone)]
614pub struct AudioTranscriptionRequest {
615 pub file: Vec<u8>,
617 pub filename: String,
619 pub model: Option<String>,
621 pub language: Option<String>,
623 pub prompt: Option<String>,
625 pub response_format: Option<String>,
627 pub temperature: Option<f32>,
629}
630
631#[derive(Debug, Clone, Serialize, Deserialize)]
633pub struct AudioTranscriptionResponse {
634 pub text: String,
636 #[serde(skip_serializing_if = "Option::is_none")]
638 pub task: Option<String>,
639 #[serde(skip_serializing_if = "Option::is_none")]
641 pub language: Option<String>,
642 #[serde(skip_serializing_if = "Option::is_none")]
644 pub duration: Option<f64>,
645 #[serde(skip_serializing_if = "Option::is_none")]
647 pub words: Option<Vec<TranscriptionWord>>,
648 #[serde(skip_serializing_if = "Option::is_none")]
650 pub segments: Option<Vec<TranscriptionSegment>>,
651}
652
653#[derive(Debug, Clone, Serialize, Deserialize)]
655pub struct TranscriptionWord {
656 pub word: String,
658 pub start: f64,
660 pub end: f64,
662}
663
664#[derive(Debug, Clone, Serialize, Deserialize)]
666pub struct TranscriptionSegment {
667 pub id: u32,
669 pub start: f64,
671 pub end: f64,
673 pub text: String,
675}
676
677#[derive(Debug, Clone, Serialize, Deserialize)]
683pub struct SpeechRequest {
684 pub input: String,
686 #[serde(skip_serializing_if = "Option::is_none")]
688 pub model: Option<String>,
689 pub voice: String,
691 #[serde(skip_serializing_if = "Option::is_none")]
693 pub response_format: Option<String>,
694 #[serde(skip_serializing_if = "Option::is_none")]
696 pub speed: Option<f32>,
697}
698
699#[derive(Debug, Clone)]
701pub struct SpeechResponse {
702 pub audio: Vec<u8>,
704 pub content_type: String,
706}
707
708#[derive(Debug, Clone, Serialize, Deserialize)]
714pub struct VideoGenerationRequest {
715 pub prompt: String,
717 #[serde(skip_serializing_if = "Option::is_none")]
719 pub model: Option<String>,
720 #[serde(skip_serializing_if = "Option::is_none")]
722 pub duration: Option<f32>,
723 #[serde(skip_serializing_if = "Option::is_none")]
725 pub size: Option<String>,
726 #[serde(skip_serializing_if = "Option::is_none")]
728 pub fps: Option<u32>,
729 #[serde(skip_serializing_if = "Option::is_none")]
731 pub image: Option<String>,
732 #[serde(skip_serializing_if = "Option::is_none")]
734 pub negative_prompt: Option<String>,
735 #[serde(skip_serializing_if = "Option::is_none")]
737 pub seed: Option<u64>,
738}
739
740#[derive(Debug, Clone, Serialize, Deserialize)]
742pub struct VideoGenerationResponse {
743 pub created: u64,
745 pub data: Vec<VideoData>,
747}
748
749#[derive(Debug, Clone, Serialize, Deserialize)]
751pub struct VideoData {
752 #[serde(skip_serializing_if = "Option::is_none")]
754 pub url: Option<String>,
755 #[serde(skip_serializing_if = "Option::is_none")]
757 pub b64_json: Option<String>,
758 #[serde(skip_serializing_if = "Option::is_none")]
760 pub revised_prompt: Option<String>,
761}
762
763#[async_trait]
765pub trait ModelProvider: Send + Sync {
766 fn name(&self) -> &str;
768
769 fn model(&self) -> &str {
771 "default"
772 }
773
774 fn capabilities(&self) -> ModelCapabilities {
776 ModelCapabilities::default()
777 }
778
779 fn requires_network(&self) -> bool {
783 true
784 }
785
786 async fn chat(&self, request: ChatRequest) -> anyhow::Result<ChatResponse>;
788
789 async fn embed(&self, _request: EmbeddingRequest) -> anyhow::Result<EmbeddingResponse> {
794 anyhow::bail!("Embeddings not supported by this provider")
795 }
796
797 async fn generate_image(
802 &self,
803 _request: ImageGenerationRequest,
804 ) -> anyhow::Result<ImageGenerationResponse> {
805 anyhow::bail!("Image generation not supported by this provider")
806 }
807
808 async fn transcribe(
813 &self,
814 _request: AudioTranscriptionRequest,
815 ) -> anyhow::Result<AudioTranscriptionResponse> {
816 anyhow::bail!("Audio transcription not supported by this provider")
817 }
818
819 async fn speak(&self, _request: SpeechRequest) -> anyhow::Result<SpeechResponse> {
824 anyhow::bail!("Text-to-speech not supported by this provider")
825 }
826
827 async fn generate_video(
832 &self,
833 _request: VideoGenerationRequest,
834 ) -> anyhow::Result<VideoGenerationResponse> {
835 anyhow::bail!("Video generation not supported by this provider")
836 }
837}