Skip to main content

magic_coder_types/
tools.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[cfg(feature = "schemars")]
5use schemars::{JsonSchema, schema_for};
6
7/// Convert a Rust struct schema into an OpenAI tool `parameters` object.
8///
9/// - remove `$schema` and `title`
10/// - convert `definitions` to `$defs`
11/// - convert `oneOf` to `anyOf`
12#[cfg(feature = "schemars")]
13pub fn tool_parameters<T: JsonSchema>() -> Value {
14    let mut v = serde_json::to_value(schema_for!(T)).expect("can't parse value from schema");
15
16    // remove the $schema and title fields
17    if let Some(value) = v.as_object_mut() {
18        value.remove("$schema");
19        value.remove("title");
20    }
21
22    let mut v_str = serde_json::to_string(&v).unwrap();
23    v_str = v_str
24        .replace("/definitions/", "/$defs/")
25        .replace("\"definitions\":", "\"$defs\":");
26
27    // Replace oneOf with anyOf, because it's better supported by the LLMs
28    v_str = v_str.replace("\"oneOf\":", "\"anyOf\":");
29
30    let mut v: Value = serde_json::from_str(&v_str).expect("can't parse value from updated schema");
31    enforce_openai_strict_schema(&mut v);
32    v
33}
34
35#[cfg(feature = "schemars")]
36fn enforce_openai_strict_schema(v: &mut Value) {
37    match v {
38        Value::Object(map) => {
39            // Recurse first so we fix nested schemas too.
40            for (_k, child) in map.iter_mut() {
41                enforce_openai_strict_schema(child);
42            }
43
44            // If this looks like an object schema, enforce strict rules.
45            let is_object = map
46                .get("type")
47                .and_then(|t| t.as_str())
48                .is_some_and(|t| t == "object");
49            let has_props = map.get("properties").is_some();
50            if is_object || has_props {
51                map.entry("additionalProperties".to_string())
52                    .or_insert(Value::Bool(false));
53
54                if let Some(Value::Object(props)) = map.get("properties") {
55                    let mut keys: Vec<String> = props.keys().cloned().collect();
56                    keys.sort();
57                    map.insert(
58                        "required".to_string(),
59                        Value::Array(keys.into_iter().map(Value::String).collect()),
60                    );
61                }
62            }
63        }
64        Value::Array(arr) => {
65            for child in arr.iter_mut() {
66                enforce_openai_strict_schema(child);
67            }
68        }
69        _ => {}
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74#[cfg_attr(feature = "schemars", derive(JsonSchema))]
75pub struct ReadFileArgs {
76    /// Path to file.
77    pub path: String,
78    /// Optional starting line (0-based).
79    pub offset: Option<usize>,
80    /// Optional maximum number of lines to read.
81    pub limit: Option<usize>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85#[cfg_attr(feature = "schemars", derive(JsonSchema))]
86pub struct ListDirArgs {
87    /// Directory path to list.
88    pub path: String,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
92#[cfg_attr(feature = "schemars", derive(JsonSchema))]
93#[serde(rename_all = "snake_case")]
94pub enum GlobKind {
95    Files,
96    Dirs,
97    All,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101#[cfg_attr(feature = "schemars", derive(JsonSchema))]
102pub struct GlobArgs {
103    /// Glob pattern to match. Supports `*`, `**`, `?`, and character classes.
104    pub pattern: String,
105    /// Optional directory root to search under. Defaults to `"."`.
106    pub path: Option<String>,
107    /// Optional maximum number of returned paths. Defaults to `50`.
108    pub limit: Option<usize>,
109    /// Optional match kind. Defaults to `files`.
110    pub kind: Option<GlobKind>,
111    /// Optional exclude patterns. Defaults to an empty list.
112    #[serde(default)]
113    pub exclude: Vec<String>,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
117#[cfg_attr(feature = "schemars", derive(JsonSchema))]
118pub struct GrepArgs {
119    /// Regex pattern to search for.
120    pub pattern: String,
121    /// Optional path (file or directory) to search in.
122    pub path: Option<String>,
123    /// Optional glob filter, e.g. `"*.rs"`.
124    pub glob: Option<String>,
125    /// Optional limit for returned matches.
126    pub head_limit: Option<usize>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130#[cfg_attr(feature = "schemars", derive(JsonSchema))]
131pub struct RunShellArgs {
132    /// Shell command line to run (executed via `bash -lc`), supports pipes/redirection.
133    pub command: String,
134    /// Optional working directory (relative to project root).
135    pub cwd: Option<String>,
136    /// Optional timeout in seconds for foreground execution.
137    /// Omit to use the default 30 second timeout.
138    /// Must be omitted when `bg=true`.
139    /// For longer-running work like model training, set a larger value up front on the safe side to avoid retries.
140    pub timeout_seconds: Option<u64>,
141    /// Optional maximum captured bytes per stream (stdout/stderr) for foreground execution.
142    /// Must be omitted when `bg=true`.
143    ///
144    /// Truncated output keeps roughly the first 30% and last 70%, so very large
145    /// values are usually unnecessary; prefer a few KB or low tens of KB and only
146    /// increase if needed.
147    pub max_output_bytes: Option<u64>,
148    /// When true, spawn the shell in the background and return immediately with a shell id.
149    #[serde(default)]
150    pub bg: bool,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154#[cfg_attr(feature = "schemars", derive(JsonSchema))]
155pub struct ReadShellOutputArgs {
156    /// Background shell id returned by `run_shell` with `bg=true`.
157    pub shell_id: String,
158    /// When true, read from the start of the log. Defaults to `false` meaning read from the end.
159    #[serde(default)]
160    pub from_start: bool,
161    /// Optional 0-based line offset from the selected side. Defaults to `0`.
162    pub offset: Option<usize>,
163    /// Optional maximum number of lines to read. Defaults to `200`, max `1000`.
164    pub limit: Option<usize>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168#[cfg_attr(feature = "schemars", derive(JsonSchema))]
169pub struct StopShellArgs {
170    /// Background shell id returned by `run_shell` with `bg=true`.
171    pub shell_id: String,
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175#[cfg_attr(feature = "schemars", derive(JsonSchema))]
176pub struct SleepArgs {
177    /// Sleep duration in seconds. Clients may clamp this to a supported range.
178    pub seconds: u64,
179    /// Background shell ids to watch. Use an empty array for a plain timer.
180    /// If any watched shell exits early, the sleep may end early.
181    pub shell_ids: Vec<String>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185#[cfg_attr(feature = "schemars", derive(JsonSchema))]
186pub struct ApplyDiffArgs {
187    /// A git-style unified diff to apply to the working tree.
188    ///
189    /// You may pass either:
190    /// - the raw diff text starting with `diff --git ...`, OR
191    /// - a fenced diff block like ```diff ... ``` (indentation is OK).
192    pub diff: String,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196#[cfg_attr(feature = "schemars", derive(JsonSchema))]
197pub struct DeleteFilesArgs {
198    /// Paths to delete (relative to project root; no absolute paths; no `..`).
199    pub paths: Vec<String>,
200}
201
202/// Tools (function definitions) to send to the OpenAI Responses API.
203#[cfg(feature = "schemars")]
204pub fn openai_tools() -> Vec<Value> {
205    vec![
206        serde_json::json!({
207            "type": "function",
208            "name": "read_file",
209            "description": "Read a local file (by path), optionally with offset/limit. Returns a JSON string with keys: path, offset, limit, total_lines, content, numbered_content, fingerprint{hash64,len_bytes}, truncated.",
210            "strict": true,
211            "parameters": tool_parameters::<ReadFileArgs>(),
212        }),
213        serde_json::json!({
214            "type": "function",
215            "name": "list_dir",
216            "description": "List a local directory (by path). Returns a JSON string: { path, entries: [{ name, is_dir, is_file }, ...] }.",
217            "strict": true,
218            "parameters": tool_parameters::<ListDirArgs>(),
219        }),
220        serde_json::json!({
221            "type": "function",
222            "name": "glob",
223            "description": "Find local file or directory paths using a glob pattern under a search root. Use this for path discovery when you need matching paths, not file contents. Returns plain text with Returned, Total, and one relative path per line.",
224            "strict": true,
225            "parameters": tool_parameters::<GlobArgs>(),
226        }),
227        serde_json::json!({
228            "type": "function",
229            "name": "grep",
230            "description": "Search for a regex pattern in files. Returns a JSON string including matches (file, line_number, line). May be truncated to head_limit.",
231            "strict": true,
232            "parameters": tool_parameters::<GrepArgs>(),
233        }),
234        serde_json::json!({
235            "type": "function",
236            "name": "run_shell",
237            "description": "Run a shell command via `bash -lc` (supports pipes/redirection). Requires user confirmation unless the client auto-approves it. Use `max_output_bytes` intentionally for foreground runs: prefer the smallest limit that answers the question, and increase only when needed. Oversize output is cut from the middle, preserving roughly the first 30% and last 70%, so large requests are rarely necessary just to inspect the tail. Set `bg=true` to start a background shell that returns immediately with a shell id. When `bg=true`, omit `timeout_seconds` and omit `max_output_bytes`.",
238            "parameters": tool_parameters::<RunShellArgs>(),
239        }),
240        serde_json::json!({
241            "type": "function",
242            "name": "read_shell_output",
243            "description": "Read captured output from a background shell started with `run_shell(bg=true)`. Output is line-oriented. By default it reads from the end; set `from_start=true` to read from the beginning.",
244            "strict": true,
245            "parameters": tool_parameters::<ReadShellOutputArgs>(),
246        }),
247        serde_json::json!({
248            "type": "function",
249            "name": "stop_shell",
250            "description": "Stop a background shell started with `run_shell(bg=true)` and discard its retained state and logs.",
251            "strict": true,
252            "parameters": tool_parameters::<StopShellArgs>(),
253        }),
254        serde_json::json!({
255            "type": "function",
256            "name": "sleep",
257            "description": "Wait for a short period. Provide `shell_ids` to return early when any watched background shell exits. Use `shell_ids: []` for a plain timer.",
258            "strict": true,
259            "parameters": tool_parameters::<SleepArgs>(),
260        }),
261        serde_json::json!({
262            "type": "function",
263            "name": "apply_diff",
264            "description": "Apply a git-style unified diff to the local working tree (create/update files). Returns a JSON string describing what changed or an error.",
265            "strict": true,
266            "parameters": tool_parameters::<ApplyDiffArgs>(),
267        }),
268        serde_json::json!({
269            "type": "function",
270            "name": "delete_files",
271            "description": "Delete one or more files by path (relative to project root). Returns a JSON string listing deleted and missing paths.",
272            "strict": true,
273            "parameters": tool_parameters::<DeleteFilesArgs>(),
274        }),
275    ]
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use serde_json::json;
282
283    #[test]
284    fn glob_args_default_exclude_to_empty_list() {
285        let args: GlobArgs = serde_json::from_value(json!({
286            "pattern": "**/*.rs",
287            "path": "src",
288            "limit": 50,
289            "kind": "files",
290        }))
291        .expect("glob args");
292
293        assert_eq!(args.pattern, "**/*.rs");
294        assert_eq!(args.path.as_deref(), Some("src"));
295        assert_eq!(args.limit, Some(50));
296        assert_eq!(args.kind, Some(GlobKind::Files));
297        assert!(args.exclude.is_empty());
298    }
299
300    #[cfg(feature = "schemars")]
301    #[test]
302    fn run_shell_tool_schema_encourages_small_output_limits() {
303        let run_shell = openai_tools()
304            .into_iter()
305            .find(|tool| tool.get("name").and_then(Value::as_str) == Some("run_shell"))
306            .expect("run_shell tool");
307
308        let description = run_shell
309            .get("description")
310            .and_then(Value::as_str)
311            .expect("run_shell description");
312        assert!(description.contains("max_output_bytes"));
313        assert!(description.contains("30%"));
314        assert!(description.contains("70%"));
315        assert!(description.contains("smallest limit"));
316        assert!(description.contains("bg=true"));
317        assert!(description.contains("omit `timeout_seconds`"));
318        assert!(description.contains("omit `max_output_bytes`"));
319
320        let timeout_description = run_shell
321            .get("parameters")
322            .and_then(|value| value.get("properties"))
323            .and_then(|value| value.get("timeout_seconds"))
324            .and_then(|value| value.get("description"))
325            .and_then(Value::as_str)
326            .expect("timeout_seconds description");
327        assert!(timeout_description.contains("30 second timeout"));
328        assert!(timeout_description.contains("model training"));
329        assert!(timeout_description.contains("safe side"));
330        assert!(timeout_description.contains("Must be omitted when `bg=true`"));
331
332        let max_output_description = run_shell
333            .get("parameters")
334            .and_then(|value| value.get("properties"))
335            .and_then(|value| value.get("max_output_bytes"))
336            .and_then(|value| value.get("description"))
337            .and_then(Value::as_str)
338            .expect("max_output_bytes description");
339        assert!(max_output_description.contains("30%"));
340        assert!(max_output_description.contains("70%"));
341        assert!(max_output_description.contains("few KB"));
342        assert!(max_output_description.contains("Must be omitted when `bg=true`"));
343
344        let properties = run_shell
345            .get("parameters")
346            .and_then(|value| value.get("properties"))
347            .and_then(Value::as_object)
348            .expect("run_shell parameters");
349        assert!(properties.contains_key("bg"));
350    }
351
352    #[cfg(feature = "schemars")]
353    #[test]
354    fn openai_tools_include_glob_tool() {
355        let glob_tool = openai_tools()
356            .into_iter()
357            .find(|tool| tool.get("name").and_then(Value::as_str) == Some("glob"))
358            .expect("glob tool");
359
360        let description = glob_tool
361            .get("description")
362            .and_then(Value::as_str)
363            .expect("glob description");
364        assert!(description.contains("path discovery"));
365        assert!(description.contains("Returned"));
366        assert!(description.contains("Total"));
367
368        let properties = glob_tool
369            .get("parameters")
370            .and_then(|value| value.get("properties"))
371            .and_then(Value::as_object)
372            .expect("glob parameters");
373        assert!(properties.contains_key("pattern"));
374        assert!(properties.contains_key("path"));
375        assert!(properties.contains_key("limit"));
376        assert!(properties.contains_key("kind"));
377        assert!(properties.contains_key("exclude"));
378    }
379
380    #[test]
381    fn background_shell_tool_args_default_to_tail_reads() {
382        let read_shell_output: ReadShellOutputArgs = serde_json::from_value(json!({
383            "shell_id": "bg_123"
384        }))
385        .expect("read_shell_output args");
386        assert_eq!(read_shell_output.shell_id, "bg_123");
387        assert!(!read_shell_output.from_start);
388        assert_eq!(read_shell_output.offset, None);
389        assert_eq!(read_shell_output.limit, None);
390
391        let run_shell: RunShellArgs = serde_json::from_value(json!({
392            "command": "echo hi"
393        }))
394        .expect("run_shell args");
395        assert_eq!(run_shell.command, "echo hi");
396        assert!(!run_shell.bg);
397
398        let sleep: SleepArgs = serde_json::from_value(json!({
399            "seconds": 30,
400            "shell_ids": ["bg_123", "bg_456"]
401        }))
402        .expect("sleep args");
403        assert_eq!(sleep.seconds, 30);
404        assert_eq!(sleep.shell_ids, vec!["bg_123", "bg_456"]);
405    }
406
407    #[cfg(feature = "schemars")]
408    #[test]
409    fn openai_tools_include_background_shell_tools() {
410        let tools = openai_tools();
411        let names = tools
412            .iter()
413            .filter_map(|tool| tool.get("name").and_then(Value::as_str))
414            .collect::<Vec<_>>();
415
416        assert!(names.contains(&"read_shell_output"));
417        assert!(names.contains(&"stop_shell"));
418        assert!(names.contains(&"sleep"));
419    }
420}