Skip to main content

reddb_server/runtime/ai/
provider_capabilities.rs

1//! `ProviderCapabilityRegistry` — pure provider capability lookup.
2//!
3//! Issue #396 (PRD #391): which LLM providers can reliably honor the
4//! strict citation contract (#395), the deterministic seed (#400),
5//! `temperature=0`, and streaming responses (#405)?
6//!
7//! This is a deep module: no I/O, no transport, no LLM calls. Given a
8//! provider token (e.g. `"openai"`, `"ollama"`) and a caller-requested
9//! [`Mode`] (strict / lenient), it returns a [`ModeOutcome`] saying
10//! either "go ahead" or "the caller asked for strict but this provider
11//! can't honor it — fall back to lenient and surface a warning".
12//!
13//! The caller is responsible for:
14//! - actually surfacing the [`ModeWarning`] in the response envelope,
15//! - recording the *effective* mode (not the requested one) in the
16//!   audit row.
17//!
18//! ## Defaults
19//!
20//! Built-in capabilities (see [`Capabilities::for_provider`]) follow
21//! these rules of thumb:
22//!
23//! - **citations**: every provider that exposes a steerable chat
24//!   completion API can emit `[^N]` markers when the system prompt
25//!   asks for them. Raw-inference endpoints (HuggingFace Inference
26//!   API, the embedded `local` embeddings backend) cannot, and
27//!   small-model Ollama installs are not reliable either — those
28//!   default to `false`.
29//! - **seed**: any provider speaking the OpenAI-compatible
30//!   `seed` field — OpenAI, Groq, Together, OpenRouter, Venice,
31//!   DeepSeek, Ollama (≥0.1.30). Anthropic's API does not accept
32//!   `seed`, so it's `false` there even though the model is otherwise
33//!   capable.
34//! - **temperature_zero**: every chat provider in the list. `false`
35//!   only for the embedded `local` backend, which doesn't take a
36//!   temperature.
37//! - **streaming**: every chat provider that documents an SSE / event
38//!   stream. HuggingFace Inference returns one shot; `local` is
39//!   synchronous; `custom` is conservatively `false` since we cannot
40//!   know what the operator pointed at.
41//!
42//! Unknown tokens get the conservative defaults from
43//! [`Capabilities::conservative`] — citations off, seed off,
44//! temperature_zero on, streaming off. This is the safe baseline
45//! described in the issue's AC ("Unknown provider returns conservative
46//! defaults").
47//!
48//! ## Per-deployment overrides
49//!
50//! [`Registry`] holds a `HashMap` keyed by lower-cased token. An entry
51//! supplied via [`Registry::with_override`] completely replaces the
52//! built-in row for that token — there is no partial-merge, since the
53//! settings surface in #401 uses one TOML table per provider. Callers
54//! that want partial overrides should construct the merged
55//! [`Capabilities`] themselves.
56
57use std::collections::HashMap;
58
59use crate::runtime::ai::strict_validator::Mode;
60
61/// Per-provider capability bag. Each flag is independently testable.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub struct Capabilities {
64    /// Provider can emit `[^N]` markers reliably enough to honor the
65    /// strict citation contract (#395).
66    pub supports_citations: bool,
67    /// Provider honors a `seed` parameter, enabling reproducible
68    /// completions when paired with `temperature=0` (#400).
69    pub supports_seed: bool,
70    /// Provider accepts `temperature=0`. Endpoints that don't take a
71    /// temperature at all (e.g. embedded embeddings) report `false`.
72    pub supports_temperature_zero: bool,
73    /// Provider exposes a streaming response (SSE / chunked / ws). Set
74    /// to `false` for synchronous inference endpoints.
75    pub supports_streaming: bool,
76}
77
78impl Capabilities {
79    /// Defaults for a provider the registry has no row for. Picked to
80    /// match the AC: "Unknown provider returns conservative defaults
81    /// (no citation support, no seed)".
82    pub const fn conservative() -> Self {
83        Self {
84            supports_citations: false,
85            supports_seed: false,
86            supports_temperature_zero: true,
87            supports_streaming: false,
88        }
89    }
90
91    /// Built-in capability row for a canonical provider token (the
92    /// `AiProvider::token()` form: `"openai"`, `"anthropic"`, …).
93    /// Unknown tokens get [`Capabilities::conservative`].
94    pub fn for_provider(token: &str) -> Self {
95        match token {
96            "openai" => Self {
97                supports_citations: true,
98                supports_seed: true,
99                supports_temperature_zero: true,
100                supports_streaming: true,
101            },
102            "anthropic" => Self {
103                supports_citations: true,
104                supports_seed: false,
105                supports_temperature_zero: true,
106                supports_streaming: true,
107            },
108            "groq" | "together" | "openrouter" | "venice" | "deepseek" => Self {
109                supports_citations: true,
110                supports_seed: true,
111                supports_temperature_zero: true,
112                supports_streaming: true,
113            },
114            "ollama" => Self {
115                supports_citations: false,
116                supports_seed: true,
117                supports_temperature_zero: true,
118                supports_streaming: true,
119            },
120            "huggingface" => Self {
121                supports_citations: false,
122                supports_seed: false,
123                supports_temperature_zero: true,
124                supports_streaming: false,
125            },
126            "local" => Self {
127                supports_citations: false,
128                supports_seed: false,
129                supports_temperature_zero: false,
130                supports_streaming: false,
131            },
132            "custom" => Self::conservative(),
133            _ => Self::conservative(),
134        }
135    }
136}
137
138/// Why the effective mode differs from the requested mode. The caller
139/// surfaces this as a structured warning entry on the ASK response.
140#[derive(Debug, Clone, PartialEq, Eq)]
141pub struct ModeWarning {
142    /// Stable identifier — drivers can branch on this.
143    pub kind: ModeWarningKind,
144    /// Human-readable explanation including the provider token.
145    pub detail: String,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum ModeWarningKind {
150    /// Strict was requested but the provider's `supports_citations`
151    /// is `false`. Effective mode is [`Mode::Lenient`].
152    ModeFallback,
153}
154
155/// Result of consulting the registry for a strict-mode request.
156#[derive(Debug, Clone, PartialEq, Eq)]
157pub enum ModeOutcome {
158    /// Caller's requested mode is honored verbatim.
159    Allowed { effective: Mode },
160    /// Strict was downgraded to lenient. The caller MUST record the
161    /// `effective` mode (not the requested one) and surface
162    /// `warning`.
163    Fallback {
164        effective: Mode,
165        warning: ModeWarning,
166    },
167}
168
169impl ModeOutcome {
170    /// The mode the caller should actually run with.
171    pub fn effective(&self) -> Mode {
172        match self {
173            Self::Allowed { effective } | Self::Fallback { effective, .. } => *effective,
174        }
175    }
176
177    /// Convenience for the audit log / response builder.
178    pub fn warning(&self) -> Option<&ModeWarning> {
179        match self {
180            Self::Allowed { .. } => None,
181            Self::Fallback { warning, .. } => Some(warning),
182        }
183    }
184}
185
186/// Capability registry with optional per-deployment overrides.
187///
188/// Construct via [`Registry::new`] for built-ins only, then layer
189/// overrides with [`Registry::with_override`]. Lookups go through
190/// [`Registry::capabilities`] (raw row) and [`Registry::evaluate_mode`]
191/// (strict-fallback policy).
192#[derive(Debug, Clone, Default)]
193pub struct Registry {
194    overrides: HashMap<String, Capabilities>,
195}
196
197impl Registry {
198    /// Empty registry. Built-in defaults are still applied to every
199    /// lookup — this constructor just means "no per-deployment
200    /// overrides yet".
201    pub fn new() -> Self {
202        Self {
203            overrides: HashMap::new(),
204        }
205    }
206
207    /// Replace the capability row for `token` (lower-cased before
208    /// storage). Returns `self` for builder-style chaining in tests.
209    pub fn with_override(mut self, token: &str, caps: Capabilities) -> Self {
210        self.overrides.insert(token.to_ascii_lowercase(), caps);
211        self
212    }
213
214    /// Look up the capability row for a provider token, applying any
215    /// override on top of the built-in row.
216    pub fn capabilities(&self, token: &str) -> Capabilities {
217        let key = token.to_ascii_lowercase();
218        if let Some(c) = self.overrides.get(&key) {
219            return *c;
220        }
221        Capabilities::for_provider(&key)
222    }
223
224    /// Decide what mode the caller should actually run in, given the
225    /// requested mode and this provider's capabilities.
226    ///
227    /// Strict against a non-citing provider transparently degrades to
228    /// lenient with a `mode_fallback` warning. Lenient is always
229    /// allowed.
230    pub fn evaluate_mode(&self, token: &str, requested: Mode) -> ModeOutcome {
231        if requested == Mode::Lenient {
232            return ModeOutcome::Allowed {
233                effective: Mode::Lenient,
234            };
235        }
236        let caps = self.capabilities(token);
237        if caps.supports_citations {
238            return ModeOutcome::Allowed {
239                effective: Mode::Strict,
240            };
241        }
242        ModeOutcome::Fallback {
243            effective: Mode::Lenient,
244            warning: ModeWarning {
245                kind: ModeWarningKind::ModeFallback,
246                detail: format!(
247                    "provider '{}' does not support reliable citation emission; \
248                     strict mode downgraded to lenient",
249                    token.to_ascii_lowercase()
250                ),
251            },
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn conservative_defaults_match_ac() {
262        let c = Capabilities::conservative();
263        assert!(!c.supports_citations);
264        assert!(!c.supports_seed);
265        assert!(c.supports_temperature_zero);
266        assert!(!c.supports_streaming);
267    }
268
269    #[test]
270    fn openai_supports_everything() {
271        let c = Capabilities::for_provider("openai");
272        assert!(c.supports_citations);
273        assert!(c.supports_seed);
274        assert!(c.supports_temperature_zero);
275        assert!(c.supports_streaming);
276    }
277
278    #[test]
279    fn anthropic_no_seed() {
280        let c = Capabilities::for_provider("anthropic");
281        assert!(c.supports_citations);
282        assert!(!c.supports_seed);
283        assert!(c.supports_temperature_zero);
284        assert!(c.supports_streaming);
285    }
286
287    #[test]
288    fn openai_compatible_family_uniform() {
289        for token in ["groq", "together", "openrouter", "venice", "deepseek"] {
290            let c = Capabilities::for_provider(token);
291            assert!(c.supports_citations, "{token} citations");
292            assert!(c.supports_seed, "{token} seed");
293            assert!(c.supports_temperature_zero, "{token} temp0");
294            assert!(c.supports_streaming, "{token} streaming");
295        }
296    }
297
298    #[test]
299    fn ollama_no_citations_but_seed_and_streaming() {
300        let c = Capabilities::for_provider("ollama");
301        assert!(!c.supports_citations);
302        assert!(c.supports_seed);
303        assert!(c.supports_temperature_zero);
304        assert!(c.supports_streaming);
305    }
306
307    #[test]
308    fn huggingface_inference_no_seed_no_streaming() {
309        let c = Capabilities::for_provider("huggingface");
310        assert!(!c.supports_citations);
311        assert!(!c.supports_seed);
312        assert!(c.supports_temperature_zero);
313        assert!(!c.supports_streaming);
314    }
315
316    #[test]
317    fn local_backend_has_no_temperature() {
318        let c = Capabilities::for_provider("local");
319        assert!(!c.supports_citations);
320        assert!(!c.supports_seed);
321        assert!(!c.supports_temperature_zero);
322        assert!(!c.supports_streaming);
323    }
324
325    #[test]
326    fn custom_is_conservative() {
327        assert_eq!(
328            Capabilities::for_provider("custom"),
329            Capabilities::conservative()
330        );
331    }
332
333    #[test]
334    fn unknown_token_is_conservative() {
335        assert_eq!(
336            Capabilities::for_provider("totally-made-up"),
337            Capabilities::conservative()
338        );
339    }
340
341    #[test]
342    fn token_lookup_is_case_insensitive_via_registry() {
343        let r = Registry::new();
344        // Built-in path lower-cases the token before consulting the
345        // match arm, so OPENAI / OpenAI / openai all resolve.
346        assert_eq!(
347            r.capabilities("OPENAI"),
348            Capabilities::for_provider("openai")
349        );
350        assert_eq!(
351            r.capabilities("OpenAi"),
352            Capabilities::for_provider("openai")
353        );
354    }
355
356    #[test]
357    fn override_completely_replaces_builtin_row() {
358        let overridden = Capabilities {
359            supports_citations: false,
360            supports_seed: false,
361            supports_temperature_zero: false,
362            supports_streaming: false,
363        };
364        let r = Registry::new().with_override("openai", overridden);
365        assert_eq!(r.capabilities("openai"), overridden);
366        // Unrelated providers are untouched.
367        assert_eq!(r.capabilities("groq"), Capabilities::for_provider("groq"));
368    }
369
370    #[test]
371    fn override_key_is_lowercased() {
372        let custom_caps = Capabilities {
373            supports_citations: true,
374            supports_seed: true,
375            supports_temperature_zero: true,
376            supports_streaming: true,
377        };
378        let r = Registry::new().with_override("CUSTOM-INTERNAL", custom_caps);
379        // Stored lower-cased so lookup with any case finds it.
380        assert_eq!(r.capabilities("custom-internal"), custom_caps);
381        assert_eq!(r.capabilities("Custom-Internal"), custom_caps);
382    }
383
384    #[test]
385    fn lenient_always_allowed_regardless_of_provider() {
386        let r = Registry::new();
387        for token in ["openai", "huggingface", "local", "totally-made-up"] {
388            let outcome = r.evaluate_mode(token, Mode::Lenient);
389            assert_eq!(
390                outcome,
391                ModeOutcome::Allowed {
392                    effective: Mode::Lenient
393                },
394                "lenient should pass through for {token}"
395            );
396            assert!(outcome.warning().is_none());
397        }
398    }
399
400    #[test]
401    fn strict_allowed_for_citing_provider() {
402        let r = Registry::new();
403        let outcome = r.evaluate_mode("openai", Mode::Strict);
404        assert_eq!(
405            outcome,
406            ModeOutcome::Allowed {
407                effective: Mode::Strict
408            }
409        );
410        assert!(outcome.warning().is_none());
411    }
412
413    #[test]
414    fn strict_downgraded_for_non_citing_provider() {
415        let r = Registry::new();
416        let outcome = r.evaluate_mode("huggingface", Mode::Strict);
417        match outcome {
418            ModeOutcome::Fallback {
419                effective,
420                ref warning,
421            } => {
422                assert_eq!(effective, Mode::Lenient);
423                assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
424                assert!(warning.detail.contains("huggingface"));
425                assert!(warning.detail.contains("strict"));
426            }
427            other => panic!("expected Fallback, got {other:?}"),
428        }
429        assert_eq!(outcome.effective(), Mode::Lenient);
430        assert!(outcome.warning().is_some());
431    }
432
433    #[test]
434    fn strict_downgraded_for_unknown_provider() {
435        let r = Registry::new();
436        let outcome = r.evaluate_mode("brand-new-provider", Mode::Strict);
437        assert_eq!(outcome.effective(), Mode::Lenient);
438        match outcome {
439            ModeOutcome::Fallback { warning, .. } => {
440                assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
441                assert!(warning.detail.contains("brand-new-provider"));
442            }
443            other => panic!("expected Fallback, got {other:?}"),
444        }
445    }
446
447    #[test]
448    fn override_can_upgrade_non_citing_provider_to_citing() {
449        let r = Registry::new().with_override(
450            "ollama",
451            Capabilities {
452                supports_citations: true,
453                supports_seed: true,
454                supports_temperature_zero: true,
455                supports_streaming: true,
456            },
457        );
458        let outcome = r.evaluate_mode("ollama", Mode::Strict);
459        assert_eq!(
460            outcome,
461            ModeOutcome::Allowed {
462                effective: Mode::Strict
463            }
464        );
465    }
466
467    #[test]
468    fn override_can_downgrade_citing_provider_to_non_citing() {
469        let r = Registry::new().with_override(
470            "openai",
471            Capabilities {
472                supports_citations: false,
473                supports_seed: false,
474                supports_temperature_zero: true,
475                supports_streaming: false,
476            },
477        );
478        let outcome = r.evaluate_mode("openai", Mode::Strict);
479        match outcome {
480            ModeOutcome::Fallback {
481                effective,
482                ref warning,
483            } => {
484                assert_eq!(effective, Mode::Lenient);
485                assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
486                assert!(warning.detail.contains("openai"));
487            }
488            other => panic!("expected Fallback, got {other:?}"),
489        }
490    }
491
492    #[test]
493    fn evaluate_mode_is_deterministic() {
494        let r = Registry::new();
495        for _ in 0..16 {
496            assert_eq!(
497                r.evaluate_mode("openai", Mode::Strict),
498                ModeOutcome::Allowed {
499                    effective: Mode::Strict
500                }
501            );
502            assert_eq!(
503                r.evaluate_mode("huggingface", Mode::Strict).effective(),
504                Mode::Lenient
505            );
506        }
507    }
508
509    #[test]
510    fn all_eleven_provider_tokens_have_explicit_rows() {
511        // The registry should have a non-conservative row for every
512        // built-in provider (10 explicit + custom returns conservative
513        // by design). Pin so adding/removing a provider in
514        // `AiProvider` is a deliberate decision.
515        let citing = [
516            "openai",
517            "anthropic",
518            "groq",
519            "together",
520            "openrouter",
521            "venice",
522            "deepseek",
523        ];
524        let non_citing = ["ollama", "huggingface", "local"];
525        for t in citing {
526            assert!(
527                Capabilities::for_provider(t).supports_citations,
528                "{t} should cite"
529            );
530        }
531        for t in non_citing {
532            assert!(
533                !Capabilities::for_provider(t).supports_citations,
534                "{t} should not cite"
535            );
536        }
537        // Custom is the 11th, and explicitly conservative.
538        assert_eq!(
539            Capabilities::for_provider("custom"),
540            Capabilities::conservative()
541        );
542    }
543}