use std::collections::HashMap;
use crate::runtime::ai::strict_validator::Mode;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Capabilities {
pub supports_citations: bool,
pub supports_seed: bool,
pub supports_temperature_zero: bool,
pub supports_streaming: bool,
}
impl Capabilities {
pub const fn conservative() -> Self {
Self {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: true,
supports_streaming: false,
}
}
pub fn for_provider(token: &str) -> Self {
match token {
"openai" => Self {
supports_citations: true,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
},
"anthropic" => Self {
supports_citations: true,
supports_seed: false,
supports_temperature_zero: true,
supports_streaming: true,
},
"groq" | "together" | "openrouter" | "venice" | "deepseek" => Self {
supports_citations: true,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
},
"ollama" => Self {
supports_citations: false,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
},
"huggingface" => Self {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: true,
supports_streaming: false,
},
"local" => Self {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: false,
supports_streaming: false,
},
"custom" => Self::conservative(),
_ => Self::conservative(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModeWarning {
pub kind: ModeWarningKind,
pub detail: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModeWarningKind {
ModeFallback,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ModeOutcome {
Allowed { effective: Mode },
Fallback {
effective: Mode,
warning: ModeWarning,
},
}
impl ModeOutcome {
pub fn effective(&self) -> Mode {
match self {
Self::Allowed { effective } | Self::Fallback { effective, .. } => *effective,
}
}
pub fn warning(&self) -> Option<&ModeWarning> {
match self {
Self::Allowed { .. } => None,
Self::Fallback { warning, .. } => Some(warning),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Registry {
overrides: HashMap<String, Capabilities>,
}
impl Registry {
pub fn new() -> Self {
Self {
overrides: HashMap::new(),
}
}
pub fn with_override(mut self, token: &str, caps: Capabilities) -> Self {
self.overrides.insert(token.to_ascii_lowercase(), caps);
self
}
pub fn capabilities(&self, token: &str) -> Capabilities {
let key = token.to_ascii_lowercase();
if let Some(c) = self.overrides.get(&key) {
return *c;
}
Capabilities::for_provider(&key)
}
pub fn evaluate_mode(&self, token: &str, requested: Mode) -> ModeOutcome {
if requested == Mode::Lenient {
return ModeOutcome::Allowed {
effective: Mode::Lenient,
};
}
let caps = self.capabilities(token);
if caps.supports_citations {
return ModeOutcome::Allowed {
effective: Mode::Strict,
};
}
ModeOutcome::Fallback {
effective: Mode::Lenient,
warning: ModeWarning {
kind: ModeWarningKind::ModeFallback,
detail: format!(
"provider '{}' does not support reliable citation emission; \
strict mode downgraded to lenient",
token.to_ascii_lowercase()
),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn conservative_defaults_match_ac() {
let c = Capabilities::conservative();
assert!(!c.supports_citations);
assert!(!c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(!c.supports_streaming);
}
#[test]
fn openai_supports_everything() {
let c = Capabilities::for_provider("openai");
assert!(c.supports_citations);
assert!(c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(c.supports_streaming);
}
#[test]
fn anthropic_no_seed() {
let c = Capabilities::for_provider("anthropic");
assert!(c.supports_citations);
assert!(!c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(c.supports_streaming);
}
#[test]
fn openai_compatible_family_uniform() {
for token in ["groq", "together", "openrouter", "venice", "deepseek"] {
let c = Capabilities::for_provider(token);
assert!(c.supports_citations, "{token} citations");
assert!(c.supports_seed, "{token} seed");
assert!(c.supports_temperature_zero, "{token} temp0");
assert!(c.supports_streaming, "{token} streaming");
}
}
#[test]
fn ollama_no_citations_but_seed_and_streaming() {
let c = Capabilities::for_provider("ollama");
assert!(!c.supports_citations);
assert!(c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(c.supports_streaming);
}
#[test]
fn huggingface_inference_no_seed_no_streaming() {
let c = Capabilities::for_provider("huggingface");
assert!(!c.supports_citations);
assert!(!c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(!c.supports_streaming);
}
#[test]
fn local_backend_has_no_temperature() {
let c = Capabilities::for_provider("local");
assert!(!c.supports_citations);
assert!(!c.supports_seed);
assert!(!c.supports_temperature_zero);
assert!(!c.supports_streaming);
}
#[test]
fn custom_is_conservative() {
assert_eq!(
Capabilities::for_provider("custom"),
Capabilities::conservative()
);
}
#[test]
fn unknown_token_is_conservative() {
assert_eq!(
Capabilities::for_provider("totally-made-up"),
Capabilities::conservative()
);
}
#[test]
fn token_lookup_is_case_insensitive_via_registry() {
let r = Registry::new();
assert_eq!(
r.capabilities("OPENAI"),
Capabilities::for_provider("openai")
);
assert_eq!(
r.capabilities("OpenAi"),
Capabilities::for_provider("openai")
);
}
#[test]
fn override_completely_replaces_builtin_row() {
let overridden = Capabilities {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: false,
supports_streaming: false,
};
let r = Registry::new().with_override("openai", overridden);
assert_eq!(r.capabilities("openai"), overridden);
assert_eq!(r.capabilities("groq"), Capabilities::for_provider("groq"));
}
#[test]
fn override_key_is_lowercased() {
let custom_caps = Capabilities {
supports_citations: true,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
};
let r = Registry::new().with_override("CUSTOM-INTERNAL", custom_caps);
assert_eq!(r.capabilities("custom-internal"), custom_caps);
assert_eq!(r.capabilities("Custom-Internal"), custom_caps);
}
#[test]
fn lenient_always_allowed_regardless_of_provider() {
let r = Registry::new();
for token in ["openai", "huggingface", "local", "totally-made-up"] {
let outcome = r.evaluate_mode(token, Mode::Lenient);
assert_eq!(
outcome,
ModeOutcome::Allowed {
effective: Mode::Lenient
},
"lenient should pass through for {token}"
);
assert!(outcome.warning().is_none());
}
}
#[test]
fn strict_allowed_for_citing_provider() {
let r = Registry::new();
let outcome = r.evaluate_mode("openai", Mode::Strict);
assert_eq!(
outcome,
ModeOutcome::Allowed {
effective: Mode::Strict
}
);
assert!(outcome.warning().is_none());
}
#[test]
fn strict_downgraded_for_non_citing_provider() {
let r = Registry::new();
let outcome = r.evaluate_mode("huggingface", Mode::Strict);
match outcome {
ModeOutcome::Fallback {
effective,
ref warning,
} => {
assert_eq!(effective, Mode::Lenient);
assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
assert!(warning.detail.contains("huggingface"));
assert!(warning.detail.contains("strict"));
}
other => panic!("expected Fallback, got {other:?}"),
}
assert_eq!(outcome.effective(), Mode::Lenient);
assert!(outcome.warning().is_some());
}
#[test]
fn strict_downgraded_for_unknown_provider() {
let r = Registry::new();
let outcome = r.evaluate_mode("brand-new-provider", Mode::Strict);
assert_eq!(outcome.effective(), Mode::Lenient);
match outcome {
ModeOutcome::Fallback { warning, .. } => {
assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
assert!(warning.detail.contains("brand-new-provider"));
}
other => panic!("expected Fallback, got {other:?}"),
}
}
#[test]
fn override_can_upgrade_non_citing_provider_to_citing() {
let r = Registry::new().with_override(
"ollama",
Capabilities {
supports_citations: true,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
},
);
let outcome = r.evaluate_mode("ollama", Mode::Strict);
assert_eq!(
outcome,
ModeOutcome::Allowed {
effective: Mode::Strict
}
);
}
#[test]
fn override_can_downgrade_citing_provider_to_non_citing() {
let r = Registry::new().with_override(
"openai",
Capabilities {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: true,
supports_streaming: false,
},
);
let outcome = r.evaluate_mode("openai", Mode::Strict);
match outcome {
ModeOutcome::Fallback {
effective,
ref warning,
} => {
assert_eq!(effective, Mode::Lenient);
assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
assert!(warning.detail.contains("openai"));
}
other => panic!("expected Fallback, got {other:?}"),
}
}
#[test]
fn evaluate_mode_is_deterministic() {
let r = Registry::new();
for _ in 0..16 {
assert_eq!(
r.evaluate_mode("openai", Mode::Strict),
ModeOutcome::Allowed {
effective: Mode::Strict
}
);
assert_eq!(
r.evaluate_mode("huggingface", Mode::Strict).effective(),
Mode::Lenient
);
}
}
#[test]
fn all_eleven_provider_tokens_have_explicit_rows() {
let citing = [
"openai",
"anthropic",
"groq",
"together",
"openrouter",
"venice",
"deepseek",
];
let non_citing = ["ollama", "huggingface", "local"];
for t in citing {
assert!(
Capabilities::for_provider(t).supports_citations,
"{t} should cite"
);
}
for t in non_citing {
assert!(
!Capabilities::for_provider(t).supports_citations,
"{t} should not cite"
);
}
assert_eq!(
Capabilities::for_provider("custom"),
Capabilities::conservative()
);
}
}