Skip to main content

llm_kernel/provider/
capability.rs

1use super::catalog::ServiceDescriptor;
2
3/// How a provider authenticates API requests.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum AuthStrategy {
6    /// No authentication required (e.g. local Ollama).
7    None,
8    /// Hardcoded token in the catalog (e.g. a free-tier key).
9    Literal,
10    /// Read secret from environment variable.
11    Secret,
12    /// Unknown authentication mode.
13    Unknown,
14}
15
16/// Capability profile for a provider — determines auth strategy and feature support.
17pub trait CapabilityProfile {
18    /// How the provider authenticates API requests.
19    fn auth_strategy(&self) -> AuthStrategy;
20    /// Whether the Anthropic API key should be cleared before calling this provider.
21    fn clears_anthropic_api_key(&self) -> bool;
22    /// Whether the provider supports model tiers (e.g. fast vs. powerful model aliases).
23    fn supports_model_tiers(&self) -> bool;
24    /// Returns `true` if any model offered by this provider supports tool/function calling.
25    fn supports_tool_calling(&self) -> bool {
26        false
27    }
28    /// Returns `true` if any model offered by this provider accepts image input.
29    fn supports_vision(&self) -> bool {
30        false
31    }
32    /// Returns `true` if the provider supports streaming completions.
33    fn supports_streaming(&self) -> bool {
34        true
35    }
36    /// Maximum context window in tokens across all models, or `None` if unknown.
37    fn context_limit(&self) -> Option<u64> {
38        None
39    }
40}
41
42fn auth_mode_to_strategy(value: &str) -> AuthStrategy {
43    match value {
44        "none" => AuthStrategy::None,
45        "literal" => AuthStrategy::Literal,
46        "secret" => AuthStrategy::Secret,
47        _ => AuthStrategy::Unknown,
48    }
49}
50
51fn clears_api_key_for_family(family: &str) -> bool {
52    matches!(family, "openrouter" | "local" | "custom_unknown")
53}
54
55fn supports_tiers_for_family(family: &str) -> bool {
56    !matches!(family, "claude_strict")
57}
58
59impl CapabilityProfile for ServiceDescriptor {
60    fn auth_strategy(&self) -> AuthStrategy {
61        auth_mode_to_strategy(&self.auth_mode)
62    }
63
64    fn clears_anthropic_api_key(&self) -> bool {
65        clears_api_key_for_family(&self.family)
66    }
67
68    fn supports_model_tiers(&self) -> bool {
69        supports_tiers_for_family(&self.family)
70    }
71
72    fn supports_tool_calling(&self) -> bool {
73        self.models
74            .iter()
75            .any(|m| m.capabilities.as_ref().is_some_and(|c| c.tool_call))
76    }
77
78    fn supports_vision(&self) -> bool {
79        self.models.iter().any(|m| {
80            m.modalities
81                .as_ref()
82                .is_some_and(|md| md.input.iter().any(|i| i == "image"))
83        })
84    }
85
86    fn supports_streaming(&self) -> bool {
87        self.models
88            .iter()
89            .any(|m| m.capabilities.as_ref().is_none_or(|c| c.streaming))
90    }
91
92    fn context_limit(&self) -> Option<u64> {
93        self.models
94            .iter()
95            .filter_map(|m| m.limit.as_ref().map(|l| l.context))
96            .max()
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use std::collections::HashMap;
104
105    fn make_descriptor(auth_mode: &str, family: &str) -> ServiceDescriptor {
106        ServiceDescriptor {
107            id: "test".to_string(),
108            display_name: "Test".to_string(),
109            description: String::new(),
110            category: "test".to_string(),
111            family: family.to_string(),
112            auth_mode: auth_mode.to_string(),
113            key_var: String::new(),
114            literal_auth_token: String::new(),
115            base_url: String::new(),
116            default_model: String::new(),
117            model_tiers: HashMap::new(),
118            model_choices: vec![],
119            test_url: String::new(),
120            setup: vec![],
121            usage: vec![],
122            api_base_url: None,
123            npm_package: None,
124            doc_url: None,
125            models: vec![],
126        }
127    }
128
129    #[test]
130    fn test_auth_mode_mapping() {
131        assert_eq!(auth_mode_to_strategy("none"), AuthStrategy::None);
132        assert_eq!(auth_mode_to_strategy("literal"), AuthStrategy::Literal);
133        assert_eq!(auth_mode_to_strategy("secret"), AuthStrategy::Secret);
134        assert_eq!(auth_mode_to_strategy("other"), AuthStrategy::Unknown);
135    }
136
137    #[test]
138    fn test_secret_provider_capability() {
139        let desc = make_descriptor("secret", "openrouter");
140        assert_eq!(desc.auth_strategy(), AuthStrategy::Secret);
141        assert!(desc.clears_anthropic_api_key());
142        assert!(desc.supports_model_tiers());
143    }
144
145    #[test]
146    fn test_claude_strict_invariants() {
147        assert!(!clears_api_key_for_family("claude_strict"));
148        assert!(!supports_tiers_for_family("claude_strict"));
149    }
150
151    #[test]
152    fn test_openrouter_invariants() {
153        assert!(clears_api_key_for_family("openrouter"));
154        assert!(supports_tiers_for_family("openrouter"));
155    }
156
157    #[test]
158    fn test_local_family_clears_api_key() {
159        assert!(clears_api_key_for_family("local"));
160    }
161
162    #[test]
163    fn test_none_auth_strategy() {
164        let desc = make_descriptor("none", "local");
165        assert_eq!(desc.auth_strategy(), AuthStrategy::None);
166    }
167}