1use std::sync::Arc;
13
14use async_trait::async_trait;
15
16use adk_core::Llm;
17
18use crate::types::{ModelConfig, ModelRef, Provider};
19
20#[derive(Debug, thiserror::Error)]
22pub enum ResolverError {
23 #[error(
25 "cannot infer provider from model name \"{name}\". Expected prefix: gemini, gpt, claude, llama, mistral, or deepseek"
26 )]
27 UnknownProvider { name: String },
28
29 #[error("failed to construct model for provider {provider:?}: {reason}")]
31 ConstructionFailed { provider: Provider, reason: String },
32}
33
34pub type ResolverResult<T> = std::result::Result<T, ResolverError>;
36
37#[async_trait]
54pub trait ModelResolver: Send + Sync {
55 async fn resolve(&self, model_ref: &ModelRef) -> ResolverResult<Arc<dyn Llm>>;
57}
58
59pub fn infer_provider(name: &str) -> ResolverResult<Provider> {
74 let lower = name.to_lowercase();
75 if lower.starts_with("gemini") {
76 Ok(Provider::Gemini)
77 } else if lower.starts_with("gpt") {
78 Ok(Provider::Openai)
79 } else if lower.starts_with("claude") {
80 Ok(Provider::Anthropic)
81 } else if lower.starts_with("llama")
82 || lower.starts_with("mistral")
83 || lower.starts_with("deepseek")
84 {
85 Ok(Provider::Ollama)
86 } else {
87 Err(ResolverError::UnknownProvider { name: name.to_string() })
88 }
89}
90
91#[derive(Debug, Clone, Default)]
112pub struct DefaultModelResolver;
113
114impl DefaultModelResolver {
115 pub fn new() -> Self {
117 Self
118 }
119}
120
121#[async_trait]
122impl ModelResolver for DefaultModelResolver {
123 async fn resolve(&self, model_ref: &ModelRef) -> ResolverResult<Arc<dyn Llm>> {
124 match model_ref {
125 ModelRef::Shorthand(name) => {
126 let provider = infer_provider(name)?;
127 Err(ResolverError::ConstructionFailed {
132 provider,
133 reason: format!(
134 "DefaultModelResolver cannot construct real models. \
135 Use a platform-provided resolver with credentials. \
136 Resolved provider: {provider:?}, model: {name}"
137 ),
138 })
139 }
140 ModelRef::Structured { provider, model, .. } => {
141 let model_name = match model {
142 ModelConfig::Name(name) => name.clone(),
143 ModelConfig::Compatible { model, base_url, .. } => {
144 return Err(ResolverError::ConstructionFailed {
148 provider: *provider,
149 reason: format!(
150 "DefaultModelResolver cannot construct OpenAI-compatible \
151 client. Model: {model}, base_url: {base_url}. \
152 Use a platform-provided resolver with credentials."
153 ),
154 });
155 }
156 };
157
158 Err(ResolverError::ConstructionFailed {
159 provider: *provider,
160 reason: format!(
161 "DefaultModelResolver cannot construct real models. \
162 Use a platform-provided resolver with credentials. \
163 Provider: {provider:?}, model: {model_name}"
164 ),
165 })
166 }
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
178 fn test_infer_gemini_from_shorthand() {
179 assert_eq!(infer_provider("gemini-2.5-flash").unwrap(), Provider::Gemini);
180 assert_eq!(infer_provider("gemini-2.5-pro").unwrap(), Provider::Gemini);
181 assert_eq!(infer_provider("gemini-3.1-flash-lite-preview").unwrap(), Provider::Gemini);
182 }
183
184 #[test]
185 fn test_infer_openai_from_shorthand() {
186 assert_eq!(infer_provider("gpt-4.1").unwrap(), Provider::Openai);
187 assert_eq!(infer_provider("gpt-4o").unwrap(), Provider::Openai);
188 assert_eq!(infer_provider("gpt-4.1-mini").unwrap(), Provider::Openai);
189 }
190
191 #[test]
192 fn test_infer_anthropic_from_shorthand() {
193 assert_eq!(infer_provider("claude-3.5-sonnet").unwrap(), Provider::Anthropic);
194 assert_eq!(infer_provider("claude-4-opus").unwrap(), Provider::Anthropic);
195 }
196
197 #[test]
198 fn test_infer_ollama_from_llama() {
199 assert_eq!(infer_provider("llama-3.2-70b").unwrap(), Provider::Ollama);
200 }
201
202 #[test]
203 fn test_infer_ollama_from_mistral() {
204 assert_eq!(infer_provider("mistral-7b").unwrap(), Provider::Ollama);
205 assert_eq!(infer_provider("mistral-large").unwrap(), Provider::Ollama);
206 }
207
208 #[test]
209 fn test_infer_ollama_from_deepseek() {
210 assert_eq!(infer_provider("deepseek-chat").unwrap(), Provider::Ollama);
211 assert_eq!(infer_provider("deepseek-coder").unwrap(), Provider::Ollama);
212 }
213
214 #[test]
215 fn test_infer_unknown_returns_error() {
216 let result = infer_provider("some-random-model");
217 assert!(result.is_err());
218 match result.unwrap_err() {
219 ResolverError::UnknownProvider { name } => {
220 assert_eq!(name, "some-random-model");
221 }
222 _ => panic!("expected UnknownProvider error"),
223 }
224 }
225
226 #[test]
227 fn test_infer_case_insensitive() {
228 assert_eq!(infer_provider("Gemini-2.5-flash").unwrap(), Provider::Gemini);
229 assert_eq!(infer_provider("GPT-4.1").unwrap(), Provider::Openai);
230 assert_eq!(infer_provider("Claude-3.5-sonnet").unwrap(), Provider::Anthropic);
231 assert_eq!(infer_provider("LLAMA-3.2").unwrap(), Provider::Ollama);
232 assert_eq!(infer_provider("DeepSeek-V3").unwrap(), Provider::Ollama);
233 }
234
235 #[tokio::test]
238 async fn test_resolver_shorthand_gemini_infers_provider() {
239 let resolver = DefaultModelResolver::new();
240 let model_ref = ModelRef::Shorthand("gemini-2.5-flash".to_string());
241 let result = resolver.resolve(&model_ref).await;
242
243 let err = result.err().expect("expected an error");
246 match err {
247 ResolverError::ConstructionFailed { provider, reason } => {
248 assert_eq!(provider, Provider::Gemini);
249 assert!(reason.contains("gemini-2.5-flash"));
250 }
251 e => panic!("expected ConstructionFailed, got: {e}"),
252 }
253 }
254
255 #[tokio::test]
256 async fn test_resolver_shorthand_openai_infers_provider() {
257 let resolver = DefaultModelResolver::new();
258 let model_ref = ModelRef::Shorthand("gpt-4.1".to_string());
259 let result = resolver.resolve(&model_ref).await;
260
261 let err = result.err().expect("expected an error");
262 match err {
263 ResolverError::ConstructionFailed { provider, .. } => {
264 assert_eq!(provider, Provider::Openai);
265 }
266 e => panic!("expected ConstructionFailed, got: {e}"),
267 }
268 }
269
270 #[tokio::test]
271 async fn test_resolver_shorthand_anthropic_infers_provider() {
272 let resolver = DefaultModelResolver::new();
273 let model_ref = ModelRef::Shorthand("claude-3.5-sonnet".to_string());
274 let result = resolver.resolve(&model_ref).await;
275
276 let err = result.err().expect("expected an error");
277 match err {
278 ResolverError::ConstructionFailed { provider, .. } => {
279 assert_eq!(provider, Provider::Anthropic);
280 }
281 e => panic!("expected ConstructionFailed, got: {e}"),
282 }
283 }
284
285 #[tokio::test]
286 async fn test_resolver_shorthand_unknown_returns_unknown_provider() {
287 let resolver = DefaultModelResolver::new();
288 let model_ref = ModelRef::Shorthand("totally-unknown-model".to_string());
289 let result = resolver.resolve(&model_ref).await;
290
291 let err = result.err().expect("expected an error");
292 match err {
293 ResolverError::UnknownProvider { name } => {
294 assert_eq!(name, "totally-unknown-model");
295 }
296 e => panic!("expected UnknownProvider, got: {e}"),
297 }
298 }
299
300 #[tokio::test]
301 async fn test_resolver_structured_uses_provider_field() {
302 let resolver = DefaultModelResolver::new();
303 let model_ref = ModelRef::Structured {
304 provider: Provider::Anthropic,
305 model: ModelConfig::Name("claude-3.5-sonnet".to_string()),
306 speed: None,
307 };
308 let result = resolver.resolve(&model_ref).await;
309
310 let err = result.err().expect("expected an error");
311 match err {
312 ResolverError::ConstructionFailed { provider, reason } => {
313 assert_eq!(provider, Provider::Anthropic);
314 assert!(reason.contains("claude-3.5-sonnet"));
315 }
316 e => panic!("expected ConstructionFailed, got: {e}"),
317 }
318 }
319
320 #[tokio::test]
321 async fn test_resolver_structured_openai_compatible() {
322 let resolver = DefaultModelResolver::new();
323 let model_ref = ModelRef::Structured {
324 provider: Provider::OpenaiCompatible,
325 model: ModelConfig::Compatible {
326 model: "deepseek-chat".to_string(),
327 base_url: "https://api.deepseek.com/v1".to_string(),
328 api_key: "sk-test-key".to_string(),
329 },
330 speed: None,
331 };
332 let result = resolver.resolve(&model_ref).await;
333
334 let err = result.err().expect("expected an error");
335 match err {
336 ResolverError::ConstructionFailed { provider, reason } => {
337 assert_eq!(provider, Provider::OpenaiCompatible);
338 assert!(reason.contains("deepseek-chat"));
339 assert!(reason.contains("https://api.deepseek.com/v1"));
340 }
341 e => panic!("expected ConstructionFailed, got: {e}"),
342 }
343 }
344
345 #[tokio::test]
346 async fn test_resolver_structured_with_speed_hint() {
347 let resolver = DefaultModelResolver::new();
348 let model_ref = ModelRef::Structured {
349 provider: Provider::Gemini,
350 model: ModelConfig::Name("gemini-2.5-flash".to_string()),
351 speed: Some("fast".to_string()),
352 };
353 let result = resolver.resolve(&model_ref).await;
354
355 let err = result.err().expect("expected an error");
357 match err {
358 ResolverError::ConstructionFailed { provider, .. } => {
359 assert_eq!(provider, Provider::Gemini);
360 }
361 e => panic!("expected ConstructionFailed, got: {e}"),
362 }
363 }
364}