Skip to main content

atd_cli/
call.rs

1//! `atd call` — invoke a tool with JSON args and print the result.
2
3use atd_protocol::{AtdError, ToolResult};
4use atd_sdk::{AtdClient, CallOptions};
5use std::io::Write;
6
7use crate::cli::CallArgs;
8
9pub async fn run(client: &AtdClient, args: CallArgs, out: &mut impl Write) -> Result<(), AtdError> {
10    let call_args: serde_json::Value =
11        serde_json::from_str(&args.args).map_err(|e| AtdError::InvalidArguments {
12            tool_id: args.tool_id.clone(),
13            field: "--args".into(),
14            reason: format!("not valid JSON: {e}"),
15        })?;
16
17    let result = client
18        .call(
19            &args.tool_id,
20            call_args,
21            CallOptions {
22                dry_run: args.dry_run,
23                preferred_binding: None,
24            },
25        )
26        .await?;
27
28    if args.json {
29        let v = serde_json::to_string(&result).map_err(|e| AtdError::ProtocolError {
30            expected: "serializable ToolResult".into(),
31            got: format!("serde error: {e}"),
32        })?;
33        writeln!(out, "{v}").ok();
34        return Ok(());
35    }
36
37    match result {
38        ToolResult::Success { data, .. } => {
39            let pretty = serde_json::to_string_pretty(&data).unwrap_or_else(|_| "{}".into());
40            writeln!(out, "ok:").ok();
41            writeln!(out, "{pretty}").ok();
42            Ok(())
43        }
44        ToolResult::Error {
45            code,
46            message,
47            reason,
48            retryable,
49        } => Err(AtdError::ToolExecutionFailed {
50            tool_id: args.tool_id.clone(),
51            inner: Box::new(std::io::Error::other(format!(
52                "[{code}] {message}{}{}",
53                if retryable { " (retryable)" } else { "" },
54                reason
55                    .as_deref()
56                    .map(|r| format!(" — raw: {r}"))
57                    .unwrap_or_default()
58            ))),
59        }),
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use atd_sdk::Endpoint;
67    use tokio::io::{AsyncReadExt, AsyncWriteExt};
68    use tokio::net::UnixListener;
69
70    async fn spawn_fake_server(
71        handler: fn(serde_json::Value) -> serde_json::Value,
72    ) -> std::path::PathBuf {
73        let dir = tempfile::tempdir().unwrap();
74        let path = dir.path().join("s.sock");
75        let listener = UnixListener::bind(&path).unwrap();
76        std::mem::forget(dir);
77
78        let ret = path.clone();
79        tokio::spawn(async move {
80            while let Ok((stream, _)) = listener.accept().await {
81                tokio::spawn(async move {
82                    let (mut r, mut w) = stream.into_split();
83                    loop {
84                        let mut lb = [0u8; 4];
85                        if r.read_exact(&mut lb).await.is_err() {
86                            return;
87                        }
88                        let n = u32::from_be_bytes(lb) as usize;
89                        let mut buf = vec![0u8; n];
90                        if r.read_exact(&mut buf).await.is_err() {
91                            return;
92                        }
93                        let req: serde_json::Value = serde_json::from_slice(&buf).unwrap();
94                        let reply = match req["type"].as_str() {
95                            Some("ping") => serde_json::json!({"type":"pong"}),
96                            _ => handler(req),
97                        };
98                        let body = serde_json::to_vec(&reply).unwrap();
99                        if w.write_all(&(body.len() as u32).to_be_bytes())
100                            .await
101                            .is_err()
102                        {
103                            return;
104                        }
105                        if w.write_all(&body).await.is_err() {
106                            return;
107                        }
108                        let _ = w.flush().await;
109                    }
110                });
111            }
112        });
113        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
114        ret
115    }
116
117    #[tokio::test]
118    async fn call_prints_ok_and_data_on_success() {
119        let sock = spawn_fake_server(|req| match req["type"].as_str() {
120            Some("run_tool") => serde_json::json!({
121                "type":"tool_result",
122                "tool_id": req["tool_id"],
123                "result": {"content":"hello"},
124                "success": true,
125                "dry_run": false
126            }),
127            _ => serde_json::json!({"type":"error","message":"no"}),
128        })
129        .await;
130        let client = AtdClient::connect(Endpoint::unix(sock)).await.unwrap();
131        let mut out: Vec<u8> = Vec::new();
132        run(
133            &client,
134            CallArgs {
135                tool_id: "anos:fs.read".into(),
136                args: r#"{"path":"/tmp/x"}"#.into(),
137                dry_run: false,
138                json: false,
139            },
140            &mut out,
141        )
142        .await
143        .unwrap();
144        let s = String::from_utf8(out).unwrap();
145        assert!(s.starts_with("ok:\n"));
146        assert!(s.contains("\"content\": \"hello\""));
147    }
148
149    #[tokio::test]
150    async fn call_errors_on_invalid_json_args() {
151        let sock = spawn_fake_server(|_| serde_json::json!({"type":"error","message":"no"})).await;
152        let client = AtdClient::connect(Endpoint::unix(sock)).await.unwrap();
153        let mut out: Vec<u8> = Vec::new();
154        let err = run(
155            &client,
156            CallArgs {
157                tool_id: "anos:fs.read".into(),
158                args: "not json".into(),
159                dry_run: false,
160                json: false,
161            },
162            &mut out,
163        )
164        .await
165        .unwrap_err();
166        match err {
167            AtdError::InvalidArguments { field, .. } => assert_eq!(field, "--args"),
168            _ => panic!("expected InvalidArguments variant"),
169        }
170    }
171
172    #[tokio::test]
173    async fn call_json_flag_emits_full_tool_result_envelope() {
174        let sock = spawn_fake_server(|req| match req["type"].as_str() {
175            Some("run_tool") => serde_json::json!({
176                "type":"tool_result",
177                "tool_id": req["tool_id"],
178                "result": {"k":"v"},
179                "success": true,
180                "dry_run": false
181            }),
182            _ => serde_json::json!({"type":"error","message":"no"}),
183        })
184        .await;
185        let client = AtdClient::connect(Endpoint::unix(sock)).await.unwrap();
186        let mut out: Vec<u8> = Vec::new();
187        run(
188            &client,
189            CallArgs {
190                tool_id: "anos:fs.read".into(),
191                args: "{}".into(),
192                dry_run: false,
193                json: true,
194            },
195            &mut out,
196        )
197        .await
198        .unwrap();
199        let s = String::from_utf8(out).unwrap();
200        let v: serde_json::Value = serde_json::from_str(s.trim()).unwrap();
201        assert_eq!(v["status"], "success");
202        assert_eq!(v["data"]["k"], "v");
203    }
204
205    #[tokio::test]
206    async fn call_surfaces_server_reported_failure_as_error() {
207        let sock = spawn_fake_server(|req| match req["type"].as_str() {
208            Some("run_tool") => serde_json::json!({
209                "type":"tool_result",
210                "tool_id": req["tool_id"],
211                "result": {"code":"EPERM","message":"denied","retryable":false},
212                "success": false,
213                "dry_run": false
214            }),
215            _ => serde_json::json!({"type":"error","message":"no"}),
216        })
217        .await;
218        let client = AtdClient::connect(Endpoint::unix(sock)).await.unwrap();
219        let mut out: Vec<u8> = Vec::new();
220        let err = run(
221            &client,
222            CallArgs {
223                tool_id: "anos:fs.read".into(),
224                args: "{}".into(),
225                dry_run: false,
226                json: false,
227            },
228            &mut out,
229        )
230        .await
231        .unwrap_err();
232        let s = format!("{err:?}");
233        assert!(s.contains("ToolExecutionFailed"), "got: {s}");
234    }
235}