1use atd_protocol::{AtdError, ToolResult};
4use atd_sdk::{AtdClient, CallOptions};
5use std::io::Write;
6
7use crate::cli::CallArgs;
8
9pub async fn run(client: &AtdClient, args: CallArgs, out: &mut impl Write) -> Result<(), AtdError> {
10 let call_args: serde_json::Value =
11 serde_json::from_str(&args.args).map_err(|e| AtdError::InvalidArguments {
12 tool_id: args.tool_id.clone(),
13 field: "--args".into(),
14 reason: format!("not valid JSON: {e}"),
15 })?;
16
17 let result = client
18 .call(
19 &args.tool_id,
20 call_args,
21 CallOptions {
22 dry_run: args.dry_run,
23 preferred_binding: None,
24 },
25 )
26 .await?;
27
28 if args.json {
29 let v = serde_json::to_string(&result).map_err(|e| AtdError::ProtocolError {
30 expected: "serializable ToolResult".into(),
31 got: format!("serde error: {e}"),
32 })?;
33 writeln!(out, "{v}").ok();
34 return Ok(());
35 }
36
37 match result {
38 ToolResult::Success { data, .. } => {
39 let pretty = serde_json::to_string_pretty(&data).unwrap_or_else(|_| "{}".into());
40 writeln!(out, "ok:").ok();
41 writeln!(out, "{pretty}").ok();
42 Ok(())
43 }
44 ToolResult::Error {
45 code,
46 message,
47 reason,
48 retryable,
49 } => Err(AtdError::ToolExecutionFailed {
50 tool_id: args.tool_id.clone(),
51 inner: Box::new(std::io::Error::other(format!(
52 "[{code}] {message}{}{}",
53 if retryable { " (retryable)" } else { "" },
54 reason
55 .as_deref()
56 .map(|r| format!(" — raw: {r}"))
57 .unwrap_or_default()
58 ))),
59 }),
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use atd_sdk::Endpoint;
67 use tokio::io::{AsyncReadExt, AsyncWriteExt};
68 use tokio::net::UnixListener;
69
70 async fn spawn_fake_server(
71 handler: fn(serde_json::Value) -> serde_json::Value,
72 ) -> std::path::PathBuf {
73 let dir = tempfile::tempdir().unwrap();
74 let path = dir.path().join("s.sock");
75 let listener = UnixListener::bind(&path).unwrap();
76 std::mem::forget(dir);
77
78 let ret = path.clone();
79 tokio::spawn(async move {
80 while let Ok((stream, _)) = listener.accept().await {
81 tokio::spawn(async move {
82 let (mut r, mut w) = stream.into_split();
83 loop {
84 let mut lb = [0u8; 4];
85 if r.read_exact(&mut lb).await.is_err() {
86 return;
87 }
88 let n = u32::from_be_bytes(lb) as usize;
89 let mut buf = vec![0u8; n];
90 if r.read_exact(&mut buf).await.is_err() {
91 return;
92 }
93 let req: serde_json::Value = serde_json::from_slice(&buf).unwrap();
94 let reply = match req["type"].as_str() {
95 Some("ping") => serde_json::json!({"type":"pong"}),
96 _ => handler(req),
97 };
98 let body = serde_json::to_vec(&reply).unwrap();
99 if w.write_all(&(body.len() as u32).to_be_bytes())
100 .await
101 .is_err()
102 {
103 return;
104 }
105 if w.write_all(&body).await.is_err() {
106 return;
107 }
108 let _ = w.flush().await;
109 }
110 });
111 }
112 });
113 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
114 ret
115 }
116
117 #[tokio::test]
118 async fn call_prints_ok_and_data_on_success() {
119 let sock = spawn_fake_server(|req| match req["type"].as_str() {
120 Some("run_tool") => serde_json::json!({
121 "type":"tool_result",
122 "tool_id": req["tool_id"],
123 "result": {"content":"hello"},
124 "success": true,
125 "dry_run": false
126 }),
127 _ => serde_json::json!({"type":"error","message":"no"}),
128 })
129 .await;
130 let client = AtdClient::connect(Endpoint::unix(sock)).await.unwrap();
131 let mut out: Vec<u8> = Vec::new();
132 run(
133 &client,
134 CallArgs {
135 tool_id: "anos:fs.read".into(),
136 args: r#"{"path":"/tmp/x"}"#.into(),
137 dry_run: false,
138 json: false,
139 },
140 &mut out,
141 )
142 .await
143 .unwrap();
144 let s = String::from_utf8(out).unwrap();
145 assert!(s.starts_with("ok:\n"));
146 assert!(s.contains("\"content\": \"hello\""));
147 }
148
149 #[tokio::test]
150 async fn call_errors_on_invalid_json_args() {
151 let sock = spawn_fake_server(|_| serde_json::json!({"type":"error","message":"no"})).await;
152 let client = AtdClient::connect(Endpoint::unix(sock)).await.unwrap();
153 let mut out: Vec<u8> = Vec::new();
154 let err = run(
155 &client,
156 CallArgs {
157 tool_id: "anos:fs.read".into(),
158 args: "not json".into(),
159 dry_run: false,
160 json: false,
161 },
162 &mut out,
163 )
164 .await
165 .unwrap_err();
166 match err {
167 AtdError::InvalidArguments { field, .. } => assert_eq!(field, "--args"),
168 _ => panic!("expected InvalidArguments variant"),
169 }
170 }
171
172 #[tokio::test]
173 async fn call_json_flag_emits_full_tool_result_envelope() {
174 let sock = spawn_fake_server(|req| match req["type"].as_str() {
175 Some("run_tool") => serde_json::json!({
176 "type":"tool_result",
177 "tool_id": req["tool_id"],
178 "result": {"k":"v"},
179 "success": true,
180 "dry_run": false
181 }),
182 _ => serde_json::json!({"type":"error","message":"no"}),
183 })
184 .await;
185 let client = AtdClient::connect(Endpoint::unix(sock)).await.unwrap();
186 let mut out: Vec<u8> = Vec::new();
187 run(
188 &client,
189 CallArgs {
190 tool_id: "anos:fs.read".into(),
191 args: "{}".into(),
192 dry_run: false,
193 json: true,
194 },
195 &mut out,
196 )
197 .await
198 .unwrap();
199 let s = String::from_utf8(out).unwrap();
200 let v: serde_json::Value = serde_json::from_str(s.trim()).unwrap();
201 assert_eq!(v["status"], "success");
202 assert_eq!(v["data"]["k"], "v");
203 }
204
205 #[tokio::test]
206 async fn call_surfaces_server_reported_failure_as_error() {
207 let sock = spawn_fake_server(|req| match req["type"].as_str() {
208 Some("run_tool") => serde_json::json!({
209 "type":"tool_result",
210 "tool_id": req["tool_id"],
211 "result": {"code":"EPERM","message":"denied","retryable":false},
212 "success": false,
213 "dry_run": false
214 }),
215 _ => serde_json::json!({"type":"error","message":"no"}),
216 })
217 .await;
218 let client = AtdClient::connect(Endpoint::unix(sock)).await.unwrap();
219 let mut out: Vec<u8> = Vec::new();
220 let err = run(
221 &client,
222 CallArgs {
223 tool_id: "anos:fs.read".into(),
224 args: "{}".into(),
225 dry_run: false,
226 json: false,
227 },
228 &mut out,
229 )
230 .await
231 .unwrap_err();
232 let s = format!("{err:?}");
233 assert!(s.contains("ToolExecutionFailed"), "got: {s}");
234 }
235}