use std::collections::HashMap;
use crate::runtime::ai::strict_validator::Mode;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Capabilities {
pub supports_citations: bool,
pub supports_seed: bool,
pub supports_temperature_zero: bool,
pub supports_streaming: bool,
}
impl Capabilities {
pub const fn conservative() -> Self {
Self {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: true,
supports_streaming: false,
}
}
pub fn for_provider(token: &str) -> Self {
match token {
"openai" => Self {
supports_citations: true,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
},
"anthropic" => Self {
supports_citations: true,
supports_seed: false,
supports_temperature_zero: true,
supports_streaming: true,
},
"groq" | "together" | "openrouter" | "venice" | "deepseek" => Self {
supports_citations: true,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
},
"ollama" => Self {
supports_citations: false,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
},
"huggingface" => Self {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: true,
supports_streaming: false,
},
"local" => Self {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: false,
supports_streaming: false,
},
"custom" => Self::conservative(),
_ => Self::conservative(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Modality {
Embed,
Generate,
Vision,
Moderate,
}
impl Modality {
pub fn token(self) -> &'static str {
match self {
Self::Embed => "embed",
Self::Generate => "generate",
Self::Vision => "vision",
Self::Moderate => "moderate",
}
}
pub fn parse(token: &str) -> Option<Self> {
match token.trim().to_ascii_lowercase().as_str() {
"embed" | "embedding" | "embeddings" => Some(Self::Embed),
"generate" | "generation" | "chat" | "completion" => Some(Self::Generate),
"vision" | "image" | "multimodal" => Some(Self::Vision),
"moderate" | "moderation" => Some(Self::Moderate),
_ => None,
}
}
pub const ALL: [Self; 4] = [Self::Embed, Self::Generate, Self::Vision, Self::Moderate];
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Modalities {
pub embed: bool,
pub generate: bool,
pub vision: bool,
pub moderate: bool,
}
impl Modalities {
pub const fn conservative() -> Self {
Self {
embed: true,
generate: true,
vision: false,
moderate: false,
}
}
pub fn supports(&self, modality: Modality) -> bool {
match modality {
Modality::Embed => self.embed,
Modality::Generate => self.generate,
Modality::Vision => self.vision,
Modality::Moderate => self.moderate,
}
}
pub fn for_provider(token: &str) -> Self {
match token {
"openai" => Self {
embed: true,
generate: true,
vision: true,
moderate: true,
},
"anthropic" => Self {
embed: false,
generate: true,
vision: true,
moderate: false,
},
"minimax" | "together" | "ollama" => Self {
embed: true,
generate: true,
vision: true,
moderate: false,
},
"groq" | "openrouter" | "venice" => Self {
embed: false,
generate: true,
vision: true,
moderate: false,
},
"deepseek" => Self {
embed: false,
generate: true,
vision: false,
moderate: false,
},
"huggingface" => Self {
embed: true,
generate: true,
vision: false,
moderate: false,
},
"local" => Self {
embed: true,
generate: false,
vision: true,
moderate: true,
},
"custom" => Self::conservative(),
_ => Self::conservative(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModalityValidationError {
pub provider: String,
pub model: String,
pub modality: Modality,
pub message: String,
}
impl std::fmt::Display for ModalityValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for ModalityValidationError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModeWarning {
pub kind: ModeWarningKind,
pub detail: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModeWarningKind {
ModeFallback,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ModeOutcome {
Allowed { effective: Mode },
Fallback {
effective: Mode,
warning: ModeWarning,
},
}
impl ModeOutcome {
pub fn effective(&self) -> Mode {
match self {
Self::Allowed { effective } | Self::Fallback { effective, .. } => *effective,
}
}
pub fn warning(&self) -> Option<&ModeWarning> {
match self {
Self::Allowed { .. } => None,
Self::Fallback { warning, .. } => Some(warning),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Registry {
overrides: HashMap<String, Capabilities>,
modality_overrides: HashMap<String, Modalities>,
}
impl Registry {
pub fn new() -> Self {
Self {
overrides: HashMap::new(),
modality_overrides: HashMap::new(),
}
}
pub fn with_override(mut self, token: &str, caps: Capabilities) -> Self {
self.overrides.insert(token.to_ascii_lowercase(), caps);
self
}
pub fn capabilities(&self, token: &str) -> Capabilities {
let key = token.to_ascii_lowercase();
if let Some(c) = self.overrides.get(&key) {
return *c;
}
Capabilities::for_provider(&key)
}
pub fn with_modality_override(mut self, token: &str, modalities: Modalities) -> Self {
self.modality_overrides
.insert(token.to_ascii_lowercase(), modalities);
self
}
pub fn modalities(&self, token: &str) -> Modalities {
let key = token.to_ascii_lowercase();
if let Some(m) = self.modality_overrides.get(&key) {
return *m;
}
Modalities::for_provider(&key)
}
pub fn can_serve(&self, token: &str, _model: &str, modality: Modality) -> bool {
self.modalities(token).supports(modality)
}
pub fn validate_policy_modality(
&self,
provider: &str,
model: &str,
modality: Modality,
) -> Result<(), ModalityValidationError> {
if self.can_serve(provider, model, modality) {
return Ok(());
}
Err(ModalityValidationError {
provider: provider.to_string(),
model: model.to_string(),
modality,
message: format!(
"AI policy is invalid: provider '{}' (model '{}') cannot serve the '{}' modality; \
declare a provider that supports it or register a modality override",
provider.to_ascii_lowercase(),
model,
modality.token()
),
})
}
pub fn evaluate_mode(&self, token: &str, requested: Mode) -> ModeOutcome {
if requested == Mode::Lenient {
return ModeOutcome::Allowed {
effective: Mode::Lenient,
};
}
let caps = self.capabilities(token);
if caps.supports_citations {
return ModeOutcome::Allowed {
effective: Mode::Strict,
};
}
ModeOutcome::Fallback {
effective: Mode::Lenient,
warning: ModeWarning {
kind: ModeWarningKind::ModeFallback,
detail: format!(
"provider '{}' does not support reliable citation emission; \
strict mode downgraded to lenient",
token.to_ascii_lowercase()
),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn conservative_defaults_match_ac() {
let c = Capabilities::conservative();
assert!(!c.supports_citations);
assert!(!c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(!c.supports_streaming);
}
#[test]
fn openai_supports_everything() {
let c = Capabilities::for_provider("openai");
assert!(c.supports_citations);
assert!(c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(c.supports_streaming);
}
#[test]
fn anthropic_no_seed() {
let c = Capabilities::for_provider("anthropic");
assert!(c.supports_citations);
assert!(!c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(c.supports_streaming);
}
#[test]
fn openai_compatible_family_uniform() {
for token in ["groq", "together", "openrouter", "venice", "deepseek"] {
let c = Capabilities::for_provider(token);
assert!(c.supports_citations, "{token} citations");
assert!(c.supports_seed, "{token} seed");
assert!(c.supports_temperature_zero, "{token} temp0");
assert!(c.supports_streaming, "{token} streaming");
}
}
#[test]
fn ollama_no_citations_but_seed_and_streaming() {
let c = Capabilities::for_provider("ollama");
assert!(!c.supports_citations);
assert!(c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(c.supports_streaming);
}
#[test]
fn huggingface_inference_no_seed_no_streaming() {
let c = Capabilities::for_provider("huggingface");
assert!(!c.supports_citations);
assert!(!c.supports_seed);
assert!(c.supports_temperature_zero);
assert!(!c.supports_streaming);
}
#[test]
fn local_backend_has_no_temperature() {
let c = Capabilities::for_provider("local");
assert!(!c.supports_citations);
assert!(!c.supports_seed);
assert!(!c.supports_temperature_zero);
assert!(!c.supports_streaming);
}
#[test]
fn custom_is_conservative() {
assert_eq!(
Capabilities::for_provider("custom"),
Capabilities::conservative()
);
}
#[test]
fn unknown_token_is_conservative() {
assert_eq!(
Capabilities::for_provider("totally-made-up"),
Capabilities::conservative()
);
}
#[test]
fn token_lookup_is_case_insensitive_via_registry() {
let r = Registry::new();
assert_eq!(
r.capabilities("OPENAI"),
Capabilities::for_provider("openai")
);
assert_eq!(
r.capabilities("OpenAi"),
Capabilities::for_provider("openai")
);
}
#[test]
fn override_completely_replaces_builtin_row() {
let overridden = Capabilities {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: false,
supports_streaming: false,
};
let r = Registry::new().with_override("openai", overridden);
assert_eq!(r.capabilities("openai"), overridden);
assert_eq!(r.capabilities("groq"), Capabilities::for_provider("groq"));
}
#[test]
fn override_key_is_lowercased() {
let custom_caps = Capabilities {
supports_citations: true,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
};
let r = Registry::new().with_override("CUSTOM-INTERNAL", custom_caps);
assert_eq!(r.capabilities("custom-internal"), custom_caps);
assert_eq!(r.capabilities("Custom-Internal"), custom_caps);
}
#[test]
fn lenient_always_allowed_regardless_of_provider() {
let r = Registry::new();
for token in ["openai", "huggingface", "local", "totally-made-up"] {
let outcome = r.evaluate_mode(token, Mode::Lenient);
assert_eq!(
outcome,
ModeOutcome::Allowed {
effective: Mode::Lenient
},
"lenient should pass through for {token}"
);
assert!(outcome.warning().is_none());
}
}
#[test]
fn strict_allowed_for_citing_provider() {
let r = Registry::new();
let outcome = r.evaluate_mode("openai", Mode::Strict);
assert_eq!(
outcome,
ModeOutcome::Allowed {
effective: Mode::Strict
}
);
assert!(outcome.warning().is_none());
}
#[test]
fn strict_downgraded_for_non_citing_provider() {
let r = Registry::new();
let outcome = r.evaluate_mode("huggingface", Mode::Strict);
match outcome {
ModeOutcome::Fallback {
effective,
ref warning,
} => {
assert_eq!(effective, Mode::Lenient);
assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
assert!(warning.detail.contains("huggingface"));
assert!(warning.detail.contains("strict"));
}
other => panic!("expected Fallback, got {other:?}"),
}
assert_eq!(outcome.effective(), Mode::Lenient);
assert!(outcome.warning().is_some());
}
#[test]
fn strict_downgraded_for_unknown_provider() {
let r = Registry::new();
let outcome = r.evaluate_mode("brand-new-provider", Mode::Strict);
assert_eq!(outcome.effective(), Mode::Lenient);
match outcome {
ModeOutcome::Fallback { warning, .. } => {
assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
assert!(warning.detail.contains("brand-new-provider"));
}
other => panic!("expected Fallback, got {other:?}"),
}
}
#[test]
fn override_can_upgrade_non_citing_provider_to_citing() {
let r = Registry::new().with_override(
"ollama",
Capabilities {
supports_citations: true,
supports_seed: true,
supports_temperature_zero: true,
supports_streaming: true,
},
);
let outcome = r.evaluate_mode("ollama", Mode::Strict);
assert_eq!(
outcome,
ModeOutcome::Allowed {
effective: Mode::Strict
}
);
}
#[test]
fn override_can_downgrade_citing_provider_to_non_citing() {
let r = Registry::new().with_override(
"openai",
Capabilities {
supports_citations: false,
supports_seed: false,
supports_temperature_zero: true,
supports_streaming: false,
},
);
let outcome = r.evaluate_mode("openai", Mode::Strict);
match outcome {
ModeOutcome::Fallback {
effective,
ref warning,
} => {
assert_eq!(effective, Mode::Lenient);
assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
assert!(warning.detail.contains("openai"));
}
other => panic!("expected Fallback, got {other:?}"),
}
}
#[test]
fn evaluate_mode_is_deterministic() {
let r = Registry::new();
for _ in 0..16 {
assert_eq!(
r.evaluate_mode("openai", Mode::Strict),
ModeOutcome::Allowed {
effective: Mode::Strict
}
);
assert_eq!(
r.evaluate_mode("huggingface", Mode::Strict).effective(),
Mode::Lenient
);
}
}
#[test]
fn all_eleven_provider_tokens_have_explicit_rows() {
let citing = [
"openai",
"anthropic",
"groq",
"together",
"openrouter",
"venice",
"deepseek",
];
let non_citing = ["ollama", "huggingface", "local"];
for t in citing {
assert!(
Capabilities::for_provider(t).supports_citations,
"{t} should cite"
);
}
for t in non_citing {
assert!(
!Capabilities::for_provider(t).supports_citations,
"{t} should not cite"
);
}
assert_eq!(
Capabilities::for_provider("custom"),
Capabilities::conservative()
);
}
#[test]
fn modality_token_roundtrips_through_parse() {
for m in Modality::ALL {
assert_eq!(Modality::parse(m.token()), Some(m), "{m:?}");
}
assert_eq!(Modality::parse("EMBEDDING"), Some(Modality::Embed));
assert_eq!(Modality::parse("chat"), Some(Modality::Generate));
assert_eq!(Modality::parse("image"), Some(Modality::Vision));
assert_eq!(Modality::parse("moderation"), Some(Modality::Moderate));
assert_eq!(Modality::parse("nonsense"), None);
}
#[test]
fn unknown_provider_gets_conservative_modalities() {
let c = Modalities::for_provider("totally-made-up");
assert_eq!(c, Modalities::conservative());
assert!(c.embed);
assert!(c.generate);
assert!(!c.vision);
assert!(!c.moderate);
assert_eq!(
Modalities::for_provider("custom"),
Modalities::conservative()
);
}
#[test]
fn openai_serves_every_modality() {
let c = Modalities::for_provider("openai");
for m in Modality::ALL {
assert!(c.supports(m), "openai should serve {m:?}");
}
}
#[test]
fn minimax_serves_embed_generate_vision_not_moderate() {
let c = Modalities::for_provider("minimax");
assert!(c.supports(Modality::Embed));
assert!(c.supports(Modality::Generate));
assert!(c.supports(Modality::Vision));
assert!(!c.supports(Modality::Moderate));
}
#[test]
fn anthropic_cannot_embed() {
assert!(!Modalities::for_provider("anthropic").supports(Modality::Embed));
assert!(Modalities::for_provider("anthropic").supports(Modality::Generate));
}
#[test]
fn local_serves_embed_vision_and_moderate() {
let c = Modalities::for_provider("local");
assert!(c.supports(Modality::Embed));
assert!(c.supports(Modality::Vision));
assert!(c.supports(Modality::Moderate));
assert!(!c.supports(Modality::Generate));
}
#[test]
fn deepseek_is_generate_only() {
let c = Modalities::for_provider("deepseek");
assert!(c.supports(Modality::Generate));
assert!(!c.supports(Modality::Embed));
assert!(!c.supports(Modality::Vision));
assert!(!c.supports(Modality::Moderate));
}
#[test]
fn can_serve_is_case_insensitive_and_deterministic() {
let r = Registry::new();
for _ in 0..8 {
assert!(r.can_serve("OpenAI", "gpt-4o", Modality::Vision));
assert!(!r.can_serve("LOCAL", "all-MiniLM", Modality::Generate));
}
}
#[test]
fn validate_rejects_incapable_provider_modality() {
let r = Registry::new();
let err = r
.validate_policy_modality("local", "all-MiniLM-L6-v2", Modality::Generate)
.expect_err("local cannot generate");
assert_eq!(err.provider, "local");
assert_eq!(err.modality, Modality::Generate);
let msg = err.to_string();
assert!(msg.contains("local"), "{msg}");
assert!(msg.contains("generate"), "{msg}");
assert!(msg.contains("all-MiniLM-L6-v2"), "{msg}");
}
#[test]
fn validate_accepts_capable_provider_modality() {
let r = Registry::new();
assert!(r
.validate_policy_modality("openai", "text-embedding-3-small", Modality::Embed)
.is_ok());
assert!(r
.validate_policy_modality("minimax", "abab6.5s-chat", Modality::Vision)
.is_ok());
}
#[test]
fn modality_override_completely_replaces_builtin_row() {
let upgraded = Modalities {
embed: true,
generate: true,
vision: false,
moderate: false,
};
let r = Registry::new().with_modality_override("deepseek", upgraded);
assert_eq!(r.modalities("deepseek"), upgraded);
assert!(r
.validate_policy_modality("deepseek", "deepseek-embed", Modality::Embed)
.is_ok());
assert_eq!(r.modalities("openai"), Modalities::for_provider("openai"));
}
#[test]
fn modality_override_can_revoke_a_builtin_capability() {
let restricted = Modalities {
embed: true,
generate: true,
vision: false,
moderate: false,
};
let r = Registry::new().with_modality_override("OpenAI", restricted);
assert_eq!(r.modalities("openai"), restricted);
let err = r
.validate_policy_modality("openai", "gpt-4o", Modality::Vision)
.expect_err("vision revoked by override");
assert_eq!(err.modality, Modality::Vision);
}
}