Skip to main content

nika_engine/tools/
submit_tool.rs

1//! Dynamic Submit Tool - Structured Output Enforcement
2//!
3//! Based on rig's Extractor pattern - injects JSON Schema as tool definition
4//! so LLM understands expected output format.
5//!
6//! # Usage
7//!
8//! ```rust,ignore
9//! use nika::tools::DynamicSubmitTool;
10//! use serde_json::json;
11//!
12//! let schema = json!({
13//!     "type": "object",
14//!     "properties": {
15//!         "keywords": { "type": "array", "items": { "type": "string" } }
16//!     },
17//!     "required": ["keywords"]
18//! });
19//!
20//! let tool = DynamicSubmitTool::new(schema);
21//! let definition = tool.to_claude_tool();
22//! ```
23
24use serde::{Deserialize, Serialize};
25use serde_json::Value;
26
27use crate::error::NikaError;
28use crate::runtime::output::validate_inline_schema;
29
30/// Tool definition for provider-agnostic tool calling
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ToolDefinition {
33    pub name: String,
34    pub description: String,
35    pub parameters: Value,
36}
37
38/// Dynamic submit tool that enforces structured output via tool calling
39///
40/// When added to a chat completion request, the LLM is instructed to "submit"
41/// its response using this tool, which includes the expected JSON Schema.
42#[derive(Debug, Clone)]
43pub struct DynamicSubmitTool {
44    schema: Value,
45    description: Option<String>,
46}
47
48impl DynamicSubmitTool {
49    /// Create a new submit tool with the given JSON Schema
50    pub fn new(schema: Value) -> Self {
51        Self {
52            schema,
53            description: None,
54        }
55    }
56
57    /// Set a custom description for the tool
58    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
59        self.description = Some(desc.into());
60        self
61    }
62
63    /// Get the tool definition for the LLM
64    pub fn definition(&self) -> ToolDefinition {
65        ToolDefinition {
66            name: "submit".to_string(),
67            description: self.description.clone().unwrap_or_else(|| {
68                "Submit your response in the required structured format. \
69                 Use this tool to provide your final answer. The response \
70                 MUST match the schema exactly."
71                    .to_string()
72            }),
73            parameters: self.schema.clone(),
74        }
75    }
76
77    /// Validate and return the submitted data
78    pub fn validate(&self, input: &Value) -> Result<(), NikaError> {
79        validate_inline_schema(input, &self.schema)
80    }
81
82    /// Get the schema for error feedback
83    pub fn schema(&self) -> &Value {
84        &self.schema
85    }
86
87    /// Convert to Claude tool format
88    pub fn to_claude_tool(&self) -> Value {
89        serde_json::json!({
90            "name": "submit",
91            "description": self.definition().description,
92            "input_schema": self.schema
93        })
94    }
95
96    /// Convert to OpenAI function format
97    pub fn to_openai_tool(&self) -> Value {
98        serde_json::json!({
99            "type": "function",
100            "function": {
101                "name": "submit",
102                "description": self.definition().description,
103                "parameters": self.schema
104            }
105        })
106    }
107
108    /// Convert to generic tool format (for rig-core)
109    pub fn to_rig_tool(&self) -> Value {
110        serde_json::json!({
111            "name": "submit",
112            "description": self.definition().description,
113            "parameters": self.schema
114        })
115    }
116}
117
118// ═══════════════════════════════════════════════════════════════════════════
119// TESTS
120// ═══════════════════════════════════════════════════════════════════════════
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use serde_json::json;
126
127    #[test]
128    fn test_submit_tool_new() {
129        let schema = json!({
130            "type": "object",
131            "properties": {
132                "name": { "type": "string" }
133            }
134        });
135
136        let tool = DynamicSubmitTool::new(schema.clone());
137        assert_eq!(tool.schema(), &schema);
138    }
139
140    #[test]
141    fn test_submit_tool_definition() {
142        let schema = json!({
143            "type": "object",
144            "properties": {
145                "keywords": {
146                    "type": "array",
147                    "items": { "type": "string" }
148                }
149            },
150            "required": ["keywords"]
151        });
152
153        let tool = DynamicSubmitTool::new(schema.clone());
154        let def = tool.definition();
155
156        assert_eq!(def.name, "submit");
157        assert_eq!(def.parameters, schema);
158        assert!(def.description.contains("structured format"));
159    }
160
161    #[test]
162    fn test_submit_tool_with_custom_description() {
163        let schema = json!({"type": "object"});
164        let tool = DynamicSubmitTool::new(schema)
165            .with_description("Extract SEO keywords from the content");
166
167        let def = tool.definition();
168        assert_eq!(def.description, "Extract SEO keywords from the content");
169    }
170
171    #[test]
172    fn test_submit_tool_validate_success() {
173        let schema = json!({
174            "type": "object",
175            "properties": {
176                "name": { "type": "string" }
177            },
178            "required": ["name"]
179        });
180
181        let tool = DynamicSubmitTool::new(schema);
182        let input = json!({"name": "test"});
183        assert!(tool.validate(&input).is_ok());
184    }
185
186    #[test]
187    fn test_submit_tool_validate_failure() {
188        let schema = json!({
189            "type": "object",
190            "properties": {
191                "name": { "type": "string" }
192            },
193            "required": ["name"]
194        });
195
196        let tool = DynamicSubmitTool::new(schema);
197        let input = json!({"wrong": "field"});
198        let result = tool.validate(&input);
199        assert!(result.is_err());
200    }
201
202    #[test]
203    fn test_to_claude_tool_format() {
204        let schema = json!({
205            "type": "object",
206            "properties": {
207                "result": { "type": "string" }
208            }
209        });
210
211        let tool = DynamicSubmitTool::new(schema.clone());
212        let claude_tool = tool.to_claude_tool();
213
214        assert_eq!(claude_tool["name"], "submit");
215        assert_eq!(claude_tool["input_schema"], schema);
216        assert!(claude_tool["description"].is_string());
217    }
218
219    #[test]
220    fn test_to_openai_tool_format() {
221        let schema = json!({
222            "type": "object",
223            "properties": {
224                "result": { "type": "string" }
225            }
226        });
227
228        let tool = DynamicSubmitTool::new(schema.clone());
229        let openai_tool = tool.to_openai_tool();
230
231        assert_eq!(openai_tool["type"], "function");
232        assert_eq!(openai_tool["function"]["name"], "submit");
233        assert_eq!(openai_tool["function"]["parameters"], schema);
234    }
235
236    #[test]
237    fn test_complex_schema_validation() {
238        let schema = json!({
239            "type": "object",
240            "properties": {
241                "keywords": {
242                    "type": "array",
243                    "items": {
244                        "type": "object",
245                        "properties": {
246                            "value": { "type": "string" },
247                            "slug_form": { "type": "string", "pattern": "^[a-z0-9-]+$" },
248                            "volume": { "type": "integer", "minimum": 0 },
249                            "difficulty": { "type": "integer", "minimum": 0, "maximum": 100 }
250                        },
251                        "required": ["value", "slug_form", "volume", "difficulty"]
252                    }
253                }
254            },
255            "required": ["keywords"]
256        });
257
258        let tool = DynamicSubmitTool::new(schema);
259
260        // Valid input
261        let valid = json!({
262            "keywords": [{
263                "value": "qr code generator",
264                "slug_form": "qr-code-generator",
265                "volume": 10000,
266                "difficulty": 45
267            }]
268        });
269        assert!(tool.validate(&valid).is_ok());
270
271        // Invalid: missing required field
272        let missing_field = json!({
273            "keywords": [{
274                "value": "qr code",
275                "slug_form": "qr-code",
276                "volume": 5000
277                // missing difficulty
278            }]
279        });
280        assert!(tool.validate(&missing_field).is_err());
281
282        // Invalid: wrong type
283        let wrong_type = json!({
284            "keywords": [{
285                "value": "qr code",
286                "slug_form": "qr-code",
287                "volume": "not a number",
288                "difficulty": 50
289            }]
290        });
291        assert!(tool.validate(&wrong_type).is_err());
292
293        // Invalid: difficulty out of range
294        let out_of_range = json!({
295            "keywords": [{
296                "value": "qr code",
297                "slug_form": "qr-code",
298                "volume": 5000,
299                "difficulty": 150  // max is 100
300            }]
301        });
302        assert!(tool.validate(&out_of_range).is_err());
303    }
304}