Skip to main content

magic_coder_types/
tools.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[cfg(feature = "schemars")]
5use schemars::{JsonSchema, schema_for};
6
7/// Convert a Rust struct schema into an OpenAI tool `parameters` object.
8///
9/// - remove `$schema` and `title`
10/// - convert `definitions` to `$defs`
11/// - convert `oneOf` to `anyOf`
12#[cfg(feature = "schemars")]
13pub fn tool_parameters<T: JsonSchema>() -> Value {
14    let mut v = serde_json::to_value(schema_for!(T)).expect("can't parse value from schema");
15
16    // remove the $schema and title fields
17    if let Some(value) = v.as_object_mut() {
18        value.remove("$schema");
19        value.remove("title");
20    }
21
22    let mut v_str = serde_json::to_string(&v).unwrap();
23    v_str = v_str
24        .replace("/definitions/", "/$defs/")
25        .replace("\"definitions\":", "\"$defs\":");
26
27    // Replace oneOf with anyOf, because it's better supported by the LLMs
28    v_str = v_str.replace("\"oneOf\":", "\"anyOf\":");
29
30    let mut v: Value = serde_json::from_str(&v_str).expect("can't parse value from updated schema");
31    enforce_openai_strict_schema(&mut v);
32    v
33}
34
35#[cfg(feature = "schemars")]
36fn enforce_openai_strict_schema(v: &mut Value) {
37    match v {
38        Value::Object(map) => {
39            // Recurse first so we fix nested schemas too.
40            for (_k, child) in map.iter_mut() {
41                enforce_openai_strict_schema(child);
42            }
43
44            // If this looks like an object schema, enforce strict rules.
45            let is_object = map
46                .get("type")
47                .and_then(|t| t.as_str())
48                .is_some_and(|t| t == "object");
49            let has_props = map.get("properties").is_some();
50            if is_object || has_props {
51                map.entry("additionalProperties".to_string())
52                    .or_insert(Value::Bool(false));
53
54                if let Some(Value::Object(props)) = map.get("properties") {
55                    let mut keys: Vec<String> = props.keys().cloned().collect();
56                    keys.sort();
57                    map.insert(
58                        "required".to_string(),
59                        Value::Array(keys.into_iter().map(Value::String).collect()),
60                    );
61                }
62            }
63        }
64        Value::Array(arr) => {
65            for child in arr.iter_mut() {
66                enforce_openai_strict_schema(child);
67            }
68        }
69        _ => {}
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74#[cfg_attr(feature = "schemars", derive(JsonSchema))]
75pub struct ReadFileArgs {
76    /// Path to file.
77    pub path: String,
78    /// Optional starting line (0-based).
79    pub offset: Option<usize>,
80    /// Optional maximum number of lines to read.
81    pub limit: Option<usize>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85#[cfg_attr(feature = "schemars", derive(JsonSchema))]
86pub struct ListDirArgs {
87    /// Directory path to list.
88    pub path: String,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92#[cfg_attr(feature = "schemars", derive(JsonSchema))]
93pub struct GrepArgs {
94    /// Regex pattern to search for.
95    pub pattern: String,
96    /// Optional path (file or directory) to search in.
97    pub path: Option<String>,
98    /// Optional glob filter, e.g. `"*.rs"`.
99    pub glob: Option<String>,
100    /// Optional limit for returned matches.
101    pub head_limit: Option<usize>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
105#[cfg_attr(feature = "schemars", derive(JsonSchema))]
106pub struct RunShellArgs {
107    /// Shell command line to run (executed via `bash -lc`), supports pipes/redirection.
108    pub command: String,
109    /// Optional working directory (relative to project root).
110    pub cwd: Option<String>,
111    /// Optional timeout in seconds. Omit to use the default 30 second timeout.
112    /// For longer-running work like model training, set a larger value up front on the safe side to avoid retries.
113    pub timeout_seconds: Option<u64>,
114    /// Optional maximum captured bytes per stream (stdout/stderr).
115    ///
116    /// Truncated output keeps roughly the first 30% and last 70%, so very large
117    /// values are usually unnecessary; prefer a few KB or low tens of KB and only
118    /// increase if needed.
119    pub max_output_bytes: Option<u64>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123#[cfg_attr(feature = "schemars", derive(JsonSchema))]
124pub struct ApplyDiffArgs {
125    /// A git-style unified diff to apply to the working tree.
126    ///
127    /// You may pass either:
128    /// - the raw diff text starting with `diff --git ...`, OR
129    /// - a fenced diff block like ```diff ... ``` (indentation is OK).
130    pub diff: String,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134#[cfg_attr(feature = "schemars", derive(JsonSchema))]
135pub struct DeleteFilesArgs {
136    /// Paths to delete (relative to project root; no absolute paths; no `..`).
137    pub paths: Vec<String>,
138}
139
140/// Tools (function definitions) to send to the OpenAI Responses API.
141#[cfg(feature = "schemars")]
142pub fn openai_tools() -> Vec<Value> {
143    vec![
144        serde_json::json!({
145            "type": "function",
146            "name": "read_file",
147            "description": "Read a local file (by path), optionally with offset/limit. Returns a JSON string with keys: path, offset, limit, total_lines, content, numbered_content, fingerprint{hash64,len_bytes}, truncated.",
148            "strict": true,
149            "parameters": tool_parameters::<ReadFileArgs>(),
150        }),
151        serde_json::json!({
152            "type": "function",
153            "name": "list_dir",
154            "description": "List a local directory (by path). Returns a JSON string: { path, entries: [{ name, is_dir, is_file }, ...] }.",
155            "strict": true,
156            "parameters": tool_parameters::<ListDirArgs>(),
157        }),
158        serde_json::json!({
159            "type": "function",
160            "name": "grep",
161            "description": "Search for a regex pattern in files. Returns a JSON string including matches (file, line_number, line). May be truncated to head_limit.",
162            "strict": true,
163            "parameters": tool_parameters::<GrepArgs>(),
164        }),
165        serde_json::json!({
166            "type": "function",
167            "name": "run_shell",
168            "description": "Run a shell command via `bash -lc` (supports pipes/redirection). Requires user confirmation. Use `max_output_bytes` intentionally: prefer the smallest limit that answers the question, and increase only when needed. Oversize output is cut from the middle, preserving roughly the first 30% and last 70%, so large requests are rarely necessary just to inspect the tail.",
169            "strict": true,
170            "parameters": tool_parameters::<RunShellArgs>(),
171        }),
172        serde_json::json!({
173            "type": "function",
174            "name": "apply_diff",
175            "description": "Apply a git-style unified diff to the local working tree (create/update files). Returns a JSON string describing what changed or an error.",
176            "strict": true,
177            "parameters": tool_parameters::<ApplyDiffArgs>(),
178        }),
179        serde_json::json!({
180            "type": "function",
181            "name": "delete_files",
182            "description": "Delete one or more files by path (relative to project root). Returns a JSON string listing deleted and missing paths.",
183            "strict": true,
184            "parameters": tool_parameters::<DeleteFilesArgs>(),
185        }),
186    ]
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[cfg(feature = "schemars")]
194    #[test]
195    fn run_shell_tool_schema_encourages_small_output_limits() {
196        let run_shell = openai_tools()
197            .into_iter()
198            .find(|tool| tool.get("name").and_then(Value::as_str) == Some("run_shell"))
199            .expect("run_shell tool");
200
201        let description = run_shell
202            .get("description")
203            .and_then(Value::as_str)
204            .expect("run_shell description");
205        assert!(description.contains("max_output_bytes"));
206        assert!(description.contains("30%"));
207        assert!(description.contains("70%"));
208        assert!(description.contains("smallest limit"));
209
210        let timeout_description = run_shell
211            .get("parameters")
212            .and_then(|value| value.get("properties"))
213            .and_then(|value| value.get("timeout_seconds"))
214            .and_then(|value| value.get("description"))
215            .and_then(Value::as_str)
216            .expect("timeout_seconds description");
217        assert!(timeout_description.contains("30 second timeout"));
218        assert!(timeout_description.contains("model training"));
219        assert!(timeout_description.contains("safe side"));
220
221        let max_output_description = run_shell
222            .get("parameters")
223            .and_then(|value| value.get("properties"))
224            .and_then(|value| value.get("max_output_bytes"))
225            .and_then(|value| value.get("description"))
226            .and_then(Value::as_str)
227            .expect("max_output_bytes description");
228        assert!(max_output_description.contains("30%"));
229        assert!(max_output_description.contains("70%"));
230        assert!(max_output_description.contains("few KB"));
231    }
232}