1use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
12pub struct ExecuteCommand {
13 pub command: String,
15 pub working_dir: Option<String>,
17 pub needs_approval: bool,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
23pub struct ReadFile {
24 pub path: String,
26 pub line_range: Option<(usize, usize)>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
32pub struct WriteFile {
33 pub path: String,
35 pub content: String,
37 pub append: bool,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
43#[serde(tag = "tool", rename_all = "snake_case")]
44pub enum ToolCall {
45 ExecuteCommand(ExecuteCommand),
46 ReadFile(ReadFile),
47 WriteFile(WriteFile),
48}
49
50pub fn generate_schema<T: JsonSchema>() -> String {
52 let schema = schemars::schema_for!(T);
53 serde_json::to_string_pretty(&schema).unwrap_or_default()
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn test_tool_call_serialization() {
62 let call = ToolCall::ExecuteCommand(ExecuteCommand {
63 command: "cargo test".into(),
64 working_dir: Some("/project".into()),
65 needs_approval: false,
66 });
67
68 let json = serde_json::to_string(&call).unwrap();
69 assert!(json.contains("execute_command"));
70 assert!(json.contains("cargo test"));
71
72 let deserialized: ToolCall = serde_json::from_str(&json).unwrap();
73 assert!(matches!(deserialized, ToolCall::ExecuteCommand(_)));
74 }
75
76 #[test]
77 fn test_generate_schema() {
78 let schema = generate_schema::<ExecuteCommand>();
79 assert!(schema.contains("command"));
80 assert!(schema.contains("working_dir"));
81 assert!(schema.contains("needs_approval"));
82 }
83
84 #[test]
85 fn test_read_file_schema() {
86 let schema = generate_schema::<ReadFile>();
87 assert!(schema.contains("path"));
88 assert!(schema.contains("line_range"));
89 }
90
91 #[test]
92 fn test_write_file_tool_call() {
93 let call = ToolCall::WriteFile(WriteFile {
94 path: "/tmp/test.txt".into(),
95 content: "hello world".into(),
96 append: false,
97 });
98 let json = serde_json::to_string(&call).unwrap();
99 assert!(json.contains("write_file"));
100
101 let deserialized: ToolCall = serde_json::from_str(&json).unwrap();
102 if let ToolCall::WriteFile(wf) = deserialized {
103 assert_eq!(wf.path, "/tmp/test.txt");
104 assert!(!wf.append);
105 } else {
106 panic!("Expected WriteFile variant");
107 }
108 }
109}