Skip to main content

magic_coder_types/
tools.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[cfg(feature = "schemars")]
5use schemars::{JsonSchema, schema_for};
6
7/// Convert a Rust struct schema into an OpenAI tool `parameters` object.
8///
9/// - remove `$schema` and `title`
10/// - convert `definitions` to `$defs`
11/// - convert `oneOf` to `anyOf`
12#[cfg(feature = "schemars")]
13pub fn tool_parameters<T: JsonSchema>() -> Value {
14    let mut v = serde_json::to_value(schema_for!(T)).expect("can't parse value from schema");
15
16    // remove the $schema and title fields
17    if let Some(value) = v.as_object_mut() {
18        value.remove("$schema");
19        value.remove("title");
20    }
21
22    let mut v_str = serde_json::to_string(&v).unwrap();
23    v_str = v_str
24        .replace("/definitions/", "/$defs/")
25        .replace("\"definitions\":", "\"$defs\":");
26
27    // Replace oneOf with anyOf, because it's better supported by the LLMs
28    v_str = v_str.replace("\"oneOf\":", "\"anyOf\":");
29
30    let mut v: Value = serde_json::from_str(&v_str).expect("can't parse value from updated schema");
31    enforce_openai_strict_schema(&mut v);
32    v
33}
34
35#[cfg(feature = "schemars")]
36fn enforce_openai_strict_schema(v: &mut Value) {
37    match v {
38        Value::Object(map) => {
39            // Recurse first so we fix nested schemas too.
40            for (_k, child) in map.iter_mut() {
41                enforce_openai_strict_schema(child);
42            }
43
44            // If this looks like an object schema, enforce strict rules.
45            let is_object = map
46                .get("type")
47                .and_then(|t| t.as_str())
48                .is_some_and(|t| t == "object");
49            let has_props = map.get("properties").is_some();
50            if is_object || has_props {
51                map.entry("additionalProperties".to_string())
52                    .or_insert(Value::Bool(false));
53
54                if let Some(Value::Object(props)) = map.get("properties") {
55                    let mut keys: Vec<String> = props.keys().cloned().collect();
56                    keys.sort();
57                    map.insert(
58                        "required".to_string(),
59                        Value::Array(keys.into_iter().map(Value::String).collect()),
60                    );
61                }
62            }
63        }
64        Value::Array(arr) => {
65            for child in arr.iter_mut() {
66                enforce_openai_strict_schema(child);
67            }
68        }
69        _ => {}
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74#[cfg_attr(feature = "schemars", derive(JsonSchema))]
75pub struct ReadFileArgs {
76    /// Path to file.
77    pub path: String,
78    /// Optional starting line (0-based).
79    pub offset: Option<usize>,
80    /// Optional maximum number of lines to read.
81    pub limit: Option<usize>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85#[cfg_attr(feature = "schemars", derive(JsonSchema))]
86pub struct ListDirArgs {
87    /// Directory path to list.
88    pub path: String,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
92#[cfg_attr(feature = "schemars", derive(JsonSchema))]
93#[serde(rename_all = "snake_case")]
94pub enum GlobKind {
95    Files,
96    Dirs,
97    All,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101#[cfg_attr(feature = "schemars", derive(JsonSchema))]
102pub struct GlobArgs {
103    /// Glob pattern to match. Supports `*`, `**`, `?`, and character classes.
104    pub pattern: String,
105    /// Optional directory root to search under. Defaults to `"."`.
106    pub path: Option<String>,
107    /// Optional maximum number of returned paths. Defaults to `50`.
108    pub limit: Option<usize>,
109    /// Optional match kind. Defaults to `files`.
110    pub kind: Option<GlobKind>,
111    /// Optional exclude patterns. Defaults to an empty list.
112    #[serde(default)]
113    pub exclude: Vec<String>,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
117#[cfg_attr(feature = "schemars", derive(JsonSchema))]
118pub struct GrepArgs {
119    /// Regex pattern to search for.
120    pub pattern: String,
121    /// Optional path (file or directory) to search in.
122    pub path: Option<String>,
123    /// Optional glob filter, e.g. `"*.rs"`.
124    pub glob: Option<String>,
125    /// Optional limit for returned matches.
126    pub head_limit: Option<usize>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130#[cfg_attr(feature = "schemars", derive(JsonSchema))]
131pub struct RunShellArgs {
132    /// Shell command line to run (executed via `bash -lc`), supports pipes/redirection.
133    pub command: String,
134    /// Optional working directory (relative to project root).
135    pub cwd: Option<String>,
136    /// Optional timeout in seconds. Omit to use the default 30 second timeout.
137    /// For longer-running work like model training, set a larger value up front on the safe side to avoid retries.
138    pub timeout_seconds: Option<u64>,
139    /// Optional maximum captured bytes per stream (stdout/stderr).
140    ///
141    /// Truncated output keeps roughly the first 30% and last 70%, so very large
142    /// values are usually unnecessary; prefer a few KB or low tens of KB and only
143    /// increase if needed.
144    pub max_output_bytes: Option<u64>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148#[cfg_attr(feature = "schemars", derive(JsonSchema))]
149pub struct ApplyDiffArgs {
150    /// A git-style unified diff to apply to the working tree.
151    ///
152    /// You may pass either:
153    /// - the raw diff text starting with `diff --git ...`, OR
154    /// - a fenced diff block like ```diff ... ``` (indentation is OK).
155    pub diff: String,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
159#[cfg_attr(feature = "schemars", derive(JsonSchema))]
160pub struct DeleteFilesArgs {
161    /// Paths to delete (relative to project root; no absolute paths; no `..`).
162    pub paths: Vec<String>,
163}
164
165/// Tools (function definitions) to send to the OpenAI Responses API.
166#[cfg(feature = "schemars")]
167pub fn openai_tools() -> Vec<Value> {
168    vec![
169        serde_json::json!({
170            "type": "function",
171            "name": "read_file",
172            "description": "Read a local file (by path), optionally with offset/limit. Returns a JSON string with keys: path, offset, limit, total_lines, content, numbered_content, fingerprint{hash64,len_bytes}, truncated.",
173            "strict": true,
174            "parameters": tool_parameters::<ReadFileArgs>(),
175        }),
176        serde_json::json!({
177            "type": "function",
178            "name": "list_dir",
179            "description": "List a local directory (by path). Returns a JSON string: { path, entries: [{ name, is_dir, is_file }, ...] }.",
180            "strict": true,
181            "parameters": tool_parameters::<ListDirArgs>(),
182        }),
183        serde_json::json!({
184            "type": "function",
185            "name": "glob",
186            "description": "Find local file or directory paths using a glob pattern under a search root. Use this for path discovery when you need matching paths, not file contents. Returns plain text with Returned, Total, and one relative path per line.",
187            "strict": true,
188            "parameters": tool_parameters::<GlobArgs>(),
189        }),
190        serde_json::json!({
191            "type": "function",
192            "name": "grep",
193            "description": "Search for a regex pattern in files. Returns a JSON string including matches (file, line_number, line). May be truncated to head_limit.",
194            "strict": true,
195            "parameters": tool_parameters::<GrepArgs>(),
196        }),
197        serde_json::json!({
198            "type": "function",
199            "name": "run_shell",
200            "description": "Run a shell command via `bash -lc` (supports pipes/redirection). Requires user confirmation. Use `max_output_bytes` intentionally: prefer the smallest limit that answers the question, and increase only when needed. Oversize output is cut from the middle, preserving roughly the first 30% and last 70%, so large requests are rarely necessary just to inspect the tail.",
201            "strict": true,
202            "parameters": tool_parameters::<RunShellArgs>(),
203        }),
204        serde_json::json!({
205            "type": "function",
206            "name": "apply_diff",
207            "description": "Apply a git-style unified diff to the local working tree (create/update files). Returns a JSON string describing what changed or an error.",
208            "strict": true,
209            "parameters": tool_parameters::<ApplyDiffArgs>(),
210        }),
211        serde_json::json!({
212            "type": "function",
213            "name": "delete_files",
214            "description": "Delete one or more files by path (relative to project root). Returns a JSON string listing deleted and missing paths.",
215            "strict": true,
216            "parameters": tool_parameters::<DeleteFilesArgs>(),
217        }),
218    ]
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use serde_json::json;
225
226    #[test]
227    fn glob_args_default_exclude_to_empty_list() {
228        let args: GlobArgs = serde_json::from_value(json!({
229            "pattern": "**/*.rs",
230            "path": "src",
231            "limit": 50,
232            "kind": "files",
233        }))
234        .expect("glob args");
235
236        assert_eq!(args.pattern, "**/*.rs");
237        assert_eq!(args.path.as_deref(), Some("src"));
238        assert_eq!(args.limit, Some(50));
239        assert_eq!(args.kind, Some(GlobKind::Files));
240        assert!(args.exclude.is_empty());
241    }
242
243    #[cfg(feature = "schemars")]
244    #[test]
245    fn run_shell_tool_schema_encourages_small_output_limits() {
246        let run_shell = openai_tools()
247            .into_iter()
248            .find(|tool| tool.get("name").and_then(Value::as_str) == Some("run_shell"))
249            .expect("run_shell tool");
250
251        let description = run_shell
252            .get("description")
253            .and_then(Value::as_str)
254            .expect("run_shell description");
255        assert!(description.contains("max_output_bytes"));
256        assert!(description.contains("30%"));
257        assert!(description.contains("70%"));
258        assert!(description.contains("smallest limit"));
259
260        let timeout_description = run_shell
261            .get("parameters")
262            .and_then(|value| value.get("properties"))
263            .and_then(|value| value.get("timeout_seconds"))
264            .and_then(|value| value.get("description"))
265            .and_then(Value::as_str)
266            .expect("timeout_seconds description");
267        assert!(timeout_description.contains("30 second timeout"));
268        assert!(timeout_description.contains("model training"));
269        assert!(timeout_description.contains("safe side"));
270
271        let max_output_description = run_shell
272            .get("parameters")
273            .and_then(|value| value.get("properties"))
274            .and_then(|value| value.get("max_output_bytes"))
275            .and_then(|value| value.get("description"))
276            .and_then(Value::as_str)
277            .expect("max_output_bytes description");
278        assert!(max_output_description.contains("30%"));
279        assert!(max_output_description.contains("70%"));
280        assert!(max_output_description.contains("few KB"));
281    }
282
283    #[cfg(feature = "schemars")]
284    #[test]
285    fn openai_tools_include_glob_tool() {
286        let glob_tool = openai_tools()
287            .into_iter()
288            .find(|tool| tool.get("name").and_then(Value::as_str) == Some("glob"))
289            .expect("glob tool");
290
291        let description = glob_tool
292            .get("description")
293            .and_then(Value::as_str)
294            .expect("glob description");
295        assert!(description.contains("path discovery"));
296        assert!(description.contains("Returned"));
297        assert!(description.contains("Total"));
298
299        let properties = glob_tool
300            .get("parameters")
301            .and_then(|value| value.get("properties"))
302            .and_then(Value::as_object)
303            .expect("glob parameters");
304        assert!(properties.contains_key("pattern"));
305        assert!(properties.contains_key("path"));
306        assert!(properties.contains_key("limit"));
307        assert!(properties.contains_key("kind"));
308        assert!(properties.contains_key("exclude"));
309    }
310}