Skip to main content

mur_common/muragent/
model_class.rs

1//! Static model-class table: map a (provider, name) binding to a `ModelHint`
2//! (tier, RAM estimate, local-capability). Heuristic + conservative on
3//! unknown providers. See spec §7.1.
4
5use crate::muragent::manifest::{ModelHint, ModelTier};
6
7const LOCAL_PROVIDERS: &[&str] = &[
8    "ollama",
9    "mlx",
10    "llamacpp",
11    "llama_cpp",
12    "localai",
13    "lmstudio",
14];
15const CLOUD_PROVIDERS: &[&str] = &[
16    "anthropic",
17    "openai",
18    "google",
19    "gemini",
20    "mistral",
21    "groq",
22    "cohere",
23    "deepseek",
24    "xai",
25    "openrouter",
26];
27const SMALL_MARKERS: &[&str] = &[
28    "1b", "1.5b", "2b", "3b", "mini", "nano", "haiku", "flash", "small", "tiny",
29];
30const LARGE_LOCAL_MARKERS: &[&str] = &["70b", "72b", "405b", "mixtral", "command-r-plus"];
31
32/// Map a model binding to a `ModelHint`. Local providers classify by size
33/// markers in the name; known cloud providers are frontier unless a "small"
34/// variant marker is present; unknown providers fall back to a conservative
35/// local-capable mid tier (the wizard still offers all options).
36pub fn classify(provider: &str, name: &str) -> ModelHint {
37    let p = provider.to_ascii_lowercase();
38    let n = name.to_ascii_lowercase();
39    let has = |markers: &[&str]| markers.iter().any(|m| n.contains(m));
40
41    if LOCAL_PROVIDERS.contains(&p.as_str()) {
42        let (tier, min_ram_gb) = if has(SMALL_MARKERS) {
43            (ModelTier::Small, 8)
44        } else if has(LARGE_LOCAL_MARKERS) {
45            (ModelTier::Frontier, 64)
46        } else {
47            (ModelTier::Mid, 16)
48        };
49        return ModelHint {
50            provider: provider.to_string(),
51            name: name.to_string(),
52            tier,
53            min_ram_gb,
54            local_capable: true,
55        };
56    }
57
58    if CLOUD_PROVIDERS.contains(&p.as_str()) {
59        let tier = if has(SMALL_MARKERS) {
60            ModelTier::Mid
61        } else {
62            ModelTier::Frontier
63        };
64        return ModelHint {
65            provider: provider.to_string(),
66            name: name.to_string(),
67            tier,
68            min_ram_gb: 0,
69            local_capable: false,
70        };
71    }
72
73    // Unknown provider: conservative — assume a local-capable mid model.
74    ModelHint {
75        provider: provider.to_string(),
76        name: name.to_string(),
77        tier: ModelTier::Mid,
78        min_ram_gb: 16,
79        local_capable: true,
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn local_small_model() {
89        let h = classify("ollama", "llama3.2:3b");
90        assert_eq!(h.tier, ModelTier::Small);
91        assert!(h.local_capable);
92        assert_eq!(h.min_ram_gb, 8);
93    }
94
95    #[test]
96    fn cloud_frontier_model() {
97        let h = classify("anthropic", "claude-opus-4-7");
98        assert_eq!(h.tier, ModelTier::Frontier);
99        assert!(!h.local_capable);
100        assert_eq!(h.min_ram_gb, 0);
101    }
102
103    #[test]
104    fn cloud_small_variant_is_mid() {
105        let h = classify("openai", "gpt-4o-mini");
106        assert_eq!(h.tier, ModelTier::Mid);
107        assert!(!h.local_capable);
108    }
109
110    #[test]
111    fn unknown_provider_is_conservative_local_mid() {
112        let h = classify("acme-llm", "whatever-v2");
113        assert_eq!(h.tier, ModelTier::Mid);
114        assert!(h.local_capable);
115        assert_eq!(h.min_ram_gb, 16);
116    }
117}