use anyhow::Result;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct MapState {
pub model_reports: Vec<MapModelReport>,
}
#[derive(Debug, Deserialize)]
pub struct MapModelReport {
pub short_name: String,
#[serde(default)]
pub dht_prefix: String,
#[serde(default)]
pub repository: String,
#[serde(default)]
#[allow(dead_code)]
pub num_blocks: u32,
#[serde(default)]
pub server_rows: Vec<serde_json::Value>,
}
impl MapModelReport {
pub fn server_count(&self) -> usize {
self.server_rows.len()
}
}
pub async fn fetch_map(endpoint: &str) -> Result<MapState> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(8))
.build()?;
let state = client
.get(endpoint)
.send()
.await?
.json::<MapState>()
.await?;
Ok(state)
}
fn normalize(s: &str) -> String {
s.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
.to_lowercase()
}
pub fn match_score(ollama_ref: &str, map_model: &MapModelReport) -> u32 {
let full_norm = normalize(ollama_ref);
let base_norm = normalize(ollama_ref.split(':').next().unwrap_or(ollama_ref));
let model_part = map_model
.short_name
.split('/')
.next_back()
.unwrap_or(&map_model.short_name);
let candidates = [
normalize(&map_model.short_name),
normalize(&map_model.dht_prefix),
normalize(model_part),
];
for c in &candidates {
if c.is_empty() {
continue;
}
if c.contains(&full_norm) || full_norm.contains(c.as_str()) {
return 2;
}
if !base_norm.is_empty() && (c.contains(&base_norm) || base_norm.contains(c.as_str())) {
return 1;
}
}
0
}
pub struct ModelChoice {
pub ollama_ref: String,
pub map_name: Option<String>,
pub dht_prefix: Option<String>,
pub repository: Option<String>,
pub server_count: usize,
}
pub fn pick_best_model(
local_models: &[String],
map: &MapState,
current_model: &str,
) -> Option<ModelChoice> {
if local_models.is_empty() || map.model_reports.is_empty() {
return None;
}
let mut scored: Vec<(String, u32, usize, &MapModelReport)> = local_models
.iter()
.filter_map(|local| {
map.model_reports
.iter()
.filter_map(|r| {
let s = match_score(local, r);
if s > 0 {
Some((s, r.server_count(), r))
} else {
None
}
})
.max_by_key(|&(score, count, _)| score as usize * 100_000 + count)
.map(|(score, count, report)| (local.clone(), score, count, report))
})
.collect();
if scored.is_empty() {
return None;
}
scored.sort_by(|a, b| b.1.cmp(&a.1).then(b.2.cmp(&a.2)));
if let Some(pos) = scored.iter().position(|(m, _, _, _)| m == current_model) {
let (_, cur_score, cur_count, cur_report) = &scored[pos];
let (_, best_score, best_count, _) = &scored[0];
if cur_score == best_score && best_count <= &(cur_count + 2) {
return Some(ModelChoice {
ollama_ref: current_model.to_string(),
map_name: Some(cur_report.short_name.clone()),
dht_prefix: non_empty(&cur_report.dht_prefix),
repository: non_empty(&cur_report.repository),
server_count: *cur_count,
});
}
}
let (ref best_local, _, best_count, best_report) = scored[0];
Some(ModelChoice {
ollama_ref: best_local.clone(),
map_name: Some(best_report.short_name.clone()),
dht_prefix: non_empty(&best_report.dht_prefix),
repository: non_empty(&best_report.repository),
server_count: best_count,
})
}
fn non_empty(s: &str) -> Option<String> {
if s.is_empty() {
None
} else {
Some(s.to_string())
}
}