Skip to main content

agent_core_runtime/agent/
providers.rs

1// Known OpenAI-compatible provider registry
2//
3// Maps provider names to their base URLs and default settings.
4
5/// Information about a known OpenAI-compatible provider.
6#[derive(Debug, Clone)]
7pub struct ProviderInfo {
8    /// Display name for the provider.
9    pub name: &'static str,
10    /// Base URL for the API endpoint.
11    pub base_url: &'static str,
12    /// Default context window size.
13    pub context_limit: i32,
14    /// Default model to use if none specified.
15    pub default_model: &'static str,
16    /// Environment variable for API key.
17    pub env_var: &'static str,
18    /// Environment variable for model override.
19    pub model_env_var: &'static str,
20    /// Whether this provider requires an API key.
21    /// False for local providers like Ollama and LM Studio.
22    pub requires_api_key: bool,
23}
24
25/// Known OpenAI-compatible providers.
26///
27/// These providers use the OpenAI API format but with different base URLs.
28/// Add a provider here to enable using it by name in the config (e.g., `provider: groq`).
29pub const KNOWN_PROVIDERS: &[(&str, ProviderInfo)] = &[
30    (
31        "groq",
32        ProviderInfo {
33            name: "Groq",
34            base_url: "https://api.groq.com/openai/v1",
35            context_limit: 131_072,
36            default_model: "llama-3.3-70b-versatile",
37            env_var: "GROQ_API_KEY",
38            model_env_var: "GROQ_MODEL",
39            requires_api_key: true,
40        },
41    ),
42    (
43        "together",
44        ProviderInfo {
45            name: "Together AI",
46            base_url: "https://api.together.xyz/v1",
47            context_limit: 131_072,
48            default_model: "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
49            env_var: "TOGETHER_API_KEY",
50            model_env_var: "TOGETHER_MODEL",
51            requires_api_key: true,
52        },
53    ),
54    (
55        "fireworks",
56        ProviderInfo {
57            name: "Fireworks AI",
58            base_url: "https://api.fireworks.ai/inference/v1",
59            context_limit: 131_072,
60            default_model: "accounts/fireworks/models/llama-v3p1-70b-instruct",
61            env_var: "FIREWORKS_API_KEY",
62            model_env_var: "FIREWORKS_MODEL",
63            requires_api_key: true,
64        },
65    ),
66    (
67        "mistral",
68        ProviderInfo {
69            name: "Mistral AI",
70            base_url: "https://api.mistral.ai/v1",
71            context_limit: 128_000,
72            default_model: "mistral-large-latest",
73            env_var: "MISTRAL_API_KEY",
74            model_env_var: "MISTRAL_MODEL",
75            requires_api_key: true,
76        },
77    ),
78    (
79        "perplexity",
80        ProviderInfo {
81            name: "Perplexity",
82            base_url: "https://api.perplexity.ai/chat/completions",
83            context_limit: 128_000,
84            default_model: "llama-3.1-sonar-large-128k-online",
85            env_var: "PERPLEXITY_API_KEY",
86            model_env_var: "PERPLEXITY_MODEL",
87            requires_api_key: true,
88        },
89    ),
90    (
91        "deepseek",
92        ProviderInfo {
93            name: "DeepSeek",
94            base_url: "https://api.deepseek.com/v1",
95            context_limit: 64_000,
96            default_model: "deepseek-chat",
97            env_var: "DEEPSEEK_API_KEY",
98            model_env_var: "DEEPSEEK_MODEL",
99            requires_api_key: true,
100        },
101    ),
102    (
103        "openrouter",
104        ProviderInfo {
105            name: "OpenRouter",
106            base_url: "https://openrouter.ai/api/v1",
107            context_limit: 200_000,
108            default_model: "anthropic/claude-3.5-sonnet",
109            env_var: "OPENROUTER_API_KEY",
110            model_env_var: "OPENROUTER_MODEL",
111            requires_api_key: true,
112        },
113    ),
114    (
115        "ollama",
116        ProviderInfo {
117            name: "Ollama",
118            base_url: "http://localhost:11434/v1",
119            context_limit: 128_000,
120            default_model: "llama3.1",
121            env_var: "OLLAMA_HOST", // Not an API key - just signals to enable Ollama
122            model_env_var: "OLLAMA_MODEL",
123            requires_api_key: false,
124        },
125    ),
126    (
127        "lmstudio",
128        ProviderInfo {
129            name: "LM Studio",
130            base_url: "http://localhost:1234/v1",
131            context_limit: 128_000,
132            default_model: "local-model",
133            env_var: "LMSTUDIO_HOST", // Not an API key - just signals to enable LM Studio
134            model_env_var: "LMSTUDIO_MODEL",
135            requires_api_key: false,
136        },
137    ),
138    (
139        "anyscale",
140        ProviderInfo {
141            name: "Anyscale",
142            base_url: "https://api.endpoints.anyscale.com/v1",
143            context_limit: 128_000,
144            default_model: "meta-llama/Meta-Llama-3.1-70B-Instruct",
145            env_var: "ANYSCALE_API_KEY",
146            model_env_var: "ANYSCALE_MODEL",
147            requires_api_key: true,
148        },
149    ),
150    (
151        "cerebras",
152        ProviderInfo {
153            name: "Cerebras",
154            base_url: "https://api.cerebras.ai/v1",
155            context_limit: 128_000,
156            default_model: "llama3.1-70b",
157            env_var: "CEREBRAS_API_KEY",
158            model_env_var: "CEREBRAS_MODEL",
159            requires_api_key: true,
160        },
161    ),
162    (
163        "sambanova",
164        ProviderInfo {
165            name: "SambaNova",
166            base_url: "https://api.sambanova.ai/v1",
167            context_limit: 128_000,
168            default_model: "Meta-Llama-3.1-70B-Instruct",
169            env_var: "SAMBANOVA_API_KEY",
170            model_env_var: "SAMBANOVA_MODEL",
171            requires_api_key: true,
172        },
173    ),
174    (
175        "xai",
176        ProviderInfo {
177            name: "xAI",
178            base_url: "https://api.x.ai/v1",
179            context_limit: 131_072,
180            default_model: "grok-2-latest",
181            env_var: "XAI_API_KEY",
182            model_env_var: "XAI_MODEL",
183            requires_api_key: true,
184        },
185    ),
186    (
187        "ai21",
188        ProviderInfo {
189            name: "AI21 Labs",
190            base_url: "https://api.ai21.com/studio/v1",
191            context_limit: 256_000,
192            default_model: "jamba-1.5-large",
193            env_var: "AI21_API_KEY",
194            model_env_var: "AI21_MODEL",
195            requires_api_key: true,
196        },
197    ),
198];
199
200/// Returns provider info for a known provider name.
201///
202/// Provider names are case-insensitive.
203pub fn get_provider_info(name: &str) -> Option<&'static ProviderInfo> {
204    let name_lower = name.to_lowercase();
205    KNOWN_PROVIDERS
206        .iter()
207        .find(|(key, _)| *key == name_lower)
208        .map(|(_, info)| info)
209}
210
211/// Returns true if this is a known OpenAI-compatible provider.
212pub fn is_known_provider(name: &str) -> bool {
213    get_provider_info(name).is_some()
214}
215
216/// Returns a list of all known provider names.
217pub fn list_providers() -> Vec<&'static str> {
218    KNOWN_PROVIDERS.iter().map(|(name, _)| *name).collect()
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_get_known_provider() {
227        let info = get_provider_info("groq").unwrap();
228        assert_eq!(info.name, "Groq");
229        assert!(info.base_url.contains("groq.com"));
230    }
231
232    #[test]
233    fn test_case_insensitive() {
234        assert!(get_provider_info("GROQ").is_some());
235        assert!(get_provider_info("Groq").is_some());
236        assert!(get_provider_info("groq").is_some());
237    }
238
239    #[test]
240    fn test_unknown_provider() {
241        assert!(get_provider_info("unknown").is_none());
242    }
243
244    #[test]
245    fn test_list_providers() {
246        let providers = list_providers();
247        assert!(providers.contains(&"groq"));
248        assert!(providers.contains(&"together"));
249        assert!(providers.contains(&"fireworks"));
250    }
251}