use super::presets::{PROVIDERS, ProviderPreset, provider_short_name};
use crate::config::types::ProviderEntry;
pub(super) fn fetch_models_from_provider(base_url: &str, api_key: &str) -> Option<Vec<String>> {
if base_url.is_empty() {
return None;
}
let url = format!("{}/models", base_url.trim_end_matches('/'));
let key = api_key.to_string();
std::thread::spawn(move || -> Option<Vec<String>> {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.ok()?;
let mut req = client.get(&url);
if !key.is_empty() {
req = req.bearer_auth(&key);
}
let resp = req.send().ok()?;
if !resp.status().is_success() {
return None;
}
#[derive(serde::Deserialize)]
struct ModelEntry {
id: String,
}
#[derive(serde::Deserialize)]
struct ModelList {
data: Vec<ModelEntry>,
}
let list: ModelList = resp.json().ok()?;
let ids: Vec<String> = list.data.into_iter().map(|e| e.id).collect();
if ids.is_empty() { None } else { Some(ids) }
})
.join()
.ok()
.flatten()
}
pub struct RoleModels {
pub architect: String,
pub code: String,
pub ask: String,
pub default: String,
}
struct ModelCandidate {
provider: String,
model: String,
architect_score: u8,
code_score: u8,
ask_score: u8,
}
pub struct OptimalAssignment {
pub architect: (String, String), pub code: (String, String),
pub ask: (String, String),
}
pub fn compute_optimal_assignments(providers: &[ProviderEntry]) -> Option<OptimalAssignment> {
if providers.is_empty() {
return None;
}
let registry = crate::registry::cache::load_or_fetch_blocking();
let mut candidates: Vec<ModelCandidate> = Vec::new();
for entry in providers {
if entry.cli.is_some() {
continue;
}
let preset = PROVIDERS.iter().find(|p| {
provider_short_name(p.label) == entry.name || p.label.eq_ignore_ascii_case(&entry.name)
});
let all_models = entry.all_models();
if all_models.is_empty() {
continue;
}
for model in all_models {
let m = model.to_lowercase();
let tier: u8 = if m.contains("mini")
|| m.contains("flash")
|| m.contains("haiku")
|| m.contains("fast")
|| m.contains("instant")
|| m.contains("lite")
|| m.contains("8b")
|| m.contains("7b")
|| m.contains("3b")
{
1
} else if m.contains("opus")
|| m.contains("ultra")
|| m.contains("pro")
|| m.contains("large")
|| m.contains("plus")
|| m.contains("o3")
|| m.contains("o1")
|| m.contains("70b")
|| m.contains("72b")
|| m.contains("405b")
|| m.contains("reasoner")
{
3
} else {
if registry
.models_for_provider(&entry.name)
.into_iter()
.any(|r| r.key.ends_with(model) && r.max_input_tokens.unwrap_or(0) >= 128_000)
{
3
} else {
2
}
};
let preset_architect = preset.map(|p| p.architect_model).unwrap_or("");
let preset_code = preset.map(|p| p.code_model).unwrap_or("");
let preset_ask = preset.map(|p| p.ask_model).unwrap_or("");
let architect_score = if model == preset_architect { 10 } else { tier };
let code_score = if model == preset_code {
10
} else if tier == 2 {
5
} else {
tier
};
let ask_score = if model == preset_ask {
10
} else if tier == 1 {
5
} else {
tier
};
candidates.push(ModelCandidate {
provider: entry.name.clone(),
model: model.to_string(),
architect_score,
code_score,
ask_score,
});
}
}
if candidates.is_empty() {
return None;
}
let best = |score_fn: fn(&ModelCandidate) -> u8| -> (String, String) {
candidates
.iter()
.max_by_key(|c| score_fn(c))
.map(|c| (c.provider.clone(), c.model.clone()))
.expect("candidates is non-empty (checked above)")
};
Some(OptimalAssignment {
architect: best(|c| c.architect_score),
code: best(|c| c.code_score),
ask: best(|c| c.ask_score),
})
}
struct ModelWeights {
context_score: f64,
tool_score: f64,
cost_efficiency: f64,
name_tier: f64,
}
impl ModelWeights {
fn architect_score(&self) -> f64 {
self.context_score * 0.40
+ self.tool_score * 0.35
+ self.cost_efficiency * 0.10
+ self.name_tier * 0.15
}
fn code_score(&self) -> f64 {
self.context_score * 0.20
+ self.tool_score * 0.50
+ self.cost_efficiency * 0.15
+ self.name_tier * 0.15
}
fn ask_score(&self) -> f64 {
self.context_score * 0.10
+ self.tool_score * 0.10
+ self.cost_efficiency * 0.50
+ self.name_tier * 0.30
}
}
fn model_weights(
model: &str,
registry: &crate::registry::model::ModelRegistry,
provider_name: &str,
) -> ModelWeights {
let rm = registry.find_model(provider_name, model);
let output_bonus = rm
.and_then(|r| r.max_output_tokens)
.map(|t| if t >= 8192 { 0.05 } else { 0.0 })
.unwrap_or(0.0);
let context_score = (rm
.and_then(|r| r.max_input_tokens)
.map(|t| (t as f64).log2() / 20.0_f64)
.map(|s| s.clamp(0.0, 1.0))
.unwrap_or_else(|| {
let m = model.to_lowercase();
if m.contains("70b")
|| m.contains("72b")
|| m.contains("405b")
|| m.contains("opus")
|| m.contains("ultra")
{
0.8
} else if m.contains("mini")
|| m.contains("haiku")
|| m.contains("8b")
|| m.contains("7b")
{
0.4
} else {
0.6
}
})
+ output_bonus)
.clamp(0.0, 1.0);
let vision_bonus = rm
.map(|r| if r.supports_vision { 0.1 } else { 0.0 })
.unwrap_or(0.0);
let tool_score = (rm
.map(|r| if r.supports_function_calling { 1.0_f64 } else { 0.0_f64 })
.unwrap_or(0.5_f64) + vision_bonus)
.clamp(0.0, 1.0);
let cost_efficiency = rm
.and_then(|r| r.output_cost_per_million)
.map(|c| 1.0 / (1.0 + c / 10.0))
.unwrap_or(0.5);
let m = model.to_lowercase();
let name_tier = if m.contains("mini")
|| m.contains("flash")
|| m.contains("haiku")
|| m.contains("fast")
|| m.contains("instant")
|| m.contains("lite")
|| m.contains("8b")
|| m.contains("7b")
|| m.contains("3b")
{
0.0
} else if m.contains("opus")
|| m.contains("ultra")
|| m.contains("pro")
|| m.contains("large")
|| m.contains("plus")
|| m.contains("o3")
|| m.contains("o1")
|| m.contains("70b")
|| m.contains("72b")
|| m.contains("405b")
|| m.contains("reasoner")
{
1.0
} else {
0.5
};
ModelWeights {
context_score,
tool_score,
cost_efficiency,
name_tier,
}
}
pub(super) fn assign_models_by_role(
available: &[String],
registry: &crate::registry::model::ModelRegistry,
preset: &ProviderPreset,
provider_name: &str,
) -> RoleModels {
let avail_set: std::collections::HashSet<&str> = available.iter().map(|s| s.as_str()).collect();
let fallback = preset.default_model;
if !preset.architect_model.is_empty()
&& !preset.code_model.is_empty()
&& !preset.ask_model.is_empty()
&& avail_set.contains(preset.architect_model)
&& avail_set.contains(preset.code_model)
&& avail_set.contains(preset.ask_model)
{
let code = preset.code_model.to_string();
return RoleModels {
default: code.clone(),
architect: preset.architect_model.to_string(),
code,
ask: preset.ask_model.to_string(),
};
}
let scored: Vec<(&str, ModelWeights)> = available
.iter()
.map(|m| (m.as_str(), model_weights(m, registry, provider_name)))
.collect();
let best_for = |score_fn: fn(&ModelWeights) -> f64, preset_hint: &'static str| -> String {
if !preset_hint.is_empty() && avail_set.contains(preset_hint) {
return preset_hint.to_string();
}
scored
.iter()
.max_by(|(_, a), (_, b)| {
score_fn(a)
.partial_cmp(&score_fn(b))
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(m, _)| m.to_string())
.unwrap_or_else(|| fallback.to_string())
};
let architect = best_for(ModelWeights::architect_score, preset.architect_model);
let code = best_for(ModelWeights::code_score, preset.code_model);
let ask = best_for(ModelWeights::ask_score, preset.ask_model);
RoleModels {
default: code.clone(),
architect,
code,
ask,
}
}