Skip to main content

adk_model/
provider.rs

1use std::fmt::{Display, Formatter};
2use std::str::FromStr;
3
4/// Canonical provider identifiers and metadata shared across ADK crates.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6pub enum ModelProvider {
7    /// Google Gemini models.
8    Gemini,
9    /// OpenAI models (GPT, o-series).
10    Openai,
11    /// Anthropic Claude models.
12    Anthropic,
13    /// DeepSeek models.
14    Deepseek,
15    /// Groq ultra-fast inference.
16    Groq,
17    /// Ollama local models.
18    Ollama,
19}
20
21impl ModelProvider {
22    /// All providers in UI/display order.
23    pub const ALL: [Self; 6] =
24        [Self::Gemini, Self::Openai, Self::Anthropic, Self::Deepseek, Self::Groq, Self::Ollama];
25
26    /// Return all providers in a stable order.
27    pub const fn all() -> &'static [Self] {
28        &Self::ALL
29    }
30
31    /// Machine identifier used in CLIs and configs.
32    pub const fn as_str(self) -> &'static str {
33        match self {
34            Self::Gemini => "gemini",
35            Self::Openai => "openai",
36            Self::Anthropic => "anthropic",
37            Self::Deepseek => "deepseek",
38            Self::Groq => "groq",
39            Self::Ollama => "ollama",
40        }
41    }
42
43    /// Default model for the provider.
44    pub const fn default_model(self) -> &'static str {
45        match self {
46            Self::Gemini => "gemini-3.1-flash-lite-preview",
47            Self::Openai => "gpt-5-mini",
48            Self::Anthropic => "claude-sonnet-4-5-20250929",
49            Self::Deepseek => "deepseek-v4-flash",
50            Self::Groq => "llama-3.3-70b-versatile",
51            Self::Ollama => "qwen3.5",
52        }
53    }
54
55    /// Primary environment variable used for the provider API key.
56    pub const fn env_var(self) -> &'static str {
57        match self {
58            Self::Gemini => "GOOGLE_API_KEY",
59            Self::Openai => "OPENAI_API_KEY",
60            Self::Anthropic => "ANTHROPIC_API_KEY",
61            Self::Deepseek => "DEEPSEEK_API_KEY",
62            Self::Groq => "GROQ_API_KEY",
63            Self::Ollama => "",
64        }
65    }
66
67    /// Alternate environment variable used for the provider API key.
68    pub const fn alt_env_var(self) -> Option<&'static str> {
69        match self {
70            Self::Gemini => Some("GEMINI_API_KEY"),
71            _ => None,
72        }
73    }
74
75    /// Whether the provider requires an API key.
76    pub const fn requires_key(self) -> bool {
77        !matches!(self, Self::Ollama)
78    }
79
80    /// Display name for interactive prompts and help text.
81    pub const fn display_name(self) -> &'static str {
82        match self {
83            Self::Gemini => "Gemini (Google)",
84            Self::Openai => "OpenAI",
85            Self::Anthropic => "Anthropic (Claude)",
86            Self::Deepseek => "DeepSeek",
87            Self::Groq => "Groq",
88            Self::Ollama => "Ollama (local, no key needed)",
89        }
90    }
91}
92
93impl Display for ModelProvider {
94    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
95        f.write_str(self.as_str())
96    }
97}
98
99impl FromStr for ModelProvider {
100    type Err = String;
101
102    fn from_str(value: &str) -> Result<Self, Self::Err> {
103        match value {
104            "gemini" => Ok(Self::Gemini),
105            "openai" => Ok(Self::Openai),
106            "anthropic" => Ok(Self::Anthropic),
107            "deepseek" => Ok(Self::Deepseek),
108            "groq" => Ok(Self::Groq),
109            "ollama" => Ok(Self::Ollama),
110            other => Err(format!("unsupported provider: {other}")),
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::ModelProvider;
118    use std::str::FromStr;
119
120    #[test]
121    fn provider_roundtrips_from_machine_name() {
122        for provider in ModelProvider::all() {
123            let parsed = ModelProvider::from_str(provider.as_str()).expect("provider should parse");
124            assert_eq!(*provider, parsed);
125        }
126    }
127}