use std::sync::Arc;
use async_trait::async_trait;
use adk_core::Llm;
use crate::types::{ModelConfig, ModelRef, Provider};
#[derive(Debug, thiserror::Error)]
pub enum ResolverError {
#[error(
"cannot infer provider from model name \"{name}\". Expected prefix: gemini, gpt, claude, llama, mistral, or deepseek"
)]
UnknownProvider { name: String },
#[error("failed to construct model for provider {provider:?}: {reason}")]
ConstructionFailed { provider: Provider, reason: String },
}
pub type ResolverResult<T> = std::result::Result<T, ResolverError>;
#[async_trait]
pub trait ModelResolver: Send + Sync {
async fn resolve(&self, model_ref: &ModelRef) -> ResolverResult<Arc<dyn Llm>>;
}
pub fn infer_provider(name: &str) -> ResolverResult<Provider> {
let lower = name.to_lowercase();
if lower.starts_with("gemini") {
Ok(Provider::Gemini)
} else if lower.starts_with("gpt") {
Ok(Provider::Openai)
} else if lower.starts_with("claude") {
Ok(Provider::Anthropic)
} else if lower.starts_with("llama")
|| lower.starts_with("mistral")
|| lower.starts_with("deepseek")
{
Ok(Provider::Ollama)
} else {
Err(ResolverError::UnknownProvider { name: name.to_string() })
}
}
#[derive(Debug, Clone, Default)]
pub struct DefaultModelResolver;
impl DefaultModelResolver {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl ModelResolver for DefaultModelResolver {
async fn resolve(&self, model_ref: &ModelRef) -> ResolverResult<Arc<dyn Llm>> {
match model_ref {
ModelRef::Shorthand(name) => {
let provider = infer_provider(name)?;
Err(ResolverError::ConstructionFailed {
provider,
reason: format!(
"DefaultModelResolver cannot construct real models. \
Use a platform-provided resolver with credentials. \
Resolved provider: {provider:?}, model: {name}"
),
})
}
ModelRef::Structured { provider, model, .. } => {
let model_name = match model {
ModelConfig::Name(name) => name.clone(),
ModelConfig::Compatible { model, base_url, .. } => {
return Err(ResolverError::ConstructionFailed {
provider: *provider,
reason: format!(
"DefaultModelResolver cannot construct OpenAI-compatible \
client. Model: {model}, base_url: {base_url}. \
Use a platform-provided resolver with credentials."
),
});
}
};
Err(ResolverError::ConstructionFailed {
provider: *provider,
reason: format!(
"DefaultModelResolver cannot construct real models. \
Use a platform-provided resolver with credentials. \
Provider: {provider:?}, model: {model_name}"
),
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_infer_gemini_from_shorthand() {
assert_eq!(infer_provider("gemini-2.5-flash").unwrap(), Provider::Gemini);
assert_eq!(infer_provider("gemini-2.5-pro").unwrap(), Provider::Gemini);
assert_eq!(infer_provider("gemini-3.1-flash-lite-preview").unwrap(), Provider::Gemini);
}
#[test]
fn test_infer_openai_from_shorthand() {
assert_eq!(infer_provider("gpt-4.1").unwrap(), Provider::Openai);
assert_eq!(infer_provider("gpt-4o").unwrap(), Provider::Openai);
assert_eq!(infer_provider("gpt-4.1-mini").unwrap(), Provider::Openai);
}
#[test]
fn test_infer_anthropic_from_shorthand() {
assert_eq!(infer_provider("claude-3.5-sonnet").unwrap(), Provider::Anthropic);
assert_eq!(infer_provider("claude-4-opus").unwrap(), Provider::Anthropic);
}
#[test]
fn test_infer_ollama_from_llama() {
assert_eq!(infer_provider("llama-3.2-70b").unwrap(), Provider::Ollama);
}
#[test]
fn test_infer_ollama_from_mistral() {
assert_eq!(infer_provider("mistral-7b").unwrap(), Provider::Ollama);
assert_eq!(infer_provider("mistral-large").unwrap(), Provider::Ollama);
}
#[test]
fn test_infer_ollama_from_deepseek() {
assert_eq!(infer_provider("deepseek-chat").unwrap(), Provider::Ollama);
assert_eq!(infer_provider("deepseek-coder").unwrap(), Provider::Ollama);
}
#[test]
fn test_infer_unknown_returns_error() {
let result = infer_provider("some-random-model");
assert!(result.is_err());
match result.unwrap_err() {
ResolverError::UnknownProvider { name } => {
assert_eq!(name, "some-random-model");
}
_ => panic!("expected UnknownProvider error"),
}
}
#[test]
fn test_infer_case_insensitive() {
assert_eq!(infer_provider("Gemini-2.5-flash").unwrap(), Provider::Gemini);
assert_eq!(infer_provider("GPT-4.1").unwrap(), Provider::Openai);
assert_eq!(infer_provider("Claude-3.5-sonnet").unwrap(), Provider::Anthropic);
assert_eq!(infer_provider("LLAMA-3.2").unwrap(), Provider::Ollama);
assert_eq!(infer_provider("DeepSeek-V3").unwrap(), Provider::Ollama);
}
#[tokio::test]
async fn test_resolver_shorthand_gemini_infers_provider() {
let resolver = DefaultModelResolver::new();
let model_ref = ModelRef::Shorthand("gemini-2.5-flash".to_string());
let result = resolver.resolve(&model_ref).await;
let err = result.err().expect("expected an error");
match err {
ResolverError::ConstructionFailed { provider, reason } => {
assert_eq!(provider, Provider::Gemini);
assert!(reason.contains("gemini-2.5-flash"));
}
e => panic!("expected ConstructionFailed, got: {e}"),
}
}
#[tokio::test]
async fn test_resolver_shorthand_openai_infers_provider() {
let resolver = DefaultModelResolver::new();
let model_ref = ModelRef::Shorthand("gpt-4.1".to_string());
let result = resolver.resolve(&model_ref).await;
let err = result.err().expect("expected an error");
match err {
ResolverError::ConstructionFailed { provider, .. } => {
assert_eq!(provider, Provider::Openai);
}
e => panic!("expected ConstructionFailed, got: {e}"),
}
}
#[tokio::test]
async fn test_resolver_shorthand_anthropic_infers_provider() {
let resolver = DefaultModelResolver::new();
let model_ref = ModelRef::Shorthand("claude-3.5-sonnet".to_string());
let result = resolver.resolve(&model_ref).await;
let err = result.err().expect("expected an error");
match err {
ResolverError::ConstructionFailed { provider, .. } => {
assert_eq!(provider, Provider::Anthropic);
}
e => panic!("expected ConstructionFailed, got: {e}"),
}
}
#[tokio::test]
async fn test_resolver_shorthand_unknown_returns_unknown_provider() {
let resolver = DefaultModelResolver::new();
let model_ref = ModelRef::Shorthand("totally-unknown-model".to_string());
let result = resolver.resolve(&model_ref).await;
let err = result.err().expect("expected an error");
match err {
ResolverError::UnknownProvider { name } => {
assert_eq!(name, "totally-unknown-model");
}
e => panic!("expected UnknownProvider, got: {e}"),
}
}
#[tokio::test]
async fn test_resolver_structured_uses_provider_field() {
let resolver = DefaultModelResolver::new();
let model_ref = ModelRef::Structured {
provider: Provider::Anthropic,
model: ModelConfig::Name("claude-3.5-sonnet".to_string()),
speed: None,
};
let result = resolver.resolve(&model_ref).await;
let err = result.err().expect("expected an error");
match err {
ResolverError::ConstructionFailed { provider, reason } => {
assert_eq!(provider, Provider::Anthropic);
assert!(reason.contains("claude-3.5-sonnet"));
}
e => panic!("expected ConstructionFailed, got: {e}"),
}
}
#[tokio::test]
async fn test_resolver_structured_openai_compatible() {
let resolver = DefaultModelResolver::new();
let model_ref = ModelRef::Structured {
provider: Provider::OpenaiCompatible,
model: ModelConfig::Compatible {
model: "deepseek-chat".to_string(),
base_url: "https://api.deepseek.com/v1".to_string(),
api_key: "sk-test-key".to_string(),
},
speed: None,
};
let result = resolver.resolve(&model_ref).await;
let err = result.err().expect("expected an error");
match err {
ResolverError::ConstructionFailed { provider, reason } => {
assert_eq!(provider, Provider::OpenaiCompatible);
assert!(reason.contains("deepseek-chat"));
assert!(reason.contains("https://api.deepseek.com/v1"));
}
e => panic!("expected ConstructionFailed, got: {e}"),
}
}
#[tokio::test]
async fn test_resolver_structured_with_speed_hint() {
let resolver = DefaultModelResolver::new();
let model_ref = ModelRef::Structured {
provider: Provider::Gemini,
model: ModelConfig::Name("gemini-2.5-flash".to_string()),
speed: Some("fast".to_string()),
};
let result = resolver.resolve(&model_ref).await;
let err = result.err().expect("expected an error");
match err {
ResolverError::ConstructionFailed { provider, .. } => {
assert_eq!(provider, Provider::Gemini);
}
e => panic!("expected ConstructionFailed, got: {e}"),
}
}
}