Skip to main content

adk_managed/
resolver.rs

1//! Model resolution — maps a [`ModelRef`] to a live `Arc<dyn Llm>`.
2//!
3//! The [`ModelResolver`] trait abstracts over model construction so that
4//! the runtime can resolve any provider declaration into a callable model.
5//! [`DefaultModelResolver`] implements the standard resolution logic:
6//!
7//! - **Shorthand** names are mapped to providers by prefix (`gemini-*` → Gemini,
8//!   `gpt-*` → OpenAI, `claude-*` → Anthropic, etc.)
9//! - **Structured** refs use the explicit `provider` field directly
10//! - **OpenAI-compatible** refs construct a client with the given `base_url` + `api_key`
11
12use std::sync::Arc;
13
14use async_trait::async_trait;
15
16use adk_core::Llm;
17
18use crate::types::{ModelConfig, ModelRef, Provider};
19
20/// Errors that can occur during model resolution.
21#[derive(Debug, thiserror::Error)]
22pub enum ResolverError {
23    /// The model name prefix could not be mapped to a known provider.
24    #[error(
25        "cannot infer provider from model name \"{name}\". Expected prefix: gemini, gpt, claude, llama, mistral, or deepseek"
26    )]
27    UnknownProvider { name: String },
28
29    /// The provider is recognized but model construction failed.
30    #[error("failed to construct model for provider {provider:?}: {reason}")]
31    ConstructionFailed { provider: Provider, reason: String },
32}
33
34/// Result alias for resolver operations.
35pub type ResolverResult<T> = std::result::Result<T, ResolverError>;
36
37/// Resolves a [`ModelRef`] into a live `Arc<dyn Llm>`.
38///
39/// Implementations may construct real provider clients or return pre-built
40/// instances. The trait is async because construction may involve network
41/// calls (e.g., verifying API keys or fetching model metadata).
42///
43/// # Example
44///
45/// ```rust,ignore
46/// use adk_managed::resolver::{ModelResolver, DefaultModelResolver};
47/// use adk_managed::types::ModelRef;
48///
49/// let resolver = DefaultModelResolver::new();
50/// let model_ref = ModelRef::Shorthand("gemini-2.5-flash".to_string());
51/// let llm = resolver.resolve(&model_ref).await?;
52/// ```
53#[async_trait]
54pub trait ModelResolver: Send + Sync {
55    /// Resolve a model reference into a callable LLM instance.
56    async fn resolve(&self, model_ref: &ModelRef) -> ResolverResult<Arc<dyn Llm>>;
57}
58
59/// Infers the [`Provider`] from a shorthand model name by prefix matching.
60///
61/// # Mapping
62///
63/// | Prefix | Provider |
64/// |--------|----------|
65/// | `gemini` | Gemini |
66/// | `gpt` | OpenAI |
67/// | `claude` | Anthropic |
68/// | `llama` | Ollama |
69/// | `mistral` | Ollama |
70/// | `deepseek` | Ollama |
71///
72/// Returns `Err(ResolverError::UnknownProvider)` if no prefix matches.
73pub fn infer_provider(name: &str) -> ResolverResult<Provider> {
74    let lower = name.to_lowercase();
75    if lower.starts_with("gemini") {
76        Ok(Provider::Gemini)
77    } else if lower.starts_with("gpt") {
78        Ok(Provider::Openai)
79    } else if lower.starts_with("claude") {
80        Ok(Provider::Anthropic)
81    } else if lower.starts_with("llama")
82        || lower.starts_with("mistral")
83        || lower.starts_with("deepseek")
84    {
85        Ok(Provider::Ollama)
86    } else {
87        Err(ResolverError::UnknownProvider { name: name.to_string() })
88    }
89}
90
91/// Default model resolver that uses prefix-based provider inference for
92/// shorthand names and explicit provider fields for structured refs.
93///
94/// # Construction Behavior
95///
96/// The `DefaultModelResolver` currently returns a [`ResolverError::ConstructionFailed`]
97/// for all resolved providers because actual model construction requires API keys
98/// and network access. The important logic here is the *resolution* — mapping a
99/// `ModelRef` to the correct provider. The platform layer is responsible for
100/// injecting a custom `ModelResolver` that can actually construct models with
101/// credentials.
102///
103/// # Example
104///
105/// ```rust,ignore
106/// use adk_managed::resolver::DefaultModelResolver;
107///
108/// let resolver = DefaultModelResolver::new();
109/// // In production, use a resolver that has access to credentials.
110/// ```
111#[derive(Debug, Clone, Default)]
112pub struct DefaultModelResolver;
113
114impl DefaultModelResolver {
115    /// Create a new `DefaultModelResolver`.
116    pub fn new() -> Self {
117        Self
118    }
119}
120
121#[async_trait]
122impl ModelResolver for DefaultModelResolver {
123    async fn resolve(&self, model_ref: &ModelRef) -> ResolverResult<Arc<dyn Llm>> {
124        match model_ref {
125            ModelRef::Shorthand(name) => {
126                let provider = infer_provider(name)?;
127                // In a real implementation, this would construct the appropriate
128                // provider client using API keys from the environment or a
129                // credential provider. For now, we return an error indicating
130                // that actual construction is not yet wired up.
131                Err(ResolverError::ConstructionFailed {
132                    provider,
133                    reason: format!(
134                        "DefaultModelResolver cannot construct real models. \
135                         Use a platform-provided resolver with credentials. \
136                         Resolved provider: {provider:?}, model: {name}"
137                    ),
138                })
139            }
140            ModelRef::Structured { provider, model, .. } => {
141                let model_name = match model {
142                    ModelConfig::Name(name) => name.clone(),
143                    ModelConfig::Compatible { model, base_url, .. } => {
144                        // For OpenAI-compatible, we have all we need to construct
145                        // a client (model + base_url + api_key), but actual
146                        // construction is deferred to a credentialed resolver.
147                        return Err(ResolverError::ConstructionFailed {
148                            provider: *provider,
149                            reason: format!(
150                                "DefaultModelResolver cannot construct OpenAI-compatible \
151                                 client. Model: {model}, base_url: {base_url}. \
152                                 Use a platform-provided resolver with credentials."
153                            ),
154                        });
155                    }
156                };
157
158                Err(ResolverError::ConstructionFailed {
159                    provider: *provider,
160                    reason: format!(
161                        "DefaultModelResolver cannot construct real models. \
162                         Use a platform-provided resolver with credentials. \
163                         Provider: {provider:?}, model: {model_name}"
164                    ),
165                })
166            }
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    // --- infer_provider tests ---
176
177    #[test]
178    fn test_infer_gemini_from_shorthand() {
179        assert_eq!(infer_provider("gemini-2.5-flash").unwrap(), Provider::Gemini);
180        assert_eq!(infer_provider("gemini-2.5-pro").unwrap(), Provider::Gemini);
181        assert_eq!(infer_provider("gemini-3.1-flash-lite-preview").unwrap(), Provider::Gemini);
182    }
183
184    #[test]
185    fn test_infer_openai_from_shorthand() {
186        assert_eq!(infer_provider("gpt-4.1").unwrap(), Provider::Openai);
187        assert_eq!(infer_provider("gpt-4o").unwrap(), Provider::Openai);
188        assert_eq!(infer_provider("gpt-4.1-mini").unwrap(), Provider::Openai);
189    }
190
191    #[test]
192    fn test_infer_anthropic_from_shorthand() {
193        assert_eq!(infer_provider("claude-3.5-sonnet").unwrap(), Provider::Anthropic);
194        assert_eq!(infer_provider("claude-4-opus").unwrap(), Provider::Anthropic);
195    }
196
197    #[test]
198    fn test_infer_ollama_from_llama() {
199        assert_eq!(infer_provider("llama-3.2-70b").unwrap(), Provider::Ollama);
200    }
201
202    #[test]
203    fn test_infer_ollama_from_mistral() {
204        assert_eq!(infer_provider("mistral-7b").unwrap(), Provider::Ollama);
205        assert_eq!(infer_provider("mistral-large").unwrap(), Provider::Ollama);
206    }
207
208    #[test]
209    fn test_infer_ollama_from_deepseek() {
210        assert_eq!(infer_provider("deepseek-chat").unwrap(), Provider::Ollama);
211        assert_eq!(infer_provider("deepseek-coder").unwrap(), Provider::Ollama);
212    }
213
214    #[test]
215    fn test_infer_unknown_returns_error() {
216        let result = infer_provider("some-random-model");
217        assert!(result.is_err());
218        match result.unwrap_err() {
219            ResolverError::UnknownProvider { name } => {
220                assert_eq!(name, "some-random-model");
221            }
222            _ => panic!("expected UnknownProvider error"),
223        }
224    }
225
226    #[test]
227    fn test_infer_case_insensitive() {
228        assert_eq!(infer_provider("Gemini-2.5-flash").unwrap(), Provider::Gemini);
229        assert_eq!(infer_provider("GPT-4.1").unwrap(), Provider::Openai);
230        assert_eq!(infer_provider("Claude-3.5-sonnet").unwrap(), Provider::Anthropic);
231        assert_eq!(infer_provider("LLAMA-3.2").unwrap(), Provider::Ollama);
232        assert_eq!(infer_provider("DeepSeek-V3").unwrap(), Provider::Ollama);
233    }
234
235    // --- DefaultModelResolver tests ---
236
237    #[tokio::test]
238    async fn test_resolver_shorthand_gemini_infers_provider() {
239        let resolver = DefaultModelResolver::new();
240        let model_ref = ModelRef::Shorthand("gemini-2.5-flash".to_string());
241        let result = resolver.resolve(&model_ref).await;
242
243        // We expect ConstructionFailed (not UnknownProvider) because the
244        // provider was successfully inferred but construction is stubbed.
245        let err = result.err().expect("expected an error");
246        match err {
247            ResolverError::ConstructionFailed { provider, reason } => {
248                assert_eq!(provider, Provider::Gemini);
249                assert!(reason.contains("gemini-2.5-flash"));
250            }
251            e => panic!("expected ConstructionFailed, got: {e}"),
252        }
253    }
254
255    #[tokio::test]
256    async fn test_resolver_shorthand_openai_infers_provider() {
257        let resolver = DefaultModelResolver::new();
258        let model_ref = ModelRef::Shorthand("gpt-4.1".to_string());
259        let result = resolver.resolve(&model_ref).await;
260
261        let err = result.err().expect("expected an error");
262        match err {
263            ResolverError::ConstructionFailed { provider, .. } => {
264                assert_eq!(provider, Provider::Openai);
265            }
266            e => panic!("expected ConstructionFailed, got: {e}"),
267        }
268    }
269
270    #[tokio::test]
271    async fn test_resolver_shorthand_anthropic_infers_provider() {
272        let resolver = DefaultModelResolver::new();
273        let model_ref = ModelRef::Shorthand("claude-3.5-sonnet".to_string());
274        let result = resolver.resolve(&model_ref).await;
275
276        let err = result.err().expect("expected an error");
277        match err {
278            ResolverError::ConstructionFailed { provider, .. } => {
279                assert_eq!(provider, Provider::Anthropic);
280            }
281            e => panic!("expected ConstructionFailed, got: {e}"),
282        }
283    }
284
285    #[tokio::test]
286    async fn test_resolver_shorthand_unknown_returns_unknown_provider() {
287        let resolver = DefaultModelResolver::new();
288        let model_ref = ModelRef::Shorthand("totally-unknown-model".to_string());
289        let result = resolver.resolve(&model_ref).await;
290
291        let err = result.err().expect("expected an error");
292        match err {
293            ResolverError::UnknownProvider { name } => {
294                assert_eq!(name, "totally-unknown-model");
295            }
296            e => panic!("expected UnknownProvider, got: {e}"),
297        }
298    }
299
300    #[tokio::test]
301    async fn test_resolver_structured_uses_provider_field() {
302        let resolver = DefaultModelResolver::new();
303        let model_ref = ModelRef::Structured {
304            provider: Provider::Anthropic,
305            model: ModelConfig::Name("claude-3.5-sonnet".to_string()),
306            speed: None,
307        };
308        let result = resolver.resolve(&model_ref).await;
309
310        let err = result.err().expect("expected an error");
311        match err {
312            ResolverError::ConstructionFailed { provider, reason } => {
313                assert_eq!(provider, Provider::Anthropic);
314                assert!(reason.contains("claude-3.5-sonnet"));
315            }
316            e => panic!("expected ConstructionFailed, got: {e}"),
317        }
318    }
319
320    #[tokio::test]
321    async fn test_resolver_structured_openai_compatible() {
322        let resolver = DefaultModelResolver::new();
323        let model_ref = ModelRef::Structured {
324            provider: Provider::OpenaiCompatible,
325            model: ModelConfig::Compatible {
326                model: "deepseek-chat".to_string(),
327                base_url: "https://api.deepseek.com/v1".to_string(),
328                api_key: "sk-test-key".to_string(),
329            },
330            speed: None,
331        };
332        let result = resolver.resolve(&model_ref).await;
333
334        let err = result.err().expect("expected an error");
335        match err {
336            ResolverError::ConstructionFailed { provider, reason } => {
337                assert_eq!(provider, Provider::OpenaiCompatible);
338                assert!(reason.contains("deepseek-chat"));
339                assert!(reason.contains("https://api.deepseek.com/v1"));
340            }
341            e => panic!("expected ConstructionFailed, got: {e}"),
342        }
343    }
344
345    #[tokio::test]
346    async fn test_resolver_structured_with_speed_hint() {
347        let resolver = DefaultModelResolver::new();
348        let model_ref = ModelRef::Structured {
349            provider: Provider::Gemini,
350            model: ModelConfig::Name("gemini-2.5-flash".to_string()),
351            speed: Some("fast".to_string()),
352        };
353        let result = resolver.resolve(&model_ref).await;
354
355        // Speed hint doesn't affect provider resolution
356        let err = result.err().expect("expected an error");
357        match err {
358            ResolverError::ConstructionFailed { provider, .. } => {
359                assert_eq!(provider, Provider::Gemini);
360            }
361            e => panic!("expected ConstructionFailed, got: {e}"),
362        }
363    }
364}