rustic_ai/providers/
mod.rs1use std::sync::Arc;
2
3use thiserror::Error;
4
5use crate::model::{Model, ModelSettings};
6
7pub mod anthropic;
8pub mod gemini;
9pub mod grok;
10pub mod openai;
11
12pub trait Provider: Send + Sync {
13 fn name(&self) -> &str;
14 fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model>;
15}
16
17#[derive(Debug, Error)]
18pub enum ProviderError {
19 #[error("unknown provider: {0}")]
20 UnknownProvider(String),
21 #[error("missing API key for provider: {0}")]
22 MissingApiKey(String),
23 #[error("invalid model string: {0}")]
24 InvalidModel(String),
25}
26
27pub fn infer_provider(name: &str) -> Result<Box<dyn Provider>, ProviderError> {
28 match name {
29 "openai" => openai::OpenAIProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
30 "grok" => grok::GrokProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
31 "anthropic" => {
32 anthropic::AnthropicProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>)
33 }
34 "gemini" => gemini::GeminiProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
35 other => Err(ProviderError::UnknownProvider(other.to_string())),
36 }
37}
38
39pub fn infer_model(
40 model: impl AsRef<str>,
41 provider_factory: impl Fn(&str) -> Result<Box<dyn Provider>, ProviderError>,
42) -> Result<Arc<dyn Model>, ProviderError> {
43 let model = model.as_ref();
44 let (provider_name, model_name) = match model.split_once(':') {
45 Some((provider, name)) => (provider, name),
46 None => (infer_provider_from_model(model)?, model),
47 };
48
49 let provider = provider_factory(provider_name)?;
50 Ok(provider.model(model_name, None))
51}
52
53fn infer_provider_from_model(model: &str) -> Result<&'static str, ProviderError> {
54 let lowered = model.to_lowercase();
55 if lowered.starts_with("gpt") || lowered.starts_with("o1") || lowered.starts_with("o3") {
56 return Ok("openai");
57 }
58 if lowered.starts_with("claude") {
59 return Ok("anthropic");
60 }
61 if lowered.starts_with("gemini") {
62 return Ok("gemini");
63 }
64 if lowered.starts_with("grok") {
65 return Ok("grok");
66 }
67 Err(ProviderError::InvalidModel(model.to_string()))
68}