1use tt_shared::pricing::{Capability, ModelInfo};
7
8use crate::{RouteAction, RouteConditions};
9
10#[derive(Debug, thiserror::Error, PartialEq, Eq)]
11pub enum ValidationError {
12 #[error("target_model `{target}` is missing the `{capability}` capability required by this route's content-type condition")]
13 MissingCapability {
14 target: String,
15 capability: &'static str,
16 },
17}
18
19pub fn validate_capability(
23 when: &RouteConditions,
24 then: &RouteAction,
25 lookup: impl Fn(&str) -> Option<ModelInfo>,
26) -> Result<(), ValidationError> {
27 let needs_vision = when.has_images == Some(true) || when.has_audio == Some(true);
28 if !needs_vision {
29 return Ok(());
30 }
31 if let Some(info) = lookup(&then.target_model) {
32 if !info.capabilities.contains(&Capability::Vision) {
33 return Err(ValidationError::MissingCapability {
34 target: then.target_model.clone(),
35 capability: "vision",
36 });
37 }
38 }
39 Ok(())
40}
41
42#[cfg(test)]
43mod tests {
44 use super::*;
45 use crate::{RouteAction, RouteConditions};
46 use tt_shared::pricing::{Capability, ModelInfo};
47
48 fn action(target: &str) -> RouteAction {
49 RouteAction {
50 target_model: target.into(),
51 fallbacks: vec![],
52 disable_cache: false,
53 max_cost_usd: None,
54 }
55 }
56 fn vision_model(id: &str) -> ModelInfo {
57 ModelInfo {
58 id: id.into(),
59 provider: "p".into(),
60 capabilities: vec![Capability::Text, Capability::Vision],
61 max_input_tokens: 1000,
62 max_output_tokens: 1000,
63 }
64 }
65 fn text_model(id: &str) -> ModelInfo {
66 ModelInfo {
67 id: id.into(),
68 provider: "p".into(),
69 capabilities: vec![Capability::Text],
70 max_input_tokens: 1000,
71 max_output_tokens: 1000,
72 }
73 }
74
75 #[test]
76 fn has_images_requires_vision_target() {
77 let when = RouteConditions {
78 has_images: Some(true),
79 ..Default::default()
80 };
81 let lookup = |m: &str| -> Option<ModelInfo> {
82 match m {
83 "vis" => Some(vision_model("vis")),
84 "txt" => Some(text_model("txt")),
85 _ => None,
86 }
87 };
88 assert!(validate_capability(&when, &action("vis"), lookup).is_ok());
89 assert!(validate_capability(&when, &action("txt"), lookup).is_err());
90 assert!(validate_capability(&when, &action("unknown"), lookup).is_ok());
92 }
93
94 #[test]
95 fn no_modality_condition_skips_capability_check() {
96 let when = RouteConditions::default();
97 let lookup = |_: &str| -> Option<ModelInfo> { None };
98 assert!(validate_capability(&when, &action("anything"), lookup).is_ok());
99 }
100}