1use std::collections::HashMap;
58
59use crate::runtime::ai::strict_validator::Mode;
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub struct Capabilities {
64 pub supports_citations: bool,
67 pub supports_seed: bool,
70 pub supports_temperature_zero: bool,
73 pub supports_streaming: bool,
76}
77
78impl Capabilities {
79 pub const fn conservative() -> Self {
83 Self {
84 supports_citations: false,
85 supports_seed: false,
86 supports_temperature_zero: true,
87 supports_streaming: false,
88 }
89 }
90
91 pub fn for_provider(token: &str) -> Self {
95 match token {
96 "openai" => Self {
97 supports_citations: true,
98 supports_seed: true,
99 supports_temperature_zero: true,
100 supports_streaming: true,
101 },
102 "anthropic" => Self {
103 supports_citations: true,
104 supports_seed: false,
105 supports_temperature_zero: true,
106 supports_streaming: true,
107 },
108 "groq" | "together" | "openrouter" | "venice" | "deepseek" => Self {
109 supports_citations: true,
110 supports_seed: true,
111 supports_temperature_zero: true,
112 supports_streaming: true,
113 },
114 "ollama" => Self {
115 supports_citations: false,
116 supports_seed: true,
117 supports_temperature_zero: true,
118 supports_streaming: true,
119 },
120 "huggingface" => Self {
121 supports_citations: false,
122 supports_seed: false,
123 supports_temperature_zero: true,
124 supports_streaming: false,
125 },
126 "local" => Self {
127 supports_citations: false,
128 supports_seed: false,
129 supports_temperature_zero: false,
130 supports_streaming: false,
131 },
132 "custom" => Self::conservative(),
133 _ => Self::conservative(),
134 }
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
148pub enum Modality {
149 Embed,
151 Generate,
153 Vision,
155 Moderate,
157}
158
159impl Modality {
160 pub fn token(self) -> &'static str {
162 match self {
163 Self::Embed => "embed",
164 Self::Generate => "generate",
165 Self::Vision => "vision",
166 Self::Moderate => "moderate",
167 }
168 }
169
170 pub fn parse(token: &str) -> Option<Self> {
173 match token.trim().to_ascii_lowercase().as_str() {
174 "embed" | "embedding" | "embeddings" => Some(Self::Embed),
175 "generate" | "generation" | "chat" | "completion" => Some(Self::Generate),
176 "vision" | "image" | "multimodal" => Some(Self::Vision),
177 "moderate" | "moderation" => Some(Self::Moderate),
178 _ => None,
179 }
180 }
181
182 pub const ALL: [Self; 4] = [Self::Embed, Self::Generate, Self::Vision, Self::Moderate];
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
189pub struct Modalities {
190 pub embed: bool,
191 pub generate: bool,
192 pub vision: bool,
193 pub moderate: bool,
194}
195
196impl Modalities {
197 pub const fn conservative() -> Self {
208 Self {
209 embed: true,
210 generate: true,
211 vision: false,
212 moderate: false,
213 }
214 }
215
216 pub fn supports(&self, modality: Modality) -> bool {
218 match modality {
219 Modality::Embed => self.embed,
220 Modality::Generate => self.generate,
221 Modality::Vision => self.vision,
222 Modality::Moderate => self.moderate,
223 }
224 }
225
226 pub fn for_provider(token: &str) -> Self {
249 match token {
250 "openai" => Self {
251 embed: true,
252 generate: true,
253 vision: true,
254 moderate: true,
255 },
256 "anthropic" => Self {
257 embed: false,
258 generate: true,
259 vision: true,
260 moderate: false,
261 },
262 "minimax" | "together" | "ollama" => Self {
263 embed: true,
264 generate: true,
265 vision: true,
266 moderate: false,
267 },
268 "groq" | "openrouter" | "venice" => Self {
269 embed: false,
270 generate: true,
271 vision: true,
272 moderate: false,
273 },
274 "deepseek" => Self {
275 embed: false,
276 generate: true,
277 vision: false,
278 moderate: false,
279 },
280 "huggingface" => Self {
281 embed: true,
282 generate: true,
283 vision: false,
284 moderate: false,
285 },
286 "local" => Self {
287 embed: true,
288 generate: false,
289 vision: true,
299 moderate: true,
300 },
301 "custom" => Self::conservative(),
302 _ => Self::conservative(),
303 }
304 }
305}
306
307#[derive(Debug, Clone, PartialEq, Eq)]
310pub struct ModalityValidationError {
311 pub provider: String,
313 pub model: String,
317 pub modality: Modality,
319 pub message: String,
321}
322
323impl std::fmt::Display for ModalityValidationError {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 f.write_str(&self.message)
326 }
327}
328
329impl std::error::Error for ModalityValidationError {}
330
331#[derive(Debug, Clone, PartialEq, Eq)]
334pub struct ModeWarning {
335 pub kind: ModeWarningKind,
337 pub detail: String,
339}
340
341#[derive(Debug, Clone, Copy, PartialEq, Eq)]
342pub enum ModeWarningKind {
343 ModeFallback,
346}
347
348#[derive(Debug, Clone, PartialEq, Eq)]
350pub enum ModeOutcome {
351 Allowed { effective: Mode },
353 Fallback {
357 effective: Mode,
358 warning: ModeWarning,
359 },
360}
361
362impl ModeOutcome {
363 pub fn effective(&self) -> Mode {
365 match self {
366 Self::Allowed { effective } | Self::Fallback { effective, .. } => *effective,
367 }
368 }
369
370 pub fn warning(&self) -> Option<&ModeWarning> {
372 match self {
373 Self::Allowed { .. } => None,
374 Self::Fallback { warning, .. } => Some(warning),
375 }
376 }
377}
378
379#[derive(Debug, Clone, Default)]
386pub struct Registry {
387 overrides: HashMap<String, Capabilities>,
388 modality_overrides: HashMap<String, Modalities>,
389}
390
391impl Registry {
392 pub fn new() -> Self {
396 Self {
397 overrides: HashMap::new(),
398 modality_overrides: HashMap::new(),
399 }
400 }
401
402 pub fn with_override(mut self, token: &str, caps: Capabilities) -> Self {
405 self.overrides.insert(token.to_ascii_lowercase(), caps);
406 self
407 }
408
409 pub fn capabilities(&self, token: &str) -> Capabilities {
412 let key = token.to_ascii_lowercase();
413 if let Some(c) = self.overrides.get(&key) {
414 return *c;
415 }
416 Capabilities::for_provider(&key)
417 }
418
419 pub fn with_modality_override(mut self, token: &str, modalities: Modalities) -> Self {
423 self.modality_overrides
424 .insert(token.to_ascii_lowercase(), modalities);
425 self
426 }
427
428 pub fn modalities(&self, token: &str) -> Modalities {
431 let key = token.to_ascii_lowercase();
432 if let Some(m) = self.modality_overrides.get(&key) {
433 return *m;
434 }
435 Modalities::for_provider(&key)
436 }
437
438 pub fn can_serve(&self, token: &str, _model: &str, modality: Modality) -> bool {
443 self.modalities(token).supports(modality)
444 }
445
446 pub fn validate_policy_modality(
455 &self,
456 provider: &str,
457 model: &str,
458 modality: Modality,
459 ) -> Result<(), ModalityValidationError> {
460 if self.can_serve(provider, model, modality) {
461 return Ok(());
462 }
463 Err(ModalityValidationError {
464 provider: provider.to_string(),
465 model: model.to_string(),
466 modality,
467 message: format!(
468 "AI policy is invalid: provider '{}' (model '{}') cannot serve the '{}' modality; \
469 declare a provider that supports it or register a modality override",
470 provider.to_ascii_lowercase(),
471 model,
472 modality.token()
473 ),
474 })
475 }
476
477 pub fn evaluate_mode(&self, token: &str, requested: Mode) -> ModeOutcome {
484 if requested == Mode::Lenient {
485 return ModeOutcome::Allowed {
486 effective: Mode::Lenient,
487 };
488 }
489 let caps = self.capabilities(token);
490 if caps.supports_citations {
491 return ModeOutcome::Allowed {
492 effective: Mode::Strict,
493 };
494 }
495 ModeOutcome::Fallback {
496 effective: Mode::Lenient,
497 warning: ModeWarning {
498 kind: ModeWarningKind::ModeFallback,
499 detail: format!(
500 "provider '{}' does not support reliable citation emission; \
501 strict mode downgraded to lenient",
502 token.to_ascii_lowercase()
503 ),
504 },
505 }
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn conservative_defaults_match_ac() {
515 let c = Capabilities::conservative();
516 assert!(!c.supports_citations);
517 assert!(!c.supports_seed);
518 assert!(c.supports_temperature_zero);
519 assert!(!c.supports_streaming);
520 }
521
522 #[test]
523 fn openai_supports_everything() {
524 let c = Capabilities::for_provider("openai");
525 assert!(c.supports_citations);
526 assert!(c.supports_seed);
527 assert!(c.supports_temperature_zero);
528 assert!(c.supports_streaming);
529 }
530
531 #[test]
532 fn anthropic_no_seed() {
533 let c = Capabilities::for_provider("anthropic");
534 assert!(c.supports_citations);
535 assert!(!c.supports_seed);
536 assert!(c.supports_temperature_zero);
537 assert!(c.supports_streaming);
538 }
539
540 #[test]
541 fn openai_compatible_family_uniform() {
542 for token in ["groq", "together", "openrouter", "venice", "deepseek"] {
543 let c = Capabilities::for_provider(token);
544 assert!(c.supports_citations, "{token} citations");
545 assert!(c.supports_seed, "{token} seed");
546 assert!(c.supports_temperature_zero, "{token} temp0");
547 assert!(c.supports_streaming, "{token} streaming");
548 }
549 }
550
551 #[test]
552 fn ollama_no_citations_but_seed_and_streaming() {
553 let c = Capabilities::for_provider("ollama");
554 assert!(!c.supports_citations);
555 assert!(c.supports_seed);
556 assert!(c.supports_temperature_zero);
557 assert!(c.supports_streaming);
558 }
559
560 #[test]
561 fn huggingface_inference_no_seed_no_streaming() {
562 let c = Capabilities::for_provider("huggingface");
563 assert!(!c.supports_citations);
564 assert!(!c.supports_seed);
565 assert!(c.supports_temperature_zero);
566 assert!(!c.supports_streaming);
567 }
568
569 #[test]
570 fn local_backend_has_no_temperature() {
571 let c = Capabilities::for_provider("local");
572 assert!(!c.supports_citations);
573 assert!(!c.supports_seed);
574 assert!(!c.supports_temperature_zero);
575 assert!(!c.supports_streaming);
576 }
577
578 #[test]
579 fn custom_is_conservative() {
580 assert_eq!(
581 Capabilities::for_provider("custom"),
582 Capabilities::conservative()
583 );
584 }
585
586 #[test]
587 fn unknown_token_is_conservative() {
588 assert_eq!(
589 Capabilities::for_provider("totally-made-up"),
590 Capabilities::conservative()
591 );
592 }
593
594 #[test]
595 fn token_lookup_is_case_insensitive_via_registry() {
596 let r = Registry::new();
597 assert_eq!(
600 r.capabilities("OPENAI"),
601 Capabilities::for_provider("openai")
602 );
603 assert_eq!(
604 r.capabilities("OpenAi"),
605 Capabilities::for_provider("openai")
606 );
607 }
608
609 #[test]
610 fn override_completely_replaces_builtin_row() {
611 let overridden = Capabilities {
612 supports_citations: false,
613 supports_seed: false,
614 supports_temperature_zero: false,
615 supports_streaming: false,
616 };
617 let r = Registry::new().with_override("openai", overridden);
618 assert_eq!(r.capabilities("openai"), overridden);
619 assert_eq!(r.capabilities("groq"), Capabilities::for_provider("groq"));
621 }
622
623 #[test]
624 fn override_key_is_lowercased() {
625 let custom_caps = Capabilities {
626 supports_citations: true,
627 supports_seed: true,
628 supports_temperature_zero: true,
629 supports_streaming: true,
630 };
631 let r = Registry::new().with_override("CUSTOM-INTERNAL", custom_caps);
632 assert_eq!(r.capabilities("custom-internal"), custom_caps);
634 assert_eq!(r.capabilities("Custom-Internal"), custom_caps);
635 }
636
637 #[test]
638 fn lenient_always_allowed_regardless_of_provider() {
639 let r = Registry::new();
640 for token in ["openai", "huggingface", "local", "totally-made-up"] {
641 let outcome = r.evaluate_mode(token, Mode::Lenient);
642 assert_eq!(
643 outcome,
644 ModeOutcome::Allowed {
645 effective: Mode::Lenient
646 },
647 "lenient should pass through for {token}"
648 );
649 assert!(outcome.warning().is_none());
650 }
651 }
652
653 #[test]
654 fn strict_allowed_for_citing_provider() {
655 let r = Registry::new();
656 let outcome = r.evaluate_mode("openai", Mode::Strict);
657 assert_eq!(
658 outcome,
659 ModeOutcome::Allowed {
660 effective: Mode::Strict
661 }
662 );
663 assert!(outcome.warning().is_none());
664 }
665
666 #[test]
667 fn strict_downgraded_for_non_citing_provider() {
668 let r = Registry::new();
669 let outcome = r.evaluate_mode("huggingface", Mode::Strict);
670 match outcome {
671 ModeOutcome::Fallback {
672 effective,
673 ref warning,
674 } => {
675 assert_eq!(effective, Mode::Lenient);
676 assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
677 assert!(warning.detail.contains("huggingface"));
678 assert!(warning.detail.contains("strict"));
679 }
680 other => panic!("expected Fallback, got {other:?}"),
681 }
682 assert_eq!(outcome.effective(), Mode::Lenient);
683 assert!(outcome.warning().is_some());
684 }
685
686 #[test]
687 fn strict_downgraded_for_unknown_provider() {
688 let r = Registry::new();
689 let outcome = r.evaluate_mode("brand-new-provider", Mode::Strict);
690 assert_eq!(outcome.effective(), Mode::Lenient);
691 match outcome {
692 ModeOutcome::Fallback { warning, .. } => {
693 assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
694 assert!(warning.detail.contains("brand-new-provider"));
695 }
696 other => panic!("expected Fallback, got {other:?}"),
697 }
698 }
699
700 #[test]
701 fn override_can_upgrade_non_citing_provider_to_citing() {
702 let r = Registry::new().with_override(
703 "ollama",
704 Capabilities {
705 supports_citations: true,
706 supports_seed: true,
707 supports_temperature_zero: true,
708 supports_streaming: true,
709 },
710 );
711 let outcome = r.evaluate_mode("ollama", Mode::Strict);
712 assert_eq!(
713 outcome,
714 ModeOutcome::Allowed {
715 effective: Mode::Strict
716 }
717 );
718 }
719
720 #[test]
721 fn override_can_downgrade_citing_provider_to_non_citing() {
722 let r = Registry::new().with_override(
723 "openai",
724 Capabilities {
725 supports_citations: false,
726 supports_seed: false,
727 supports_temperature_zero: true,
728 supports_streaming: false,
729 },
730 );
731 let outcome = r.evaluate_mode("openai", Mode::Strict);
732 match outcome {
733 ModeOutcome::Fallback {
734 effective,
735 ref warning,
736 } => {
737 assert_eq!(effective, Mode::Lenient);
738 assert_eq!(warning.kind, ModeWarningKind::ModeFallback);
739 assert!(warning.detail.contains("openai"));
740 }
741 other => panic!("expected Fallback, got {other:?}"),
742 }
743 }
744
745 #[test]
746 fn evaluate_mode_is_deterministic() {
747 let r = Registry::new();
748 for _ in 0..16 {
749 assert_eq!(
750 r.evaluate_mode("openai", Mode::Strict),
751 ModeOutcome::Allowed {
752 effective: Mode::Strict
753 }
754 );
755 assert_eq!(
756 r.evaluate_mode("huggingface", Mode::Strict).effective(),
757 Mode::Lenient
758 );
759 }
760 }
761
762 #[test]
763 fn all_eleven_provider_tokens_have_explicit_rows() {
764 let citing = [
769 "openai",
770 "anthropic",
771 "groq",
772 "together",
773 "openrouter",
774 "venice",
775 "deepseek",
776 ];
777 let non_citing = ["ollama", "huggingface", "local"];
778 for t in citing {
779 assert!(
780 Capabilities::for_provider(t).supports_citations,
781 "{t} should cite"
782 );
783 }
784 for t in non_citing {
785 assert!(
786 !Capabilities::for_provider(t).supports_citations,
787 "{t} should not cite"
788 );
789 }
790 assert_eq!(
792 Capabilities::for_provider("custom"),
793 Capabilities::conservative()
794 );
795 }
796
797 #[test]
800 fn modality_token_roundtrips_through_parse() {
801 for m in Modality::ALL {
802 assert_eq!(Modality::parse(m.token()), Some(m), "{m:?}");
803 }
804 assert_eq!(Modality::parse("EMBEDDING"), Some(Modality::Embed));
806 assert_eq!(Modality::parse("chat"), Some(Modality::Generate));
807 assert_eq!(Modality::parse("image"), Some(Modality::Vision));
808 assert_eq!(Modality::parse("moderation"), Some(Modality::Moderate));
809 assert_eq!(Modality::parse("nonsense"), None);
810 }
811
812 #[test]
813 fn unknown_provider_gets_conservative_modalities() {
814 let c = Modalities::for_provider("totally-made-up");
817 assert_eq!(c, Modalities::conservative());
818 assert!(c.embed);
819 assert!(c.generate);
820 assert!(!c.vision);
821 assert!(!c.moderate);
822 assert_eq!(
824 Modalities::for_provider("custom"),
825 Modalities::conservative()
826 );
827 }
828
829 #[test]
830 fn openai_serves_every_modality() {
831 let c = Modalities::for_provider("openai");
832 for m in Modality::ALL {
833 assert!(c.supports(m), "openai should serve {m:?}");
834 }
835 }
836
837 #[test]
838 fn minimax_serves_embed_generate_vision_not_moderate() {
839 let c = Modalities::for_provider("minimax");
840 assert!(c.supports(Modality::Embed));
841 assert!(c.supports(Modality::Generate));
842 assert!(c.supports(Modality::Vision));
843 assert!(!c.supports(Modality::Moderate));
844 }
845
846 #[test]
847 fn anthropic_cannot_embed() {
848 assert!(!Modalities::for_provider("anthropic").supports(Modality::Embed));
851 assert!(Modalities::for_provider("anthropic").supports(Modality::Generate));
852 }
853
854 #[test]
855 fn local_serves_embed_vision_and_moderate() {
856 let c = Modalities::for_provider("local");
857 assert!(c.supports(Modality::Embed));
858 assert!(c.supports(Modality::Vision));
859 assert!(c.supports(Modality::Moderate));
862 assert!(!c.supports(Modality::Generate));
864 }
865
866 #[test]
867 fn deepseek_is_generate_only() {
868 let c = Modalities::for_provider("deepseek");
869 assert!(c.supports(Modality::Generate));
870 assert!(!c.supports(Modality::Embed));
871 assert!(!c.supports(Modality::Vision));
872 assert!(!c.supports(Modality::Moderate));
873 }
874
875 #[test]
876 fn can_serve_is_case_insensitive_and_deterministic() {
877 let r = Registry::new();
878 for _ in 0..8 {
879 assert!(r.can_serve("OpenAI", "gpt-4o", Modality::Vision));
880 assert!(!r.can_serve("LOCAL", "all-MiniLM", Modality::Generate));
881 }
882 }
883
884 #[test]
885 fn validate_rejects_incapable_provider_modality() {
886 let r = Registry::new();
887 let err = r
888 .validate_policy_modality("local", "all-MiniLM-L6-v2", Modality::Generate)
889 .expect_err("local cannot generate");
890 assert_eq!(err.provider, "local");
891 assert_eq!(err.modality, Modality::Generate);
892 let msg = err.to_string();
893 assert!(msg.contains("local"), "{msg}");
894 assert!(msg.contains("generate"), "{msg}");
895 assert!(msg.contains("all-MiniLM-L6-v2"), "{msg}");
896 }
897
898 #[test]
899 fn validate_accepts_capable_provider_modality() {
900 let r = Registry::new();
901 assert!(r
902 .validate_policy_modality("openai", "text-embedding-3-small", Modality::Embed)
903 .is_ok());
904 assert!(r
905 .validate_policy_modality("minimax", "abab6.5s-chat", Modality::Vision)
906 .is_ok());
907 }
908
909 #[test]
910 fn modality_override_completely_replaces_builtin_row() {
911 let upgraded = Modalities {
914 embed: true,
915 generate: true,
916 vision: false,
917 moderate: false,
918 };
919 let r = Registry::new().with_modality_override("deepseek", upgraded);
920 assert_eq!(r.modalities("deepseek"), upgraded);
921 assert!(r
922 .validate_policy_modality("deepseek", "deepseek-embed", Modality::Embed)
923 .is_ok());
924 assert_eq!(r.modalities("openai"), Modalities::for_provider("openai"));
926 }
927
928 #[test]
929 fn modality_override_can_revoke_a_builtin_capability() {
930 let restricted = Modalities {
932 embed: true,
933 generate: true,
934 vision: false,
935 moderate: false,
936 };
937 let r = Registry::new().with_modality_override("OpenAI", restricted);
938 assert_eq!(r.modalities("openai"), restricted);
940 let err = r
941 .validate_policy_modality("openai", "gpt-4o", Modality::Vision)
942 .expect_err("vision revoked by override");
943 assert_eq!(err.modality, Modality::Vision);
944 }
945}