Skip to main content

harn_vm/llm/
capabilities.rs

1//! Data-driven provider capabilities.
2//!
3//! The per-(provider, model) capability matrix (native tools, deferred
4//! tool loading, tool-search variants, prompt caching, extended thinking,
5//! max tool count) lives in the shipped `capabilities.toml` and is
6//! overridable per-project via `[[capabilities.provider.<name>]]` blocks
7//! in `harn.toml`. This module owns:
8//!
9//! - loading the built-in TOML (compiled in via `include_str!`);
10//! - merging user overrides on top;
11//! - matching a `(provider, model)` pair against the rule list with
12//!   glob + semver semantics;
13//! - exposing a stable `Capabilities` struct that the `LlmProvider`
14//!   trait delegates to as the single source of truth.
15//!
16//! Before this module the Anthropic / OpenAI gates were spread across
17//! `providers/anthropic.rs` (`claude_generation`, `claude_model_supports_tool_search`)
18//! and `providers/openai_compat.rs` (`gpt_generation`, `gpt_model_supports_tool_search`).
19//! Those parsers are still used here — they supply the version extractor —
20//! but the boolean gates that used to live alongside them are now data.
21
22use std::cell::RefCell;
23use std::collections::BTreeMap;
24use std::sync::OnceLock;
25
26use serde::Deserialize;
27
28use super::providers::anthropic::claude_generation;
29use super::providers::openai_compat::gpt_generation;
30
31/// Shipped default rules. Compiled into the binary at build time.
32const BUILTIN_TOML: &str = include_str!("capabilities.toml");
33
34/// Parsed on-disk capabilities schema. Public so harn-cli can
35/// construct one directly when wiring harn.toml overrides.
36#[derive(Debug, Clone, Deserialize, Default)]
37pub struct CapabilitiesFile {
38    /// Per-provider ordered rule lists. First matching rule wins.
39    #[serde(default)]
40    pub provider: BTreeMap<String, Vec<ProviderRule>>,
41    /// Sibling → canonical family mapping. Providers with no rule of
42    /// their own fall through to the named family (recursively).
43    #[serde(default)]
44    pub provider_family: BTreeMap<String, String>,
45}
46
47/// One row of the capability matrix.
48#[derive(Debug, Clone, Deserialize)]
49pub struct ProviderRule {
50    /// Glob pattern (supports leading / trailing `*` and a single mid-`*`).
51    /// Matched case-insensitively against the model ID.
52    pub model_match: String,
53    /// Optional `[major, minor]` lower bound. When set, the model ID
54    /// must parse via the provider's version extractor AND compare ≥
55    /// this tuple. Rules with an unparseable `version_min` for the
56    /// given model are skipped, not merged.
57    #[serde(default)]
58    pub version_min: Option<Vec<u32>>,
59    #[serde(default)]
60    pub native_tools: Option<bool>,
61    #[serde(default)]
62    pub defer_loading: Option<bool>,
63    #[serde(default)]
64    pub tool_search: Option<Vec<String>>,
65    #[serde(default)]
66    pub max_tools: Option<u32>,
67    #[serde(default)]
68    pub prompt_caching: Option<bool>,
69    #[serde(default)]
70    pub thinking: Option<bool>,
71}
72
73/// Resolved capabilities for a `(provider, model)` pair. Unset rule
74/// fields resolve to `false` / empty / `None` so callers never have to
75/// unwrap an `Option<bool>` for what are really boolean gates.
76#[derive(Debug, Clone, Default, PartialEq, Eq)]
77pub struct Capabilities {
78    pub native_tools: bool,
79    pub defer_loading: bool,
80    pub tool_search: Vec<String>,
81    pub max_tools: Option<u32>,
82    pub prompt_caching: bool,
83    pub thinking: bool,
84}
85
86thread_local! {
87    /// Per-thread user overrides installed by the CLI at startup. Kept
88    /// thread-local (not process-static) to match the rest of the VM
89    /// state model — the VM is !Send and each VM thread owns its own
90    /// configuration.
91    static USER_OVERRIDES: RefCell<Option<CapabilitiesFile>> = const { RefCell::new(None) };
92}
93
94/// Lazily-parsed built-in rules. The `include_str!` content is a static
95/// constant; parsing it once per process is safe and free of ordering
96/// hazards.
97static BUILTIN: OnceLock<CapabilitiesFile> = OnceLock::new();
98
99fn builtin() -> &'static CapabilitiesFile {
100    BUILTIN.get_or_init(|| {
101        toml::from_str::<CapabilitiesFile>(BUILTIN_TOML)
102            .expect("capabilities.toml must parse at build time")
103    })
104}
105
106/// Install project-level overrides for the current thread. Usually
107/// called once at CLI bootstrap after reading `harn.toml`. Passing
108/// `None` clears any prior override.
109pub fn set_user_overrides(file: Option<CapabilitiesFile>) {
110    USER_OVERRIDES.with(|cell| *cell.borrow_mut() = file);
111}
112
113/// Clear any thread-local user overrides. Used between test runs.
114pub fn clear_user_overrides() {
115    set_user_overrides(None);
116}
117
118/// Parse a TOML string containing the capabilities section's own shape
119/// (i.e. top-level `[[provider.X]]` + optional `[provider_family]`, the
120/// same layout used by the built-in `capabilities.toml`) and install as
121/// the current thread's override.
122pub fn set_user_overrides_toml(src: &str) -> Result<(), String> {
123    let parsed: CapabilitiesFile = toml::from_str(src).map_err(|e| e.to_string())?;
124    set_user_overrides(Some(parsed));
125    Ok(())
126}
127
128/// Extract the `[capabilities]` section from a full `harn.toml` source
129/// and install it as the current thread's override. The schema inside
130/// that section mirrors `CapabilitiesFile` but with every key prefixed
131/// by `capabilities.`:
132///
133/// ```toml
134/// [[capabilities.provider.my-proxy]]
135/// model_match = "*"
136/// native_tools = true
137/// tool_search = ["hosted"]
138/// ```
139pub fn set_user_overrides_from_manifest_toml(src: &str) -> Result<(), String> {
140    #[derive(Deserialize)]
141    struct Manifest {
142        #[serde(default)]
143        capabilities: Option<CapabilitiesFile>,
144    }
145    let parsed: Manifest = toml::from_str(src).map_err(|e| e.to_string())?;
146    set_user_overrides(parsed.capabilities);
147    Ok(())
148}
149
150/// Look up effective capabilities for a `(provider, model)` pair.
151/// Walks the provider_family chain until it finds a rule list that
152/// matches. Within any one provider's rule list, user overrides are
153/// consulted before the built-in rules. The first matching rule wins —
154/// later rules (and later layers in the family chain) are ignored.
155pub fn lookup(provider: &str, model: &str) -> Capabilities {
156    let user = USER_OVERRIDES.with(|cell| cell.borrow().clone());
157    lookup_with(provider, model, builtin(), user.as_ref())
158}
159
160fn lookup_with(
161    provider: &str,
162    model: &str,
163    builtin: &CapabilitiesFile,
164    user: Option<&CapabilitiesFile>,
165) -> Capabilities {
166    // Special case: mock spoofs either shape. Try anthropic first
167    // (Claude-shape model strings) so `mock` + `claude-opus-4-7`
168    // resolves to the Anthropic capability row — the same behaviour
169    // the hardcoded dispatch gave before this refactor.
170    if provider == "mock" {
171        if let Some(caps) = try_match_layer(user, builtin, "anthropic", model, provider) {
172            return caps;
173        }
174        if let Some(caps) = try_match_layer(user, builtin, "openai", model, provider) {
175            return caps;
176        }
177        return Capabilities::default();
178    }
179
180    // Normal chain: walk provider → family(provider) → ... with a
181    // visited-guard to avoid cycles in malformed user overrides.
182    let mut current = provider.to_string();
183    let mut visited: std::collections::HashSet<String> = std::collections::HashSet::new();
184    while visited.insert(current.clone()) {
185        if let Some(caps) = try_match_layer(user, builtin, &current, model, provider) {
186            return caps;
187        }
188        let next = user
189            .and_then(|f| f.provider_family.get(&current))
190            .or_else(|| builtin.provider_family.get(&current))
191            .cloned();
192        match next {
193            Some(parent) => current = parent,
194            None => break,
195        }
196    }
197    Capabilities::default()
198}
199
200/// Try the ordered rule list for `layer_provider` (user rules first,
201/// then built-in rules). Returns `Some(caps)` on the first match, else
202/// `None`. `original_provider` is threaded through only for diagnostics.
203fn try_match_layer(
204    user: Option<&CapabilitiesFile>,
205    builtin: &CapabilitiesFile,
206    layer_provider: &str,
207    model: &str,
208    _original_provider: &str,
209) -> Option<Capabilities> {
210    if let Some(user) = user {
211        if let Some(rules) = user.provider.get(layer_provider) {
212            for rule in rules {
213                if rule_matches(rule, model) {
214                    return Some(rule_to_caps(rule));
215                }
216            }
217        }
218    }
219    if let Some(rules) = builtin.provider.get(layer_provider) {
220        for rule in rules {
221            if rule_matches(rule, model) {
222                return Some(rule_to_caps(rule));
223            }
224        }
225    }
226    None
227}
228
229fn rule_to_caps(rule: &ProviderRule) -> Capabilities {
230    Capabilities {
231        native_tools: rule.native_tools.unwrap_or(false),
232        defer_loading: rule.defer_loading.unwrap_or(false),
233        tool_search: rule.tool_search.clone().unwrap_or_default(),
234        max_tools: rule.max_tools,
235        prompt_caching: rule.prompt_caching.unwrap_or(false),
236        thinking: rule.thinking.unwrap_or(false),
237    }
238}
239
240fn rule_matches(rule: &ProviderRule, model: &str) -> bool {
241    let lower = model.to_lowercase();
242    if !glob_match(&rule.model_match.to_lowercase(), &lower) {
243        return false;
244    }
245    if let Some(version_min) = &rule.version_min {
246        if version_min.len() != 2 {
247            return false;
248        }
249        let want = (version_min[0], version_min[1]);
250        let have = match extract_version(model) {
251            Some(v) => v,
252            // `version_min` was set but the model ID can't be parsed.
253            // Fail closed: skip this rule so more permissive catch-all
254            // rules below can still match.
255            None => return false,
256        };
257        if have < want {
258            return false;
259        }
260    }
261    true
262}
263
264/// Extract `(major, minor)` from a model ID by trying the Anthropic
265/// parser first (for `claude-*` shapes) then the OpenAI parser (`gpt-*`).
266/// Both parsers return `None` for shapes they don't recognise so this
267/// never mis-parses across families.
268fn extract_version(model: &str) -> Option<(u32, u32)> {
269    claude_generation(model).or_else(|| gpt_generation(model))
270}
271
272/// Simple glob matching with `*` wildcards. Mirrors the helper in
273/// `llm_config.rs` — keep them in sync if either ever grows regex or
274/// character-class support.
275fn glob_match(pattern: &str, input: &str) -> bool {
276    if let Some(prefix) = pattern.strip_suffix('*') {
277        if let Some(rest) = prefix.strip_prefix('*') {
278            // `*foo*` — substring match.
279            return input.contains(rest);
280        }
281        return input.starts_with(prefix);
282    }
283    if let Some(suffix) = pattern.strip_prefix('*') {
284        return input.ends_with(suffix);
285    }
286    if pattern.contains('*') {
287        let parts: Vec<&str> = pattern.split('*').collect();
288        if parts.len() == 2 {
289            return input.starts_with(parts[0]) && input.ends_with(parts[1]);
290        }
291        return input == pattern;
292    }
293    input == pattern
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    fn reset() {
301        clear_user_overrides();
302    }
303
304    #[test]
305    fn anthropic_opus_47_gets_full_capabilities() {
306        reset();
307        let caps = lookup("anthropic", "claude-opus-4-7");
308        assert!(caps.native_tools);
309        assert!(caps.defer_loading);
310        assert_eq!(caps.tool_search, vec!["bm25", "regex"]);
311        assert!(caps.prompt_caching);
312        assert!(caps.thinking);
313        assert_eq!(caps.max_tools, Some(10000));
314    }
315
316    #[test]
317    fn anthropic_haiku_44_has_no_tool_search() {
318        reset();
319        let caps = lookup("anthropic", "claude-haiku-4-4");
320        // Haiku 4.4 falls through to the `claude-*` catch-all row.
321        assert!(caps.native_tools);
322        assert!(caps.prompt_caching);
323        assert!(!caps.defer_loading);
324        assert!(caps.tool_search.is_empty());
325    }
326
327    #[test]
328    fn anthropic_haiku_45_supports_tool_search() {
329        reset();
330        let caps = lookup("anthropic", "claude-haiku-4-5");
331        assert!(caps.defer_loading);
332        assert_eq!(caps.tool_search, vec!["bm25", "regex"]);
333    }
334
335    #[test]
336    fn old_claude_gets_catchall() {
337        reset();
338        let caps = lookup("anthropic", "claude-opus-3-5");
339        assert!(caps.native_tools);
340        assert!(caps.prompt_caching);
341        assert!(!caps.defer_loading);
342        assert!(caps.tool_search.is_empty());
343    }
344
345    #[test]
346    fn openai_gpt_54_supports_tool_search() {
347        reset();
348        let caps = lookup("openai", "gpt-5.4");
349        assert!(caps.defer_loading);
350        assert_eq!(caps.tool_search, vec!["hosted", "client"]);
351    }
352
353    #[test]
354    fn openai_gpt_53_has_native_tools_only() {
355        reset();
356        let caps = lookup("openai", "gpt-5.3");
357        assert!(caps.native_tools);
358        assert!(!caps.defer_loading);
359        assert!(caps.tool_search.is_empty());
360    }
361
362    #[test]
363    fn openrouter_inherits_openai() {
364        reset();
365        let caps = lookup("openrouter", "gpt-5.4");
366        assert!(caps.defer_loading);
367        assert_eq!(caps.tool_search, vec!["hosted", "client"]);
368    }
369
370    #[test]
371    fn groq_inherits_openai_family_only() {
372        reset();
373        let caps = lookup("groq", "gpt-5.5-preview");
374        assert!(caps.defer_loading);
375    }
376
377    #[test]
378    fn mock_with_claude_model_routes_to_anthropic() {
379        reset();
380        let caps = lookup("mock", "claude-sonnet-4-7");
381        assert!(caps.defer_loading);
382        assert_eq!(caps.tool_search, vec!["bm25", "regex"]);
383    }
384
385    #[test]
386    fn mock_with_gpt_model_routes_to_openai() {
387        reset();
388        let caps = lookup("mock", "gpt-5.4-preview");
389        assert!(caps.defer_loading);
390        assert_eq!(caps.tool_search, vec!["hosted", "client"]);
391    }
392
393    #[test]
394    fn unknown_provider_has_no_capabilities() {
395        reset();
396        let caps = lookup("my-custom-proxy", "foo-bar-1");
397        assert!(!caps.native_tools);
398        assert!(!caps.defer_loading);
399        assert!(caps.tool_search.is_empty());
400    }
401
402    #[test]
403    fn user_override_adds_new_provider() {
404        reset();
405        let toml_src = r#"
406[[provider.my-proxy]]
407model_match = "*"
408native_tools = true
409tool_search = ["hosted"]
410"#;
411        set_user_overrides_toml(toml_src).unwrap();
412        let caps = lookup("my-proxy", "anything");
413        assert!(caps.native_tools);
414        assert_eq!(caps.tool_search, vec!["hosted"]);
415        clear_user_overrides();
416    }
417
418    #[test]
419    fn user_override_takes_precedence_over_builtin() {
420        reset();
421        let toml_src = r#"
422[[provider.anthropic]]
423model_match = "claude-opus-*"
424native_tools = true
425defer_loading = false
426tool_search = []
427"#;
428        set_user_overrides_toml(toml_src).unwrap();
429        let caps = lookup("anthropic", "claude-opus-4-7");
430        assert!(caps.native_tools);
431        assert!(!caps.defer_loading);
432        assert!(caps.tool_search.is_empty());
433        clear_user_overrides();
434    }
435
436    #[test]
437    fn user_override_from_manifest_toml() {
438        reset();
439        let manifest = r#"
440[package]
441name = "demo"
442
443[[capabilities.provider.my-proxy]]
444model_match = "*"
445native_tools = true
446tool_search = ["hosted"]
447"#;
448        set_user_overrides_from_manifest_toml(manifest).unwrap();
449        let caps = lookup("my-proxy", "foo");
450        assert!(caps.native_tools);
451        assert_eq!(caps.tool_search, vec!["hosted"]);
452        clear_user_overrides();
453    }
454
455    #[test]
456    fn version_min_requires_parseable_model() {
457        reset();
458        let toml_src = r#"
459[[provider.custom]]
460model_match = "*"
461version_min = [5, 4]
462native_tools = true
463"#;
464        set_user_overrides_toml(toml_src).unwrap();
465        // Unparseable model ID + version_min → rule doesn't match.
466        let caps = lookup("custom", "mystery-model");
467        assert!(!caps.native_tools);
468        clear_user_overrides();
469    }
470
471    #[test]
472    fn glob_match_substring() {
473        assert!(glob_match("*gpt*", "openai/gpt-5.4"));
474        assert!(glob_match("*claude*", "anthropic/claude-opus-4-7"));
475        assert!(!glob_match("*xyz*", "openai/gpt-5.4"));
476    }
477
478    #[test]
479    fn openrouter_namespaced_anthropic_model() {
480        reset();
481        let caps = lookup("anthropic", "anthropic/claude-opus-4-7");
482        assert!(caps.defer_loading);
483    }
484}