Skip to main content

tt_shared/
capability_check.rs

1//! Capability and context-window guard for the routing / failover path.
2//!
3//! [`RequiredCapabilities`] is derived from a [`ChatCompletionRequest`] and
4//! checked against a candidate model's [`ModelInfo`] before a route rewrite or
5//! failover dispatch is committed.  The check is intentionally permissive:
6//!
7//! - When `ModelInfo` is **unknown** for a candidate (not in the registry
8//!   catalog) we allow it through — we only skip when we *positively know* a
9//!   capability is missing.
10//! - A capability that the request needs but the model info does **not** list
11//!   causes the candidate to be skipped (the caller emits a tracing event and
12//!   tries the next candidate or falls back to the original model).
13//!
14//! # Token counting
15//!
16//! [`estimate_input_tokens`] concatenates all message text and delegates to
17//! [`tt_tokenize::estimate_tokens`], keyed on `provider_id` so tiktoken is
18//! used for OpenAI/Anthropic and the char/4 heuristic is used elsewhere.
19//! Image/audio bytes are not measured — the guard is a best-effort floor, not
20//! an exact window-packing count.
21
22use crate::{
23    messages::{ContentPart, Message, MessageContent},
24    pricing::{Capability, ModelInfo},
25    ChatCompletionRequest,
26};
27
28/// The set of capabilities a [`ChatCompletionRequest`] requires.
29#[derive(Debug, Clone, Default, PartialEq, Eq)]
30pub struct RequiredCapabilities {
31    /// At least one message contains an image_url or input_audio content part.
32    pub vision: bool,
33    /// The request has non-empty `tools`, or any assistant message contains
34    /// `tool_calls`.
35    pub tools: bool,
36    /// `response_format.type` is `"json_object"` or `"json_schema"`.
37    pub json_mode: bool,
38}
39
40impl RequiredCapabilities {
41    /// Derive the required capabilities from a chat completion request.
42    pub fn from_request(req: &ChatCompletionRequest) -> Self {
43        let mut caps = Self::default();
44
45        // tools / function-calling
46        if !req.tools.is_empty() {
47            caps.tools = true;
48        }
49
50        // response_format → json mode
51        if let Some(rf) = &req.response_format {
52            if rf.r#type == "json_object" || rf.r#type == "json_schema" {
53                caps.json_mode = true;
54            }
55        }
56
57        // scan messages for vision content and tool_calls
58        for msg in &req.messages {
59            match msg {
60                Message::User { content, .. } | Message::System { content } => {
61                    if let MessageContent::Parts(parts) = content {
62                        for part in parts {
63                            match part {
64                                ContentPart::ImageUrl { .. } | ContentPart::InputAudio { .. } => {
65                                    caps.vision = true;
66                                }
67                                ContentPart::Text { .. } => {}
68                            }
69                        }
70                    }
71                }
72                Message::Assistant { tool_calls, .. } => {
73                    if !tool_calls.is_empty() {
74                        caps.tools = true;
75                    }
76                }
77                Message::Tool { .. } => {
78                    // A Tool message in context means the conversation already
79                    // used tool-calling; the next turn may need it too.
80                    caps.tools = true;
81                }
82            }
83        }
84
85        caps
86    }
87
88    /// Returns `true` when all required capabilities are listed in
89    /// `info.capabilities` **and** `max_input_tokens >= estimated_tokens`.
90    ///
91    /// Pass `estimated_tokens = 0` to skip the context-window check.
92    #[must_use]
93    pub fn satisfied_by(&self, info: &ModelInfo, estimated_tokens: u64) -> bool {
94        if self.vision && !info.capabilities.contains(&Capability::Vision) {
95            return false;
96        }
97        if self.tools && !info.capabilities.contains(&Capability::Tools) {
98            return false;
99        }
100        if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
101            return false;
102        }
103        if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
104            return false;
105        }
106        true
107    }
108
109    /// Human-readable list of the reasons a candidate was skipped, for use in
110    /// the `route_skipped_capability` tracing event.
111    pub fn skip_reasons(&self, info: &ModelInfo, estimated_tokens: u64) -> Vec<&'static str> {
112        let mut reasons = Vec::new();
113        if self.vision && !info.capabilities.contains(&Capability::Vision) {
114            reasons.push("vision_not_supported");
115        }
116        if self.tools && !info.capabilities.contains(&Capability::Tools) {
117            reasons.push("tools_not_supported");
118        }
119        if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
120            reasons.push("json_mode_not_supported");
121        }
122        if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
123            reasons.push("context_window_too_small");
124        }
125        reasons
126    }
127}
128
129/// Concatenate all message text parts from a request for token estimation.
130///
131/// Image/audio bytes are excluded — the result is passed to the caller's
132/// tokenizer (e.g. `tt_tokenize::estimate_tokens`) so that `tt-shared` does
133/// not need to depend on `tt-tokenize`.
134pub fn message_text_for_estimation(req: &ChatCompletionRequest) -> String {
135    req.messages
136        .iter()
137        .map(|m| match m {
138            Message::User { content, .. } | Message::System { content } => extract_text(content),
139            Message::Assistant { content, .. } => {
140                content.as_ref().map(extract_text).unwrap_or_default()
141            }
142            Message::Tool { content, .. } => extract_text(content),
143        })
144        .collect()
145}
146
147fn extract_text(content: &MessageContent) -> String {
148    match content {
149        MessageContent::Text(s) => s.clone(),
150        MessageContent::Parts(parts) => parts
151            .iter()
152            .filter_map(|p| match p {
153                ContentPart::Text { text } => Some(text.as_str()),
154                _ => None,
155            })
156            .collect::<Vec<_>>()
157            .join(""),
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use std::collections::HashMap;
164
165    use super::*;
166    use crate::{
167        messages::{ImageUrl, ResponseFormat, Tool, ToolCall, ToolCallFunction, ToolFunction},
168        pricing::Capability,
169        ModelInfo,
170    };
171
172    fn text_model() -> ModelInfo {
173        ModelInfo {
174            id: "text-only".into(),
175            provider: "mock".into(),
176            capabilities: vec![Capability::Text],
177            max_input_tokens: 4096,
178            max_output_tokens: 1024,
179        }
180    }
181
182    fn vision_model() -> ModelInfo {
183        ModelInfo {
184            id: "vision-model".into(),
185            provider: "mock".into(),
186            capabilities: vec![Capability::Text, Capability::Vision, Capability::Tools],
187            max_input_tokens: 128_000,
188            max_output_tokens: 4096,
189        }
190    }
191
192    fn small_model() -> ModelInfo {
193        ModelInfo {
194            id: "small-ctx".into(),
195            provider: "mock".into(),
196            capabilities: vec![Capability::Text],
197            max_input_tokens: 100,
198            max_output_tokens: 100,
199        }
200    }
201
202    fn base_req() -> ChatCompletionRequest {
203        ChatCompletionRequest {
204            model: "gpt-4o".into(),
205            messages: vec![],
206            temperature: None,
207            top_p: None,
208            max_tokens: None,
209            stream: false,
210            tools: vec![],
211            tool_choice: None,
212            response_format: None,
213            stop: vec![],
214            presence_penalty: None,
215            frequency_penalty: None,
216            n: None,
217            seed: None,
218            user: None,
219            tt_extras: HashMap::new(),
220        }
221    }
222
223    #[test]
224    fn plain_text_request_has_no_required_caps() {
225        let req = base_req();
226        let caps = RequiredCapabilities::from_request(&req);
227        assert!(!caps.vision);
228        assert!(!caps.tools);
229        assert!(!caps.json_mode);
230    }
231
232    #[test]
233    fn image_url_part_sets_vision() {
234        let mut req = base_req();
235        req.messages = vec![Message::User {
236            content: MessageContent::Parts(vec![
237                ContentPart::Text {
238                    text: "describe this".into(),
239                },
240                ContentPart::ImageUrl {
241                    image_url: ImageUrl {
242                        url: "data:image/png;base64,abc".into(),
243                        detail: None,
244                    },
245                },
246            ]),
247            name: None,
248        }];
249        let caps = RequiredCapabilities::from_request(&req);
250        assert!(caps.vision);
251        assert!(!caps.tools);
252    }
253
254    #[test]
255    fn tools_field_sets_tools_cap() {
256        let mut req = base_req();
257        req.tools = vec![Tool {
258            r#type: "function".into(),
259            function: ToolFunction {
260                name: "get_weather".into(),
261                description: None,
262                parameters: serde_json::json!({}),
263            },
264        }];
265        let caps = RequiredCapabilities::from_request(&req);
266        assert!(caps.tools);
267    }
268
269    #[test]
270    fn assistant_tool_calls_in_history_sets_tools_cap() {
271        let mut req = base_req();
272        req.messages = vec![Message::Assistant {
273            content: None,
274            tool_calls: vec![ToolCall {
275                id: "call_1".into(),
276                r#type: "function".into(),
277                function: ToolCallFunction {
278                    name: "get_weather".into(),
279                    arguments: "{}".into(),
280                },
281            }],
282            name: None,
283        }];
284        let caps = RequiredCapabilities::from_request(&req);
285        assert!(caps.tools);
286    }
287
288    #[test]
289    fn json_object_response_format_sets_json_mode() {
290        let mut req = base_req();
291        req.response_format = Some(ResponseFormat {
292            r#type: "json_object".into(),
293            json_schema: None,
294        });
295        let caps = RequiredCapabilities::from_request(&req);
296        assert!(caps.json_mode);
297    }
298
299    #[test]
300    fn vision_request_not_satisfied_by_text_model() {
301        let mut req = base_req();
302        req.messages = vec![Message::User {
303            content: MessageContent::Parts(vec![ContentPart::ImageUrl {
304                image_url: ImageUrl {
305                    url: "data:image/png;base64,abc".into(),
306                    detail: None,
307                },
308            }]),
309            name: None,
310        }];
311        let caps = RequiredCapabilities::from_request(&req);
312        assert!(!caps.satisfied_by(&text_model(), 0));
313    }
314
315    #[test]
316    fn vision_request_satisfied_by_vision_model() {
317        let mut req = base_req();
318        req.messages = vec![Message::User {
319            content: MessageContent::Parts(vec![ContentPart::ImageUrl {
320                image_url: ImageUrl {
321                    url: "data:image/png;base64,abc".into(),
322                    detail: None,
323                },
324            }]),
325            name: None,
326        }];
327        let caps = RequiredCapabilities::from_request(&req);
328        assert!(caps.satisfied_by(&vision_model(), 0));
329    }
330
331    #[test]
332    fn exceeds_context_window_not_satisfied() {
333        let caps = RequiredCapabilities::default();
334        assert!(!caps.satisfied_by(&small_model(), 200));
335    }
336
337    #[test]
338    fn within_context_window_satisfied() {
339        let caps = RequiredCapabilities::default();
340        assert!(caps.satisfied_by(&small_model(), 50));
341    }
342
343    #[test]
344    fn zero_estimated_tokens_skips_window_check() {
345        let caps = RequiredCapabilities::default();
346        assert!(caps.satisfied_by(&small_model(), 0));
347    }
348
349    #[test]
350    fn skip_reasons_lists_all_failures() {
351        let caps = RequiredCapabilities {
352            vision: true,
353            tools: true,
354            ..Default::default()
355        };
356        let reasons = caps.skip_reasons(&text_model(), 9999);
357        assert!(reasons.contains(&"vision_not_supported"));
358        assert!(reasons.contains(&"tools_not_supported"));
359        assert!(reasons.contains(&"context_window_too_small"));
360    }
361}