Skip to main content

venice_e2ee_proxy/
tools.rs

1//! OpenAI-style tool-call emulation.
2//!
3//! Venice E2EE responses do not expose native function calls, so this module
4//! parses tool calls client-side from decrypted assistant text using a vendored
5//! subset of vLLM's Rust tool parser, validates the calls against the request's OpenAI
6//! `tools`, and builds prompt text for the encrypted controller/correction
7//! requests.
8//!
9//! Venice E2EE cannot render tools into the server-side chat template (tools
10//! arrive encrypted), so the proxy prompt must instruct a model-visible output
11//! shape. Use model-family-specific shapes where the bundled vLLM Rust parser
12//! has a matching format, and fall back to Hermes JSON otherwise.
13
14use std::{collections::HashSet, time::Duration};
15
16use crate::{
17    config::{ToolMode, ToolsConfig},
18    openai::chat::{
19        ChatCompletionRequest, ChatRequestError, ChatToolChoice, ChatToolDefinition,
20        NormalizedChatMessage,
21    },
22    vllm_tool_parser::{
23        Glm47MoeToolParser, HermesToolParser, Qwen3XmlToolParser, Result as ToolParserResult, Tool,
24        ToolCallDelta, ToolParseResult, ToolParser,
25    },
26};
27use serde_json::{Map, Value};
28use thiserror::Error;
29use tracing::warn;
30
31/// Tool-call markers used by the Hermes/Qwen/GLM prompt instructions.
32const TOOL_CALL_START: &str = "<tool_call>";
33const TOOL_CALL_END: &str = "</tool_call>";
34
35/// Generates an OpenAI-style tool-call ID.
36pub fn generate_tool_call_id() -> String {
37    format!("call_{}", uuid::Uuid::new_v4().simple())
38}
39
40/// Maximum bytes of invalid assistant output echoed back in a correction
41/// prompt; oversized output would otherwise grow each encrypted retry request
42/// by the full output size.
43const CORRECTION_INVALID_OUTPUT_MAX_BYTES: usize = 4_096;
44
45/// Per-request tool-emulation state derived from config, tools, and tool choice.
46#[derive(Debug, Clone)]
47pub struct ToolEmulationContext {
48    config: ToolsConfig,
49    tools: Vec<ChatToolDefinition>,
50    tool_schemas_json: String,
51    require_tool_call: bool,
52    prompt_format: ToolPromptFormat,
53}
54
55/// Prompt/parser format selected for the model family handling tool calls.
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57enum ToolPromptFormat {
58    HermesJson,
59    GlmXml,
60    QwenXml,
61}
62
63impl ToolPromptFormat {
64    /// Chooses a prompt/parser format from the requested model id.
65    fn for_model(model: &str) -> Self {
66        let model = model.to_ascii_lowercase();
67        if model.contains("glm") {
68            Self::GlmXml
69        } else if model.contains("qwen") {
70            Self::QwenXml
71        } else {
72            // Gemma, GPT-OSS, Venice uncensored, and unknown E2EE models have
73            // live-tested successfully with the prompt-instructed Hermes JSON
74            // format. Their native parser formats either need tokenizer special
75            // tokens or are not exposed in the Rust parser crate.
76            Self::HermesJson
77        }
78    }
79}
80
81impl ToolEmulationContext {
82    /// Builds tool-emulation context for a request, or returns `None` when tools are disabled or unused.
83    pub fn from_request(
84        config: &ToolsConfig,
85        request: &ChatCompletionRequest,
86    ) -> Result<Option<Self>, ChatRequestError> {
87        if !config.enabled || config.mode == ToolMode::None {
88            return Ok(None);
89        }
90
91        if matches!(request.tool_choice, ChatToolChoice::None) {
92            return Ok(None);
93        }
94
95        if request.tools.is_empty() {
96            if matches!(
97                request.tool_choice,
98                ChatToolChoice::Required | ChatToolChoice::Function { .. }
99            ) {
100                return Err(ChatRequestError::invalid_field(
101                    "tool_choice",
102                    "tool_choice requires at least one function tool",
103                ));
104            }
105            return Ok(None);
106        }
107
108        let mut seen_names = HashSet::new();
109        for tool in &request.tools {
110            if !seen_names.insert(tool.name()) {
111                return Err(ChatRequestError::invalid_field(
112                    "tools",
113                    format!("duplicate function tool name {:?}", tool.name()),
114                ));
115            }
116
117            if config.validate_json_schema
118                && let Some(schema) = tool.parameters_schema()
119            {
120                validate_schema_shape(schema).map_err(|message| {
121                    ChatRequestError::invalid_field(
122                        "tools",
123                        format!(
124                            "tool {:?} has an unsupported or invalid parameters schema: {message}",
125                            tool.name()
126                        ),
127                    )
128                })?;
129            }
130        }
131
132        let (tools, require_tool_call) = match &request.tool_choice {
133            ChatToolChoice::Auto => (request.tools.clone(), false),
134            ChatToolChoice::Required => (request.tools.clone(), true),
135            ChatToolChoice::Function { name } => {
136                let selected = request
137                    .tools
138                    .iter()
139                    .find(|tool| tool.name() == name)
140                    .cloned()
141                    .ok_or_else(|| {
142                        ChatRequestError::invalid_field(
143                            "tool_choice",
144                            format!("requested function tool {name:?} is not present in tools"),
145                        )
146                    })?;
147                (vec![selected], true)
148            }
149            ChatToolChoice::None => unreachable!("tool_choice none returned above"),
150        };
151
152        let tool_schemas_json = serde_json::to_string(&tools).map_err(|source| {
153            ChatRequestError::invalid_field(
154                "tools",
155                format!("tool schemas could not be serialized for the controller prompt: {source}"),
156            )
157        })?;
158
159        Ok(Some(Self {
160            config: config.clone(),
161            tools,
162            tool_schemas_json,
163            require_tool_call,
164            prompt_format: ToolPromptFormat::for_model(&request.model),
165        }))
166    }
167
168    /// Returns the tool-emulation configuration used by this context.
169    pub fn config(&self) -> &ToolsConfig {
170        &self.config
171    }
172
173    /// Returns the maximum number of correction retries allowed for invalid tool calls.
174    pub fn max_retries(&self) -> u32 {
175        self.config.max_retries
176    }
177
178    /// Returns the maximum time to wait for a non-streamed tool-call marker response.
179    pub fn marker_timeout(&self) -> Duration {
180        self.config.tool_call_marker_timeout
181    }
182
183    /// Creates the preferred tool parser for one assistant turn.
184    pub fn create_parser(&self) -> Result<Box<dyn ToolParser>, ToolCallValidationError> {
185        self.create_parser_for_format(self.prompt_format)
186    }
187
188    /// Creates a parser for a specific prompt format.
189    fn create_parser_for_format(
190        &self,
191        format: ToolPromptFormat,
192    ) -> Result<Box<dyn ToolParser>, ToolCallValidationError> {
193        let parser = match format {
194            ToolPromptFormat::HermesJson => LenientToolParser::create(&[]),
195            ToolPromptFormat::GlmXml => Glm47MoeToolParser::create(&self.vllm_tools()),
196            ToolPromptFormat::QwenXml => Qwen3XmlToolParser::create(&[]),
197        };
198        parser.map_err(|error| {
199            ToolCallValidationError::new(format!("tool parser could not be created: {error}"))
200        })
201    }
202
203    /// Converts OpenAI function tools into the vLLM parser tool representation.
204    fn vllm_tools(&self) -> Vec<Tool> {
205        self.tools
206            .iter()
207            .map(|tool| {
208                let function = tool.function();
209                Tool {
210                    name: function.name.clone(),
211                    description: function.description.clone(),
212                    parameters: function
213                        .parameters
214                        .as_ref()
215                        .map(|schema| Value::Object(schema.as_map().clone()))
216                        .unwrap_or_else(|| Value::Object(Map::new())),
217                    strict: None,
218                }
219            })
220            .collect()
221    }
222
223    /// Builds the system/controller prompt message that instructs the model to emit tool calls.
224    pub fn controller_message(&self) -> NormalizedChatMessage {
225        let requirement = if self.require_tool_call {
226            "You must call at least one tool. Do not answer the user directly. Output each tool call using this format and nothing else:"
227        } else {
228            "If tools are required, do not answer the user directly. Output each tool call using this format and nothing else:"
229        };
230        let optional_rule = if self.require_tool_call {
231            String::new()
232        } else {
233            format!("\n- If no tool is needed, answer normally and do not use {TOOL_CALL_START}.")
234        };
235
236        let content = match self.prompt_format {
237            ToolPromptFormat::HermesJson => {
238                self.hermes_controller_content(requirement, &optional_rule)
239            }
240            ToolPromptFormat::GlmXml => {
241                self.glm_xml_controller_content(requirement, &optional_rule)
242            }
243            ToolPromptFormat::QwenXml => {
244                self.qwen_xml_controller_content(requirement, &optional_rule)
245            }
246        };
247
248        // The HTTP layer appends this content to the request's system prompt
249        // for tool-emulated requests. Keep the role as a harmless container for
250        // callers/tests that inspect the standalone message.
251        NormalizedChatMessage::new("user", content)
252    }
253
254    /// Builds a correction prompt from the previous validation error and assistant output.
255    pub fn correction_message(
256        &self,
257        validation_error: &str,
258        invalid_output: &str,
259    ) -> NormalizedChatMessage {
260        let invalid_output =
261            truncate_at_char_boundary(invalid_output, CORRECTION_INVALID_OUTPUT_MAX_BYTES);
262        let content = match self.prompt_format {
263            ToolPromptFormat::HermesJson => {
264                self.hermes_correction_content(validation_error, &invalid_output)
265            }
266            ToolPromptFormat::GlmXml => {
267                self.glm_xml_correction_content(validation_error, &invalid_output)
268            }
269            ToolPromptFormat::QwenXml => {
270                self.qwen_xml_correction_content(validation_error, &invalid_output)
271            }
272        };
273        NormalizedChatMessage::new("system", content)
274    }
275
276    /// Builds Hermes-style JSON tool-call controller instructions.
277    fn hermes_controller_content(&self, requirement: &str, optional_rule: &str) -> String {
278        format!(
279            "You have access to tools.\n\n{requirement}\n\nRequired tool-call format:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nInside each {TOOL_CALL_START} block, output ONLY one valid JSON object with exactly these top-level keys:\n- \"name\": the tool name as a JSON string.\n- \"arguments\": a JSON object containing the tool arguments.\n\nValid single-call example:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nValid multi-call example:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nInvalid formats. NEVER use these:\n- {TOOL_CALL_START}TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}){TOOL_CALL_END}\n- {TOOL_CALL_START}TOOL_NAME{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}{TOOL_CALL_END}\n- TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}})\n- {{\"tool\":\"TOOL_NAME\",\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}\n\nRules:\n- TOOL_NAME must exactly match one available tool name.\n- Always put the tool name in the JSON \"name\" field.\n- Always put tool arguments inside the JSON \"arguments\" object.\n- Do not put arguments directly after the tool name.\n- Do not use function-call syntax like TOOL_NAME(...).\n- arguments must be valid JSON and must satisfy the tool schema.\n- Emit one marker block per tool call.\n- Do not include markdown fences.\n- Do not include explanations.{optional_rule}\n\nAvailable tools:\n{}",
280            r#"{"name":"TOOL_NAME","arguments":{...}}"#,
281            r#"{"name":"TOOL_NAME","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
282            r#"{"name":"TOOL_NAME_1","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
283            r#"{"name":"TOOL_NAME_2","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
284            self.tool_schemas_json,
285        )
286    }
287
288    /// Builds Qwen XML-wrapped JSON tool-call controller instructions.
289    fn qwen_xml_controller_content(&self, requirement: &str, optional_rule: &str) -> String {
290        format!(
291            "You have access to tools.\n\n{requirement}\n\nRequired Qwen XML-wrapped JSON tool-call format:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nThere MUST be a newline immediately after {TOOL_CALL_START}. Inside each block, output ONLY one valid JSON object with exactly these top-level keys:\n- \"name\": the tool name as a JSON string.\n- \"arguments\": a JSON object containing the tool arguments.\n\nValid example:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nRules:\n- TOOL_NAME must exactly match one available tool name.\n- Always put the tool name in the JSON \"name\" field.\n- Always put tool arguments inside the JSON \"arguments\" object.\n- Do not use function-call syntax like TOOL_NAME(...).\n- arguments must be valid JSON and must satisfy the tool schema.\n- Emit one marker block per tool call.\n- Do not include markdown fences.\n- Do not include explanations.{optional_rule}\n\nAvailable tools:\n{}",
292            r#"{"name":"TOOL_NAME","arguments":{...}}"#,
293            r#"{"name":"TOOL_NAME","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
294            self.tool_schemas_json,
295        )
296    }
297
298    /// Builds GLM XML tool-call controller instructions.
299    fn glm_xml_controller_content(&self, requirement: &str, optional_rule: &str) -> String {
300        format!(
301            "You have access to tools.\n\n{requirement}\n\nRequired GLM XML tool-call format:\n\n{TOOL_CALL_START}TOOL_NAME\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nInside each {TOOL_CALL_START} block:\n- Start with the exact tool name as plain text.\n- Then output one <arg_key>/<arg_value> pair for each argument.\n- Put only the raw argument name inside <arg_key>.\n- Put only the raw argument value inside <arg_value>.\n- If an argument value is an object or array, put compact valid JSON inside <arg_value>.\n\nValid single-call example:\n\n{TOOL_CALL_START}TOOL_NAME\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nValid multi-call example:\n\n{TOOL_CALL_START}TOOL_NAME_1\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n{TOOL_CALL_START}TOOL_NAME_2\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nInvalid formats. NEVER use these:\n- {TOOL_CALL_START}TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}){TOOL_CALL_END}\n- {TOOL_CALL_START}TOOL_NAME{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}{TOOL_CALL_END}\n- {TOOL_CALL_START}{{\"name\":\"TOOL_NAME\",\"arguments\":{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}}}{TOOL_CALL_END}\n- TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}})\n\nRules:\n- TOOL_NAME must exactly match one available tool name.\n- Do not output JSON inside {TOOL_CALL_START} except for object/array values inside <arg_value>.\n- Do not use function-call syntax like TOOL_NAME(...).\n- Do not use the Hermes JSON format with \"name\" and \"arguments\" keys.\n- Argument names and values must satisfy the tool schema.\n- Emit one marker block per tool call.\n- Do not include markdown fences.\n- Do not include explanations.{optional_rule}\n\nAvailable tools:\n{}",
302            self.tool_schemas_json,
303        )
304    }
305
306    /// Builds Hermes-style correction instructions after invalid tool-call output.
307    fn hermes_correction_content(&self, validation_error: &str, invalid_output: &str) -> String {
308        format!(
309            "Your previous response attempted a tool call, but it was invalid.\n\nValidation error:\n{validation_error}\n\nInvalid output:\n{invalid_output}\n\nYou must now return only valid tool calls and nothing else.\n\nUse this exact format:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nInside each {TOOL_CALL_START} block, output ONLY one valid JSON object with exactly these top-level keys:\n- \"name\": the tool name as a JSON string.\n- \"arguments\": a JSON object containing the tool arguments.\n\nValid example:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nInvalid formats. NEVER use these:\n- {TOOL_CALL_START}TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}){TOOL_CALL_END}\n- {TOOL_CALL_START}TOOL_NAME{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}{TOOL_CALL_END}\n- TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}})\n\nRules:\n- TOOL_NAME must exactly match one of the available tools.\n- Always put the tool name in the JSON \"name\" field.\n- Always put tool arguments inside the JSON \"arguments\" object.\n- Do not put arguments directly after the tool name.\n- Do not use function-call syntax like TOOL_NAME(...).\n- arguments must be a JSON object.\n- arguments must satisfy the tool schema.\n- Do not include markdown fences.\n- Do not include explanations.\n- Do not answer the user directly.\n\nAvailable tools:\n{}",
310            r#"{"name":"TOOL_NAME","arguments":{...}}"#,
311            r#"{"name":"TOOL_NAME","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
312            self.tool_schemas_json,
313        )
314    }
315
316    /// Builds Qwen XML-wrapped JSON correction instructions after invalid tool-call output.
317    fn qwen_xml_correction_content(&self, validation_error: &str, invalid_output: &str) -> String {
318        format!(
319            "Your previous response attempted a tool call, but it was invalid.\n\nValidation error:\n{validation_error}\n\nInvalid output:\n{invalid_output}\n\nYou must now return only valid tool calls and nothing else.\n\nUse this exact Qwen XML-wrapped JSON format:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nThere MUST be a newline immediately after {TOOL_CALL_START}. Inside each block, output ONLY one valid JSON object with \"name\" and \"arguments\" top-level keys.\n\nAvailable tools:\n{}",
320            r#"{"name":"TOOL_NAME","arguments":{...}}"#, self.tool_schemas_json,
321        )
322    }
323
324    /// Builds GLM XML correction instructions after invalid tool-call output.
325    fn glm_xml_correction_content(&self, validation_error: &str, invalid_output: &str) -> String {
326        format!(
327            "Your previous response attempted a tool call, but it was invalid.\n\nValidation error:\n{validation_error}\n\nInvalid output:\n{invalid_output}\n\nYou must now return only valid tool calls and nothing else.\n\nUse this exact GLM XML format:\n\n{TOOL_CALL_START}TOOL_NAME\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nInside each {TOOL_CALL_START} block:\n- Start with the exact tool name as plain text.\n- Then output one <arg_key>/<arg_value> pair for each argument.\n- Put only the raw argument name inside <arg_key>.\n- Put only the raw argument value inside <arg_value>.\n\nValid example:\n\n{TOOL_CALL_START}TOOL_NAME\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nInvalid formats. NEVER use these:\n- {TOOL_CALL_START}TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}){TOOL_CALL_END}\n- {TOOL_CALL_START}TOOL_NAME{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}{TOOL_CALL_END}\n- {TOOL_CALL_START}{{\"name\":\"TOOL_NAME\",\"arguments\":{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}}}{TOOL_CALL_END}\n- TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}})\n\nRules:\n- TOOL_NAME must exactly match one of the available tools.\n- Do not output JSON inside {TOOL_CALL_START} except for object/array values inside <arg_value>.\n- Do not use function-call syntax like TOOL_NAME(...).\n- Do not use the Hermes JSON format with \"name\" and \"arguments\" keys.\n- Argument names and values must satisfy the tool schema.\n- Do not include markdown fences.\n- Do not include explanations.\n- Do not answer the user directly.\n\nAvailable tools:\n{}",
328            self.tool_schemas_json,
329        )
330    }
331
332    /// Classifies buffered assistant output into normal text, validated tool
333    /// calls, or an invalid tool call that feeds the retry/correction loop.
334    ///
335    /// Mixed text + tool calls classifies as tool calls; the surrounding text
336    /// is dropped from the OpenAI message, matching previous behavior.
337    pub fn classify_assistant_output(&self, output: &str) -> ToolOutputClassification {
338        if output.len() > self.config.tool_call_max_bytes {
339            return ToolOutputClassification::InvalidToolCall {
340                error: ToolCallValidationError::new(format!(
341                    "assistant output exceeded the tool call max size of {} bytes",
342                    self.config.tool_call_max_bytes
343                )),
344                invalid_output: output.to_owned(),
345            };
346        }
347
348        let result = self.parse_tool_calls(output);
349
350        match result {
351            Ok(tool_calls) if tool_calls.is_empty() => {
352                if self.require_tool_call {
353                    ToolOutputClassification::InvalidToolCall {
354                        error: ToolCallValidationError::new(
355                            "expected the assistant response to include a tool call",
356                        ),
357                        invalid_output: output.to_owned(),
358                    }
359                } else {
360                    ToolOutputClassification::NormalText
361                }
362            }
363            Ok(tool_calls) => ToolOutputClassification::ToolCalls(tool_calls),
364            Err(error) => ToolOutputClassification::InvalidToolCall {
365                error,
366                invalid_output: output.to_owned(),
367            },
368        }
369    }
370
371    /// Parses and validates assistant output, using Hermes as a compatibility fallback when needed.
372    fn parse_tool_calls(
373        &self,
374        output: &str,
375    ) -> Result<Vec<ValidatedToolCall>, ToolCallValidationError> {
376        let result = self.parse_tool_calls_with_format(self.prompt_format, output);
377        if self.prompt_format == ToolPromptFormat::HermesJson {
378            return result;
379        }
380
381        match result {
382            Ok(tool_calls) if tool_calls.is_empty() && output.contains(TOOL_CALL_START) => {
383                if let Some(fallback_calls) = self.hermes_fallback_tool_calls(output) {
384                    Ok(fallback_calls)
385                } else {
386                    Ok(tool_calls)
387                }
388            }
389            Err(error) => {
390                if let Some(fallback_calls) = self.hermes_fallback_tool_calls(output) {
391                    Ok(fallback_calls)
392                } else {
393                    Err(error)
394                }
395            }
396            result => result,
397        }
398    }
399
400    /// Attempts to parse non-empty Hermes JSON tool calls from model output.
401    fn hermes_fallback_tool_calls(&self, output: &str) -> Option<Vec<ValidatedToolCall>> {
402        match self.parse_tool_calls_with_format(ToolPromptFormat::HermesJson, output) {
403            Ok(tool_calls) if !tool_calls.is_empty() => Some(tool_calls),
404            _ => None,
405        }
406    }
407
408    /// Parses and validates assistant output using a specific prompt/parser format.
409    fn parse_tool_calls_with_format(
410        &self,
411        format: ToolPromptFormat,
412        output: &str,
413    ) -> Result<Vec<ValidatedToolCall>, ToolCallValidationError> {
414        self.create_parser_for_format(format)
415            .and_then(|mut parser| {
416                parser.parse_complete(output).map_err(|error| {
417                    ToolCallValidationError::new(format!("tool call parsing failed: {error}"))
418                })
419            })
420            .and_then(|result| {
421                result
422                    .calls
423                    .iter()
424                    .map(|call| self.validate_tool_call(call))
425                    .collect::<Result<Vec<_>, _>>()
426            })
427    }
428
429    /// Validates one coalesced parser tool call against the request's tools.
430    fn validate_tool_call(
431        &self,
432        call: &ToolCallDelta,
433    ) -> Result<ValidatedToolCall, ToolCallValidationError> {
434        let name = call.name.as_deref().unwrap_or_default();
435        if name.trim().is_empty() {
436            return Err(ToolCallValidationError::new(
437                "tool call name must not be empty",
438            ));
439        }
440        let tool = self
441            .tools
442            .iter()
443            .find(|tool| tool.name() == name)
444            .ok_or_else(|| ToolCallValidationError::new(format!("unknown tool name {name:?}")))?;
445
446        let arguments: Value = serde_json::from_str(&call.arguments).map_err(|source| {
447            ToolCallValidationError::new(format!("tool call arguments JSON is invalid: {source}"))
448        })?;
449        if !arguments.is_object() {
450            return Err(ToolCallValidationError::new(
451                "tool call arguments must be a JSON object",
452            ));
453        }
454
455        if self.config.validate_json_schema
456            && let Some(schema) = tool.parameters_schema()
457        {
458            validate_value_against_schema(&arguments, schema, "arguments").map_err(|message| {
459                ToolCallValidationError::new(format!(
460                    "tool call arguments do not satisfy schema: {message}"
461                ))
462            })?;
463        }
464
465        let arguments_json = serde_json::to_string(&arguments).map_err(|source| {
466            ToolCallValidationError::new(format!(
467                "tool call arguments could not be serialized as JSON: {source}"
468            ))
469        })?;
470
471        Ok(ValidatedToolCall {
472            id: generate_tool_call_id(),
473            name: name.to_owned(),
474            arguments_json,
475        })
476    }
477}
478
479/// Truncates text to at most `max_bytes` (on a char boundary), marking the cut.
480fn truncate_at_char_boundary(text: &str, max_bytes: usize) -> std::borrow::Cow<'_, str> {
481    if text.len() <= max_bytes {
482        return std::borrow::Cow::Borrowed(text);
483    }
484    let mut end = max_bytes;
485    while !text.is_char_boundary(end) {
486        end -= 1;
487    }
488    std::borrow::Cow::Owned(format!("{} [output truncated]", &text[..end]))
489}
490
491/// Lenient wrapper around the strict Hermes parser, tolerating model
492/// deviations observed live against Venice:
493/// - A parse that fails to finish (the model or an upstream stop sequence
494///   dropped `</tool_call>`; seen with `e2ee-glm-4-7-flash-p`) is retried
495///   once with the closing marker appended.
496/// - Trailing garbage after a tool call with complete JSON arguments (e.g.
497///   `e2ee-glm-5-1` closing a call with a stray `</arg_value>`) drains the
498///   rest of the output instead of failing. Input is split before each `<`
499///   so such garbage reaches the parser in its own push and cannot take
500///   already-parsed deltas down with it.
501struct LenientToolParser {
502    inner: Box<dyn ToolParser>,
503    /// Tracks whether the most recent call's argument JSON is complete, to
504    /// distinguish trailing garbage from a truncated call.
505    args_scanner: ArgsCompletenessScanner,
506    /// Set once trailing garbage was detected; all further input is ignored.
507    drained: bool,
508}
509
510impl ToolParser for LenientToolParser {
511    /// Creates a lenient Hermes parser for the supplied tools.
512    fn create(tools: &[Tool]) -> ToolParserResult<Box<dyn ToolParser>> {
513        Ok(Box::new(Self {
514            inner: HermesToolParser::create(tools)?,
515            args_scanner: ArgsCompletenessScanner::default(),
516            drained: false,
517        }))
518    }
519
520    /// Pushes one assistant output chunk through the parser, appending parsed deltas to `output`.
521    fn parse_into(&mut self, chunk: &str, output: &mut ToolParseResult) -> ToolParserResult<()> {
522        output.append(self.push(chunk)?);
523        Ok(())
524    }
525
526    /// Pushes one assistant output chunk through the parser and returns parsed deltas.
527    fn push(&mut self, chunk: &str) -> ToolParserResult<ToolParseResult> {
528        let mut merged = ToolParseResult::default();
529        if self.drained {
530            return Ok(merged);
531        }
532        for piece in split_before_tag_starts(chunk) {
533            match self.inner.push(piece) {
534                Ok(result) => {
535                    self.args_scanner.track(&result);
536                    merged.normal_text.push_str(&result.normal_text);
537                    merged.calls.extend(result.calls);
538                }
539                Err(error) => {
540                    if !self.args_scanner.complete() {
541                        return Err(error);
542                    }
543                    // Some live GLM outputs append a native closing tag after complete
544                    // Hermes arguments; keep the parsed call rather than failing on that
545                    // incompatible tail.
546                    warn!(%error, "ignoring trailing output after a complete tool call");
547                    self.drained = true;
548                    break;
549                }
550            }
551        }
552        Ok(merged)
553    }
554
555    /// Finishes parsing and returns any recovered complete tool calls.
556    fn finish(&mut self) -> ToolParserResult<ToolParseResult> {
557        if self.drained {
558            return Ok(ToolParseResult::default());
559        }
560        let error = match self.inner.finish() {
561            Ok(result) => return Ok(result),
562            Err(error) => error,
563        };
564        // Venice sometimes cuts only the closing marker; keep the original
565        // parser error when appending that marker still cannot recover a complete call.
566        let Ok(mut recovered) = self.inner.push(TOOL_CALL_END) else {
567            return Err(error);
568        };
569        let Ok(finished) = self.inner.finish() else {
570            return Err(error);
571        };
572        recovered.normal_text.push_str(&finished.normal_text);
573        recovered.calls.extend(finished.calls);
574        Ok(recovered)
575    }
576
577    /// Clears parser state and returns uncommitted buffered text.
578    fn reset(&mut self) -> String {
579        self.args_scanner = ArgsCompletenessScanner::default();
580        self.drained = false;
581        self.inner.reset()
582    }
583}
584
585/// Splits text so every `<` starts a new piece, isolating each potential tag
586/// (marker, native-format tag, or garbage) in its own parser push.
587fn split_before_tag_starts(text: &str) -> Vec<&str> {
588    let mut pieces = Vec::new();
589    let mut start = 0;
590    for (index, _) in text.match_indices('<') {
591        if index > start {
592            pieces.push(&text[start..index]);
593        }
594        start = index;
595    }
596    if start < text.len() {
597        pieces.push(&text[start..]);
598    }
599    pieces
600}
601
602/// Minimal JSON scanner tracking whether the most recent tool call's
603/// argument text forms a complete JSON value (balanced braces/brackets
604/// outside strings).
605#[derive(Debug, Default)]
606struct ArgsCompletenessScanner {
607    started: bool,
608    depth: u32,
609    in_string: bool,
610    escaped: bool,
611}
612
613impl ArgsCompletenessScanner {
614    /// Updates scanner state from parser deltas emitted for tool calls.
615    fn track(&mut self, result: &ToolParseResult) {
616        for call in &result.calls {
617            if call.name.is_some() {
618                *self = Self::default();
619            }
620            self.feed(&call.arguments);
621        }
622    }
623
624    /// Consumes an argument fragment and updates JSON completeness state.
625    fn feed(&mut self, fragment: &str) {
626        for ch in fragment.chars() {
627            if self.escaped {
628                self.escaped = false;
629                continue;
630            }
631            if self.in_string {
632                match ch {
633                    '\\' => self.escaped = true,
634                    '"' => self.in_string = false,
635                    _ => {}
636                }
637                continue;
638            }
639            match ch {
640                '"' => {
641                    self.in_string = true;
642                    self.started = true;
643                }
644                '{' | '[' => {
645                    self.depth += 1;
646                    self.started = true;
647                }
648                '}' | ']' => self.depth = self.depth.saturating_sub(1),
649                ch if !ch.is_whitespace() => self.started = true,
650                _ => {}
651            }
652        }
653    }
654
655    /// Returns whether the most recently observed argument text is a complete JSON value.
656    fn complete(&self) -> bool {
657        self.started && self.depth == 0 && !self.in_string
658    }
659}
660
661/// Classification of a decrypted assistant response under tool emulation.
662#[derive(Debug, Clone, PartialEq, Eq)]
663pub enum ToolOutputClassification {
664    NormalText,
665    ToolCalls(Vec<ValidatedToolCall>),
666    InvalidToolCall {
667        error: ToolCallValidationError,
668        invalid_output: String,
669    },
670}
671
672/// OpenAI-compatible tool call validated against the request's available tools.
673#[derive(Debug, Clone, PartialEq, Eq)]
674pub struct ValidatedToolCall {
675    pub id: String,
676    pub name: String,
677    pub arguments_json: String,
678}
679
680impl ValidatedToolCall {
681    /// Converts the validated call into an OpenAI `tool_calls` JSON object.
682    pub fn to_openai_value(&self) -> Value {
683        serde_json::json!({
684            "id": self.id,
685            "type": "function",
686            "function": {
687                "name": self.name,
688                "arguments": self.arguments_json,
689            },
690        })
691    }
692}
693
694/// Validation error for parsed assistant tool-call output.
695#[derive(Debug, Clone, PartialEq, Eq, Error)]
696#[error("{message}")]
697pub struct ToolCallValidationError {
698    message: String,
699}
700
701impl ToolCallValidationError {
702    /// Creates a validation error with a client-facing message.
703    fn new(message: impl Into<String>) -> Self {
704        Self {
705            message: message.into(),
706        }
707    }
708
709    /// Returns the validation error message.
710    pub fn message(&self) -> &str {
711        &self.message
712    }
713}
714
715/// Validates that a tool parameters schema uses the supported JSON Schema subset.
716fn validate_schema_shape(schema: &Map<String, Value>) -> Result<(), String> {
717    validate_schema_object_shape(schema, "schema")
718}
719
720/// Validates one schema object and nested supported schema objects.
721fn validate_schema_object_shape(object: &Map<String, Value>, path: &str) -> Result<(), String> {
722    if let Some(kind) = object.get("type") {
723        validate_schema_type_shape(kind, &format!("{path}.type"))?;
724    }
725    if let Some(required) = object.get("required") {
726        let required = required
727            .as_array()
728            .ok_or_else(|| format!("{path}.required must be an array"))?;
729        if required.iter().any(|value| !value.is_string()) {
730            return Err(format!("{path}.required must contain only strings"));
731        }
732    }
733    if let Some(properties) = object.get("properties") {
734        let properties = properties
735            .as_object()
736            .ok_or_else(|| format!("{path}.properties must be an object"))?;
737        for (name, schema) in properties {
738            let schema = schema
739                .as_object()
740                .ok_or_else(|| format!("{path}.properties.{name} must be an object"))?;
741            validate_schema_object_shape(schema, &format!("{path}.properties.{name}"))?;
742        }
743    }
744    if let Some(items) = object.get("items") {
745        let items = items
746            .as_object()
747            .ok_or_else(|| format!("{path}.items must be an object"))?;
748        validate_schema_object_shape(items, &format!("{path}.items"))?;
749    }
750    if let Some(additional) = object.get("additionalProperties") {
751        match additional {
752            Value::Bool(_) => {}
753            Value::Object(additional) => {
754                validate_schema_object_shape(additional, &format!("{path}.additionalProperties"))?
755            }
756            _ => {
757                return Err(format!(
758                    "{path}.additionalProperties must be a boolean or object"
759                ));
760            }
761        }
762    }
763    if let Some(enum_values) = object.get("enum")
764        && !enum_values.is_array()
765    {
766        return Err(format!("{path}.enum must be an array"));
767    }
768    Ok(())
769}
770
771/// Validates a JSON Schema `type` value in string or string-array form.
772fn validate_schema_type_shape(value: &Value, path: &str) -> Result<(), String> {
773    match value {
774        Value::String(kind) => validate_schema_type_name(kind, path),
775        Value::Array(kinds) => {
776            if kinds.is_empty() {
777                return Err(format!("{path} must not be an empty array"));
778            }
779            for kind in kinds {
780                let kind = kind
781                    .as_str()
782                    .ok_or_else(|| format!("{path} array must contain only strings"))?;
783                validate_schema_type_name(kind, path)?;
784            }
785            Ok(())
786        }
787        _ => Err(format!("{path} must be a string or array of strings")),
788    }
789}
790
791/// Validates one supported JSON Schema type name.
792fn validate_schema_type_name(kind: &str, path: &str) -> Result<(), String> {
793    match kind {
794        "object" | "array" | "string" | "integer" | "number" | "boolean" | "null" => Ok(()),
795        other => Err(format!(
796            "{path} contains unsupported JSON schema type {other:?}"
797        )),
798    }
799}
800
801/// Validates a JSON value against the supported JSON Schema subset.
802fn validate_value_against_schema(
803    value: &Value,
804    schema: &Map<String, Value>,
805    path: &str,
806) -> Result<(), String> {
807    if let Some(enum_values) = schema.get("enum").and_then(Value::as_array)
808        && !enum_values.iter().any(|enum_value| enum_value == value)
809    {
810        return Err(format!("{path} is not one of the allowed enum values"));
811    }
812
813    if let Some(kind) = schema.get("type")
814        && !schema_type_matches(value, kind)
815    {
816        return Err(format!(
817            "{path} expected type {}, got {}",
818            schema_type_description(kind),
819            value_kind(value)
820        ));
821    }
822
823    if schema_implies_object(schema) {
824        let object = value
825            .as_object()
826            .ok_or_else(|| format!("{path} expected object, got {}", value_kind(value)))?;
827        if let Some(required) = schema.get("required").and_then(Value::as_array) {
828            for field in required.iter().filter_map(Value::as_str) {
829                if !object.contains_key(field) {
830                    return Err(format!("{path}.{field} is required"));
831                }
832            }
833        }
834        let properties = schema.get("properties").and_then(Value::as_object);
835        if let Some(properties) = properties {
836            for (field, property_schema) in properties {
837                if let Some(property_value) = object.get(field) {
838                    let property_path = format!("{path}.{field}");
839                    let property_schema = schema_value_as_object(property_schema, &property_path)?;
840                    validate_value_against_schema(property_value, property_schema, &property_path)?;
841                }
842            }
843        }
844        if let Some(additional) = schema.get("additionalProperties") {
845            match additional {
846                Value::Bool(false) => {
847                    for field in object.keys() {
848                        if properties.is_none_or(|properties| !properties.contains_key(field)) {
849                            return Err(format!("{path}.{field} is not allowed by schema"));
850                        }
851                    }
852                }
853                Value::Object(additional_schema) => {
854                    for (field, additional_value) in object {
855                        if properties.is_none_or(|properties| !properties.contains_key(field)) {
856                            validate_value_against_schema(
857                                additional_value,
858                                additional_schema,
859                                &format!("{path}.{field}"),
860                            )?;
861                        }
862                    }
863                }
864                _ => {}
865            }
866        }
867    }
868
869    if schema_implies_array(schema) {
870        let array = value
871            .as_array()
872            .ok_or_else(|| format!("{path} expected array, got {}", value_kind(value)))?;
873        if let Some(items_schema) = schema.get("items") {
874            for (index, item) in array.iter().enumerate() {
875                let item_path = format!("{path}[{index}]");
876                let items_schema = schema_value_as_object(items_schema, &item_path)?;
877                validate_value_against_schema(item, items_schema, &item_path)?;
878            }
879        }
880    }
881
882    Ok(())
883}
884
885/// Reads a nested schema value as an object for recursive validation.
886fn schema_value_as_object<'a>(
887    schema: &'a Value,
888    path: &str,
889) -> Result<&'a Map<String, Value>, String> {
890    schema
891        .as_object()
892        .ok_or_else(|| format!("{path} schema must be an object"))
893}
894
895/// Returns whether a schema requires object-specific validation.
896fn schema_implies_object(schema: &Map<String, Value>) -> bool {
897    schema
898        .get("type")
899        .is_some_and(|kind| schema_type_includes(kind, "object"))
900        || schema.contains_key("properties")
901        || schema.contains_key("required")
902        || schema.contains_key("additionalProperties")
903}
904
905/// Returns whether a schema requires array-specific validation.
906fn schema_implies_array(schema: &Map<String, Value>) -> bool {
907    schema
908        .get("type")
909        .is_some_and(|kind| schema_type_includes(kind, "array"))
910        || schema.contains_key("items")
911}
912
913/// Returns whether a JSON value matches a schema `type` value.
914fn schema_type_matches(value: &Value, kind: &Value) -> bool {
915    match kind {
916        Value::String(kind) => value_matches_schema_type(value, kind),
917        Value::Array(kinds) => kinds
918            .iter()
919            .filter_map(Value::as_str)
920            .any(|kind| value_matches_schema_type(value, kind)),
921        _ => true,
922    }
923}
924
925/// Returns whether a schema `type` value includes a named type.
926fn schema_type_includes(kind: &Value, expected: &str) -> bool {
927    match kind {
928        Value::String(kind) => kind == expected,
929        Value::Array(kinds) => kinds
930            .iter()
931            .filter_map(Value::as_str)
932            .any(|kind| kind == expected),
933        _ => false,
934    }
935}
936
937/// Returns whether a JSON value matches one supported schema type name.
938fn value_matches_schema_type(value: &Value, kind: &str) -> bool {
939    match kind {
940        "object" => value.is_object(),
941        "array" => value.is_array(),
942        "string" => value.is_string(),
943        "integer" => value.as_i64().is_some() || value.as_u64().is_some(),
944        "number" => value.is_number(),
945        "boolean" => value.is_boolean(),
946        "null" => value.is_null(),
947        _ => true,
948    }
949}
950
951/// Formats a schema `type` value for validation error messages.
952fn schema_type_description(kind: &Value) -> String {
953    match kind {
954        Value::String(kind) => kind.clone(),
955        Value::Array(kinds) => kinds
956            .iter()
957            .filter_map(Value::as_str)
958            .collect::<Vec<_>>()
959            .join(" or "),
960        _ => "unknown".to_owned(),
961    }
962}
963
964/// Returns a human-readable kind name for a JSON value.
965fn value_kind(value: &Value) -> &'static str {
966    match value {
967        Value::Null => "null",
968        Value::Bool(_) => "boolean",
969        Value::Number(_) => "number",
970        Value::String(_) => "string",
971        Value::Array(_) => "array",
972        Value::Object(_) => "object",
973    }
974}
975
976#[cfg(test)]
977mod tests {
978    use serde_json::json;
979
980    use super::*;
981    use crate::config::ToolsConfig;
982
983    fn request_with_tool(arguments_schema: Value) -> ChatCompletionRequest {
984        request_with_tool_for_model("e2ee-test", arguments_schema)
985    }
986
987    fn request_with_tool_for_model(model: &str, arguments_schema: Value) -> ChatCompletionRequest {
988        ChatCompletionRequest::parse(&json!({
989            "model": model,
990            "messages": [{"role":"user", "content":"hi"}],
991            "tools": [{
992                "type": "function",
993                "function": {
994                    "name": "search_web",
995                    "description": "Search the web",
996                    "parameters": arguments_schema
997                }
998            }]
999        }))
1000        .expect("request should parse")
1001    }
1002
1003    fn context_for_request(request: &ChatCompletionRequest) -> ToolEmulationContext {
1004        ToolEmulationContext::from_request(&ToolsConfig::default(), request)
1005            .expect("tool context should build")
1006            .expect("tools should activate")
1007    }
1008
1009    #[test]
1010    fn classifies_valid_hermes_tool_call() {
1011        let request = request_with_tool(json!({
1012            "type": "object",
1013            "properties": {"query": {"type": "string"}},
1014            "required": ["query"],
1015            "additionalProperties": false
1016        }));
1017        let context = context_for_request(&request);
1018
1019        let classification = context.classify_assistant_output(
1020            "\n<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"Venice\"}}\n</tool_call>\n",
1021        );
1022
1023        let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1024            panic!("expected valid tool call");
1025        };
1026        assert_eq!(tool_calls.len(), 1);
1027        assert!(tool_calls[0].id.starts_with("call_"));
1028        assert_eq!(tool_calls[0].name, "search_web");
1029        assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"Venice\"}");
1030    }
1031
1032    #[test]
1033    fn classifies_glm_xml_tool_call_for_glm_models() {
1034        let request = request_with_tool_for_model(
1035            "e2ee-glm-5-1",
1036            json!({
1037                "type": "object",
1038                "properties": {"query": {"type": "string"}},
1039                "required": ["query"],
1040                "additionalProperties": false
1041            }),
1042        );
1043        let context = context_for_request(&request);
1044
1045        let classification = context.classify_assistant_output(
1046            "<tool_call>search_web\n<arg_key>query</arg_key><arg_value>Venice</arg_value></tool_call>",
1047        );
1048
1049        let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1050            panic!("expected valid GLM XML tool call");
1051        };
1052        assert_eq!(tool_calls.len(), 1);
1053        assert_eq!(tool_calls[0].name, "search_web");
1054        assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"Venice\"}");
1055        assert!(context.controller_message().content.contains("<arg_key>"));
1056    }
1057
1058    #[test]
1059    fn classifies_qwen_xml_tool_call_for_qwen_models() {
1060        let request = request_with_tool_for_model(
1061            "e2ee-qwen3-30b-a3b-p",
1062            json!({
1063                "type": "object",
1064                "properties": {"query": {"type": "string"}},
1065                "required": ["query"],
1066                "additionalProperties": false
1067            }),
1068        );
1069        let context = context_for_request(&request);
1070
1071        let classification = context.classify_assistant_output(
1072            "<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"Venice\"}}\n</tool_call>",
1073        );
1074
1075        let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1076            panic!("expected valid Qwen XML-wrapped JSON tool call");
1077        };
1078        assert_eq!(tool_calls.len(), 1);
1079        assert_eq!(tool_calls[0].name, "search_web");
1080        assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"Venice\"}");
1081        assert!(
1082            context
1083                .controller_message()
1084                .content
1085                .contains("Qwen XML-wrapped JSON")
1086        );
1087    }
1088
1089    #[test]
1090    fn rejects_invalid_json_unknown_tool_and_schema_mismatch() {
1091        let request = request_with_tool(json!({
1092            "type": "object",
1093            "properties": {"query": {"type": "string"}},
1094            "required": ["query"],
1095            "additionalProperties": false
1096        }));
1097        let context = context_for_request(&request);
1098
1099        // Hermes passes argument text through raw; invalid JSON is caught by
1100        // our validation layer.
1101        let ToolOutputClassification::InvalidToolCall { error, .. } = context
1102            .classify_assistant_output(
1103                "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"x\",}}</tool_call>",
1104            )
1105        else {
1106            panic!("expected invalid JSON to be rejected");
1107        };
1108        assert!(error.message().contains("JSON is invalid"));
1109
1110        let ToolOutputClassification::InvalidToolCall { error, .. } = context
1111            .classify_assistant_output(
1112                "<tool_call>{\"name\":\"unknown\",\"arguments\":{\"query\":\"x\"}}</tool_call>",
1113            )
1114        else {
1115            panic!("expected unknown tool to be rejected");
1116        };
1117        assert!(error.message().contains("unknown tool name"));
1118
1119        let ToolOutputClassification::InvalidToolCall { error, .. } = context
1120            .classify_assistant_output(
1121                "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"q\":\"x\"}}</tool_call>",
1122            )
1123        else {
1124            panic!("expected schema mismatch to be rejected");
1125        };
1126        assert!(error.message().contains("arguments.query is required"));
1127    }
1128
1129    #[test]
1130    fn recovers_tool_call_with_truncated_closing_marker() {
1131        // Observed live: Venice cuts `</tool_call>` for some models (likely a
1132        // stop sequence). A complete call missing only the closing marker is
1133        // recovered leniently.
1134        let request = request_with_tool(json!({"type": "object"}));
1135        let context = context_for_request(&request);
1136
1137        let classification = context.classify_assistant_output(
1138            "<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}\n",
1139        );
1140
1141        let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1142            panic!("expected truncated closing marker to be recovered, got {classification:?}");
1143        };
1144        assert_eq!(tool_calls.len(), 1);
1145        assert_eq!(tool_calls[0].name, "search_web");
1146        assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"a\"}");
1147    }
1148
1149    #[test]
1150    fn ignores_trailing_garbage_after_complete_tool_call() {
1151        // Exact outputs observed live from `e2ee-glm-5-1`: a Hermes-shaped
1152        // call "closed" with a stray GLM-native tag.
1153        let request = request_with_tool_for_model("e2ee-glm-5-1", json!({"type": "object"}));
1154        let context = context_for_request(&request);
1155
1156        for output in [
1157            "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}</arg_value>",
1158            "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}</arg_value></tool_call>",
1159        ] {
1160            let classification = context.classify_assistant_output(output);
1161            let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1162                panic!(
1163                    "expected trailing garbage to be ignored for {output:?}, got {classification:?}"
1164                );
1165            };
1166            assert_eq!(tool_calls.len(), 1);
1167            assert_eq!(tool_calls[0].name, "search_web");
1168            assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"a\"}");
1169        }
1170    }
1171
1172    #[test]
1173    fn classifies_output_truncated_mid_json_as_invalid_tool_call() {
1174        let request = request_with_tool(json!({"type": "object"}));
1175        let context = context_for_request(&request);
1176
1177        let classification =
1178            context.classify_assistant_output("<tool_call>{\"name\":\"search_web\",\"argu");
1179
1180        let ToolOutputClassification::InvalidToolCall { error, .. } = classification else {
1181            panic!("expected mid-JSON truncation to be invalid, got {classification:?}");
1182        };
1183        assert!(error.message().contains("tool call parsing failed"));
1184    }
1185
1186    #[test]
1187    fn classifies_plain_text_and_enforces_required_tool_call() {
1188        let request = request_with_tool(json!({"type": "object"}));
1189        let context = context_for_request(&request);
1190        assert_eq!(
1191            context.classify_assistant_output("Hello, world!"),
1192            ToolOutputClassification::NormalText
1193        );
1194
1195        let request = ChatCompletionRequest::parse(&json!({
1196            "model": "e2ee-test",
1197            "messages": [{"role":"user", "content":"hi"}],
1198            "tool_choice": "required",
1199            "tools": [{"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}}]
1200        }))
1201        .expect("request should parse");
1202        let context = context_for_request(&request);
1203
1204        let ToolOutputClassification::InvalidToolCall { error, .. } =
1205            context.classify_assistant_output("Hello, world!")
1206        else {
1207            panic!("expected missing required tool call to be invalid");
1208        };
1209        assert!(error.message().contains("expected the assistant response"));
1210    }
1211
1212    #[test]
1213    fn classifies_mixed_text_and_tool_call_as_tool_calls() {
1214        let request = request_with_tool(json!({"type": "object"}));
1215        let context = context_for_request(&request);
1216
1217        let classification = context.classify_assistant_output(
1218            "Let me check.\n<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}</tool_call>",
1219        );
1220
1221        let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1222            panic!("expected mixed output to classify as tool calls");
1223        };
1224        assert_eq!(tool_calls.len(), 1);
1225    }
1226
1227    #[test]
1228    fn classifies_multiple_tool_calls_regardless_of_parallel_tool_calls() {
1229        let request = ChatCompletionRequest::parse(&json!({
1230            "model": "e2ee-test",
1231            "messages": [{"role":"user", "content":"hi"}],
1232            "parallel_tool_calls": false,
1233            "tools": [{"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}}]
1234        }))
1235        .expect("request should parse");
1236        let context = context_for_request(&request);
1237
1238        // `parallel_tool_calls` is accepted for OpenAI compatibility but
1239        // ignored; all parsed tool calls are returned.
1240        let classification = context.classify_assistant_output(
1241            "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}</tool_call>\n<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"b\"}}</tool_call>",
1242        );
1243        let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1244            panic!("expected two valid tool calls");
1245        };
1246        assert_eq!(tool_calls.len(), 2);
1247        assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"a\"}");
1248        assert_eq!(tool_calls[1].arguments_json, "{\"query\":\"b\"}");
1249        assert_ne!(tool_calls[0].id, tool_calls[1].id);
1250    }
1251
1252    #[test]
1253    fn rejects_oversized_assistant_output() {
1254        let request = request_with_tool(json!({"type": "object"}));
1255        let config = ToolsConfig {
1256            tool_call_max_bytes: 32,
1257            ..ToolsConfig::default()
1258        };
1259        let context = ToolEmulationContext::from_request(&config, &request)
1260            .expect("tool context should build")
1261            .expect("tools should activate");
1262
1263        let ToolOutputClassification::InvalidToolCall { error, .. } =
1264            context.classify_assistant_output(&"x".repeat(33))
1265        else {
1266            panic!("expected oversized output to be invalid");
1267        };
1268        assert!(error.message().contains("max size of 32 bytes"));
1269    }
1270
1271    #[test]
1272    fn can_disable_schema_validation_explicitly() {
1273        let request = request_with_tool(json!({
1274            "type": "object",
1275            "required": ["query"]
1276        }));
1277        let config = ToolsConfig {
1278            validate_json_schema: false,
1279            ..ToolsConfig::default()
1280        };
1281        let context = ToolEmulationContext::from_request(&config, &request)
1282            .expect("tool context should build")
1283            .expect("tools should activate");
1284
1285        let classification = context.classify_assistant_output(
1286            "<tool_call>{\"name\":\"search_web\",\"arguments\":{}}</tool_call>",
1287        );
1288        let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1289            panic!("schema mismatch should be allowed when validation is disabled");
1290        };
1291        assert_eq!(tool_calls[0].arguments_json, "{}");
1292    }
1293
1294    #[test]
1295    fn rejects_non_object_arguments() {
1296        let request = request_with_tool(json!({"type": "object"}));
1297        let context = context_for_request(&request);
1298
1299        // The Hermes parser itself rejects non-object argument payloads, so
1300        // this surfaces as a parser failure rather than reaching our
1301        // arguments-must-be-an-object validation.
1302        let ToolOutputClassification::InvalidToolCall { error, .. } = context
1303            .classify_assistant_output(
1304                "<tool_call>{\"name\":\"search_web\",\"arguments\":[]}</tool_call>",
1305            )
1306        else {
1307            panic!("expected non-object arguments to be rejected");
1308        };
1309        assert!(error.message().contains("tool call parsing failed"));
1310
1311        // Our validation layer still rejects non-object arguments that a
1312        // parser passes through (defense in depth for other families).
1313        let error = context
1314            .validate_tool_call(&ToolCallDelta {
1315                tool_index: 0,
1316                name: Some("search_web".to_owned()),
1317                arguments: "[]".to_owned(),
1318            })
1319            .unwrap_err();
1320        assert!(error.message().contains("arguments must be a JSON object"));
1321    }
1322
1323    #[test]
1324    fn builds_controller_and_retry_prompts() {
1325        let request = ChatCompletionRequest::parse(&json!({
1326            "model": "e2ee-test",
1327            "messages": [{"role":"user", "content":"hi"}],
1328            "tool_choice": "required",
1329            "tools": [{"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}}]
1330        }))
1331        .expect("request should parse");
1332        let context = context_for_request(&request);
1333
1334        let controller = context.controller_message();
1335        assert_eq!(controller.role, "user");
1336        assert!(
1337            controller
1338                .content
1339                .contains("You must call at least one tool")
1340        );
1341        assert!(
1342            controller
1343                .content
1344                .contains("Emit one marker block per tool call")
1345        );
1346        assert!(controller.content.contains("<tool_call>"));
1347        assert!(controller.content.contains("search_web"));
1348
1349        let correction = context.correction_message("bad name", "<tool_call>{}</tool_call>");
1350        assert_eq!(correction.role, "system");
1351        assert!(correction.content.contains("Validation error:\nbad name"));
1352        assert!(
1353            correction
1354                .content
1355                .contains("Invalid output:\n<tool_call>{}</tool_call>")
1356        );
1357        assert!(
1358            correction
1359                .content
1360                .contains("You must now return only valid tool calls")
1361        );
1362
1363        let optional_request = ChatCompletionRequest::parse(&json!({
1364            "model": "e2ee-test",
1365            "messages": [{"role":"user", "content":"hi"}],
1366            "tools": [{"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}}]
1367        }))
1368        .expect("request should parse");
1369        let optional = context_for_request(&optional_request);
1370        assert!(
1371            optional
1372                .controller_message()
1373                .content
1374                .contains("If no tool is needed, answer normally")
1375        );
1376    }
1377
1378    #[test]
1379    fn correction_prompt_truncates_oversized_invalid_output() {
1380        let request = request_with_tool(json!({"type": "object"}));
1381        let context = context_for_request(&request);
1382
1383        let oversized = "x".repeat(CORRECTION_INVALID_OUTPUT_MAX_BYTES + 1);
1384        let correction = context.correction_message("error", &oversized);
1385        assert!(correction.content.contains("[output truncated]"));
1386        assert!(!correction.content.contains(&oversized));
1387
1388        let short = context.correction_message("error", "<tool_call>{}</tool_call>");
1389        assert!(!short.content.contains("[output truncated]"));
1390    }
1391
1392    #[test]
1393    fn specific_tool_choice_filters_available_tools() {
1394        let request = ChatCompletionRequest::parse(&json!({
1395            "model": "e2ee-test",
1396            "messages": [{"role":"user", "content":"hi"}],
1397            "tool_choice": {"type":"function", "function":{"name":"search_web"}},
1398            "tools": [
1399                {"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}},
1400                {"type":"function", "function":{"name":"other", "parameters":{"type":"object"}}}
1401            ]
1402        }))
1403        .expect("request should parse");
1404        let context = context_for_request(&request);
1405
1406        assert!(context.controller_message().content.contains("search_web"));
1407        assert!(!context.controller_message().content.contains("other"));
1408    }
1409}