Skip to main content

codetether_agent/autochat/
model_rotation.rs

1use crate::provider::{ModelInfo, Provider, ProviderRegistry};
2use serde::{Deserialize, Serialize};
3use std::collections::HashSet;
4use std::sync::Arc;
5
6const PRIORITY_PROVIDERS: [&str; 8] = [
7    "minimax",
8    "zai",
9    "github-copilot",
10    "github-copilot-enterprise",
11    "openai-codex",
12    "openrouter",
13    "minimax-credits",
14    "openai",
15];
16
17const RESOLUTION_FALLBACKS: [&str; 12] = [
18    "minimax",
19    "zai",
20    "github-copilot",
21    "github-copilot-enterprise",
22    "openai-codex",
23    "openrouter",
24    "minimax-credits",
25    "openai",
26    "anthropic",
27    "novita",
28    "moonshotai",
29    "google",
30];
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct RelayModelRotation {
34    #[serde(default)]
35    pub model_refs: Vec<String>,
36    #[serde(default)]
37    pub cursor: usize,
38}
39
40impl RelayModelRotation {
41    pub fn fallback(model_ref: &str) -> Self {
42        Self {
43            model_refs: vec![model_ref.to_string()],
44            cursor: 0,
45        }
46    }
47
48    pub fn next_model_ref(&mut self, fallback: &str) -> String {
49        if self.model_refs.is_empty() {
50            return fallback.to_string();
51        }
52        let idx = self.cursor % self.model_refs.len();
53        let next = self.model_refs[idx].clone();
54        self.cursor = self.cursor.saturating_add(1);
55        next
56    }
57}
58
59pub async fn build_round_robin_model_rotation(
60    registry: &Arc<ProviderRegistry>,
61    requested_model_ref: &str,
62) -> RelayModelRotation {
63    let mut rotation = RelayModelRotation::default();
64    let mut used_providers = HashSet::<String>::new();
65
66    if let Some(model_ref) = normalize_requested_model_ref(registry, requested_model_ref) {
67        add_unique_model_ref(&mut rotation.model_refs, &mut used_providers, &model_ref);
68    }
69
70    for provider_name in PRIORITY_PROVIDERS {
71        if let Some(model_ref) = preferred_model_ref_for_provider(registry, provider_name).await {
72            add_unique_model_ref(&mut rotation.model_refs, &mut used_providers, &model_ref);
73        }
74    }
75
76    if rotation.model_refs.is_empty() {
77        rotation.model_refs.push(requested_model_ref.to_string());
78    }
79
80    rotation
81}
82
83pub fn resolve_provider_for_model_autochat(
84    registry: &Arc<ProviderRegistry>,
85    model_ref: &str,
86) -> Option<(Arc<dyn Provider>, String)> {
87    let (provider_name, model_name) = crate::provider::parse_model_string(model_ref);
88    if let Some(provider_name) = provider_name {
89        let normalized = normalize_provider_name(provider_name);
90        return registry
91            .get(normalized)
92            .map(|provider| (provider, model_name.to_string()));
93    }
94
95    for provider_name in RESOLUTION_FALLBACKS {
96        if let Some(provider) = registry.get(provider_name) {
97            return Some((provider, model_ref.to_string()));
98        }
99    }
100
101    registry
102        .list()
103        .first()
104        .copied()
105        .and_then(|name| registry.get(name))
106        .map(|provider| (provider, model_ref.to_string()))
107}
108
109fn normalize_requested_model_ref(
110    registry: &Arc<ProviderRegistry>,
111    requested_model_ref: &str,
112) -> Option<String> {
113    let (provider_name, model_name) = crate::provider::parse_model_string(requested_model_ref);
114    if let Some(provider_name) = provider_name {
115        let normalized = normalize_provider_name(provider_name);
116        if registry.get(normalized).is_some() {
117            return Some(format!("{normalized}/{model_name}"));
118        }
119    }
120
121    resolve_provider_for_model_autochat(registry, requested_model_ref)
122        .map(|(provider, model)| format!("{}/{}", provider.name(), model))
123}
124
125fn add_unique_model_ref(
126    model_refs: &mut Vec<String>,
127    used_providers: &mut HashSet<String>,
128    model_ref: &str,
129) {
130    let (provider_name, _) = crate::provider::parse_model_string(model_ref);
131    let Some(provider_name) = provider_name else {
132        return;
133    };
134    if used_providers.insert(provider_name.to_string()) {
135        model_refs.push(model_ref.to_string());
136    }
137}
138
139async fn preferred_model_ref_for_provider(
140    registry: &Arc<ProviderRegistry>,
141    provider_name: &str,
142) -> Option<String> {
143    let provider = registry.get(provider_name)?;
144    let listed = provider.list_models().await.unwrap_or_default();
145    if let Some(model_id) = choose_model_from_list(provider_name, &listed) {
146        return Some(format!("{provider_name}/{model_id}"));
147    }
148    default_model_for_provider(provider_name).map(|model_id| format!("{provider_name}/{model_id}"))
149}
150
151fn choose_model_from_list(provider_name: &str, models: &[ModelInfo]) -> Option<String> {
152    if models.is_empty() {
153        return None;
154    }
155
156    if provider_name == "openrouter" {
157        return choose_openrouter_model(models).or_else(|| models.first().map(|m| m.id.clone()));
158    }
159
160    if provider_name == "github-copilot" || provider_name == "github-copilot-enterprise" {
161        return choose_copilot_model(models).or_else(|| models.first().map(|m| m.id.clone()));
162    }
163
164    let preferred = preferred_models_for_provider(provider_name);
165    for model_id in preferred {
166        if let Some(found) = models
167            .iter()
168            .find(|model| model.id.eq_ignore_ascii_case(model_id))
169        {
170            return Some(found.id.clone());
171        }
172    }
173
174    models.first().map(|m| m.id.clone())
175}
176
177fn choose_openrouter_model(models: &[ModelInfo]) -> Option<String> {
178    let free_models: Vec<&ModelInfo> = models
179        .iter()
180        .filter(|model| is_openrouter_free_model(&model.id))
181        .collect();
182    if !free_models.is_empty() {
183        return free_models
184            .into_iter()
185            .max_by_key(|model| score_openrouter_model(&model.id))
186            .map(|model| model.id.clone());
187    }
188    models
189        .iter()
190        .max_by_key(|model| score_openrouter_model(&model.id))
191        .map(|model| model.id.clone())
192}
193
194fn choose_copilot_model(models: &[ModelInfo]) -> Option<String> {
195    let preferred = ["gpt-5-mini", "gpt-4.1", "gpt-4o"];
196    for model_id in preferred {
197        if let Some(found) = models
198            .iter()
199            .find(|model| model.id.eq_ignore_ascii_case(model_id))
200        {
201            return Some(found.id.clone());
202        }
203    }
204
205    models
206        .iter()
207        .find(|model| {
208            model.input_cost_per_million == Some(0.0) && model.output_cost_per_million == Some(0.0)
209        })
210        .map(|model| model.id.clone())
211}
212
213fn is_openrouter_free_model(model_id: &str) -> bool {
214    let id = model_id.to_ascii_lowercase();
215    id.contains(":free") || id.ends_with("-free")
216}
217
218fn score_openrouter_model(model_id: &str) -> i32 {
219    let id = model_id.to_ascii_lowercase();
220    let mut score = 0;
221    if is_openrouter_free_model(&id) {
222        score += 1000;
223    }
224    if id.contains("glm-5") {
225        score += 250;
226    }
227    if id.contains("minimax") {
228        score += 220;
229    }
230    if id.contains("gpt-5-mini") {
231        score += 180;
232    }
233    if id.contains("coder") {
234        score += 70;
235    }
236    score
237}
238
239fn preferred_models_for_provider(provider_name: &str) -> &'static [&'static str] {
240    match provider_name {
241        "minimax" => &["MiniMax-M2.5", "MiniMax-M2.1", "MiniMax-M2"],
242        "minimax-credits" => &["MiniMax-M2.5-highspeed", "MiniMax-M2.1-highspeed"],
243        "zai" => &["glm-5", "glm-4.7", "glm-4.7-flash"],
244        "openai-codex" => &["gpt-5-mini", "gpt-5", "gpt-5.1-codex"],
245        "github-copilot" | "github-copilot-enterprise" => &["gpt-5-mini", "gpt-4.1", "gpt-4o"],
246        "openrouter" => &["z-ai/glm-5:free", "z-ai/glm-5", "z-ai/glm-4.7:free"],
247        _ => &[],
248    }
249}
250
251fn default_model_for_provider(provider_name: &str) -> Option<&'static str> {
252    let defaults = preferred_models_for_provider(provider_name);
253    defaults.first().copied()
254}
255
256fn normalize_provider_name(provider_name: &str) -> &str {
257    if provider_name.eq_ignore_ascii_case("zhipuai") || provider_name.eq_ignore_ascii_case("z-ai") {
258        "zai"
259    } else {
260        provider_name
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::{
267        RelayModelRotation, choose_copilot_model, choose_openrouter_model, normalize_provider_name,
268    };
269    use crate::provider::ModelInfo;
270
271    #[test]
272    fn relay_rotation_cycles_models() {
273        let mut rotation = RelayModelRotation {
274            model_refs: vec!["a/x".to_string(), "b/y".to_string()],
275            cursor: 0,
276        };
277        assert_eq!(rotation.next_model_ref("fallback"), "a/x");
278        assert_eq!(rotation.next_model_ref("fallback"), "b/y");
279        assert_eq!(rotation.next_model_ref("fallback"), "a/x");
280    }
281
282    #[test]
283    fn provider_alias_normalizes_zhipu_and_z_ai() {
284        assert_eq!(normalize_provider_name("zhipuai"), "zai");
285        assert_eq!(normalize_provider_name("z-ai"), "zai");
286        assert_eq!(normalize_provider_name("openrouter"), "openrouter");
287    }
288
289    #[test]
290    fn openrouter_prefers_free_glm_models() {
291        let models = vec![
292            model("openai/gpt-4.1"),
293            model("z-ai/glm-5:free"),
294            model("moonshot/kimi-k2:free"),
295        ];
296        assert_eq!(
297            choose_openrouter_model(&models).as_deref(),
298            Some("z-ai/glm-5:free")
299        );
300    }
301
302    #[test]
303    fn copilot_prefers_gpt_5_mini_when_available() {
304        let models = vec![
305            model("gpt-4o"),
306            model("gpt-5-mini"),
307            model("claude-sonnet-4"),
308        ];
309        assert_eq!(choose_copilot_model(&models).as_deref(), Some("gpt-5-mini"));
310    }
311
312    fn model(id: &str) -> ModelInfo {
313        ModelInfo {
314            id: id.to_string(),
315            name: id.to_string(),
316            provider: "test".to_string(),
317            context_window: 128_000,
318            max_output_tokens: Some(16_384),
319            supports_vision: false,
320            supports_tools: true,
321            supports_streaming: true,
322            input_cost_per_million: Some(0.0),
323            output_cost_per_million: Some(0.0),
324        }
325    }
326}