ai_lib/model/
resolver.rs

1use std::borrow::Cow;
2
3use crate::{client::Provider, transport::TransportError, types::AiLibError};
4
5use super::catalog;
6
7#[derive(Debug, Clone)]
8pub struct ModelResolver;
9
10#[derive(Debug, Clone)]
11pub struct ModelResolution {
12    pub model: String,
13    pub source: ModelResolutionSource,
14    pub doc_url: &'static str,
15}
16
17impl ModelResolution {
18    pub fn new(
19        model: impl Into<String>,
20        source: ModelResolutionSource,
21        doc_url: &'static str,
22    ) -> Self {
23        Self {
24            model: model.into(),
25            source,
26            doc_url,
27        }
28    }
29}
30
31#[derive(Debug, Clone)]
32pub enum ModelResolutionSource {
33    Explicit,
34    CustomDefault,
35    EnvOverride,
36    ProviderDefault,
37    ProfileFallback,
38}
39
40impl ModelResolver {
41    pub fn new() -> Self {
42        Self
43    }
44
45    pub fn resolve_chat_model(
46        &self,
47        provider: Provider,
48        requested: Option<&str>,
49    ) -> ModelResolution {
50        let profile = catalog::profile(provider);
51        let doc_url = profile.doc_url;
52        if let Some(model) = requested {
53            return ModelResolution::new(
54                model.to_string(),
55                ModelResolutionSource::Explicit,
56                doc_url,
57            );
58        }
59
60        if let Some(env_model) = self.env_override(provider) {
61            return ModelResolution::new(env_model, ModelResolutionSource::EnvOverride, doc_url);
62        }
63
64        ModelResolution::new(
65            profile.default_chat_model(),
66            ModelResolutionSource::ProviderDefault,
67            doc_url,
68        )
69    }
70
71    pub fn fallback_after_invalid(
72        &self,
73        provider: Provider,
74        failed_model: &str,
75    ) -> Option<ModelResolution> {
76        let profile = catalog::profile(provider);
77        let doc_url = profile.doc_url;
78
79        if let Some(env_model) = self.env_override(provider) {
80            if !equals_ignore_case(&env_model, failed_model) {
81                return Some(ModelResolution::new(
82                    env_model,
83                    ModelResolutionSource::EnvOverride,
84                    doc_url,
85                ));
86            }
87        }
88
89        for candidate in profile.fallback_models {
90            if !equals_ignore_case(candidate, failed_model) {
91                return Some(ModelResolution::new(
92                    *candidate,
93                    ModelResolutionSource::ProfileFallback,
94                    doc_url,
95                ));
96            }
97        }
98
99        if !equals_ignore_case(profile.default_chat_model(), failed_model) {
100            return Some(ModelResolution::new(
101                profile.default_chat_model(),
102                ModelResolutionSource::ProviderDefault,
103                doc_url,
104            ));
105        }
106
107        None
108    }
109
110    pub fn doc_url(&self, provider: Provider) -> &'static str {
111        catalog::profile(provider).doc_url
112    }
113
114    pub fn suggestions(&self, provider: Provider) -> Vec<String> {
115        let mut list = Vec::new();
116
117        if let Some(env_override) = self.env_override(provider) {
118            push_unique(&mut list, env_override);
119        }
120
121        let profile = catalog::profile(provider);
122        push_unique(&mut list, profile.default_chat_model().to_string());
123        for candidate in profile.fallback_models {
124            push_unique(&mut list, (*candidate).to_string());
125        }
126
127        list
128    }
129
130    pub fn looks_like_invalid_model(&self, err: &AiLibError) -> bool {
131        match err {
132            AiLibError::ModelNotFound(_) => true,
133            AiLibError::InvalidRequest(msg)
134            | AiLibError::ProviderError(msg)
135            | AiLibError::InvalidModelResponse(msg) => contains_invalid_keyword(msg),
136            AiLibError::TransportError(TransportError::ClientError { status, message })
137            | AiLibError::TransportError(TransportError::ServerError { status, message }) => {
138                (*status == 400 || *status == 404) && contains_invalid_keyword(message)
139            }
140            _ => false,
141        }
142    }
143
144    pub fn decorate_invalid_model_error(
145        &self,
146        provider: Provider,
147        requested_model: &str,
148        err: AiLibError,
149    ) -> AiLibError {
150        let doc_url = self.doc_url(provider);
151        let suggestions = self.suggestions(provider);
152        let provider_name = format!("{provider:?}");
153        let suggestion_text = if suggestions.is_empty() {
154            Cow::Borrowed("no known fallback models configured")
155        } else {
156            Cow::Owned(suggestions.join(", "))
157        };
158
159        AiLibError::ModelNotFound(format!(
160            "Model `{}` is not available for provider {}. Try: {}. Docs: {}. Original error: {}",
161            requested_model, provider_name, suggestion_text, doc_url, err
162        ))
163    }
164
165    fn env_override(&self, provider: Provider) -> Option<String> {
166        let var = format!("{}_MODEL", provider.env_prefix());
167        std::env::var(var)
168            .ok()
169            .filter(|value| !value.trim().is_empty())
170    }
171}
172
173fn push_unique(list: &mut Vec<String>, value: String) {
174    if !list
175        .iter()
176        .any(|existing| equals_ignore_case(existing, &value))
177    {
178        list.push(value);
179    }
180}
181
182fn equals_ignore_case(a: &str, b: &str) -> bool {
183    a.eq_ignore_ascii_case(b)
184}
185
186fn contains_invalid_keyword(message: &str) -> bool {
187    let lower = message.to_ascii_lowercase();
188    lower.contains("invalid model")
189        || lower.contains("model_not_found")
190        || lower.contains("model not found")
191        || lower.contains("unknown model")
192        || lower.contains("unsupported model")
193        || lower.contains("\"code\":\"1500\"")
194}