Skip to main content

atd_cli/
doctor.rs

1//! `atd doctor` — connectivity sanity check: socket exists, ping succeeds,
2//! how many tools does `discover` return.
3
4use atd_protocol::AtdError;
5use atd_sdk::{AtdClient, DiscoverFilter};
6use serde::Serialize;
7use std::io::Write;
8use std::path::PathBuf;
9
10use crate::cli::DoctorArgs;
11
12#[derive(Serialize)]
13pub struct DoctorReport {
14    pub socket_path: String,
15    pub socket_exists: bool,
16    pub ping_ok: bool,
17    pub tool_count: Option<usize>,
18    pub error: Option<String>,
19}
20
21/// `sock` is the resolved endpoint path — we need it separately from the
22/// connected client to report socket existence when connect fails.
23pub async fn run(sock: PathBuf, args: DoctorArgs, out: &mut impl Write) -> Result<(), AtdError> {
24    let socket_exists = sock.exists();
25    let socket_path = sock.to_string_lossy().into_owned();
26
27    let (ping_ok, tool_count, error) =
28        match AtdClient::connect(atd_sdk::Endpoint::unix(&sock)).await {
29            Ok(client) => match client.discover(None, DiscoverFilter::default()).await {
30                Ok(v) => (true, Some(v.len()), None),
31                Err(e) => (true, None, Some(format!("discover failed: {e}"))),
32            },
33            Err(e) => (false, None, Some(format!("connect failed: {e}"))),
34        };
35
36    let report = DoctorReport {
37        socket_path,
38        socket_exists,
39        ping_ok,
40        tool_count,
41        error,
42    };
43
44    if args.json {
45        serde_json::to_writer(&mut *out, &report).map_err(|e| AtdError::ProtocolError {
46            expected: "serializable DoctorReport".into(),
47            got: format!("serde error: {e}"),
48        })?;
49        writeln!(out).ok();
50    } else {
51        writeln!(out, "socket path:   {}", report.socket_path).ok();
52        writeln!(out, "socket exists: {}", report.socket_exists).ok();
53        writeln!(
54            out,
55            "ping:          {}",
56            if report.ping_ok { "ok" } else { "FAIL" }
57        )
58        .ok();
59        match report.tool_count {
60            Some(n) => writeln!(out, "tool count:    {n}").ok(),
61            None => writeln!(out, "tool count:    unavailable").ok(),
62        };
63        if let Some(e) = &report.error {
64            writeln!(out, "error:         {e}").ok();
65        }
66    }
67    Ok(())
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use tokio::io::{AsyncReadExt, AsyncWriteExt};
74    use tokio::net::UnixListener;
75
76    async fn spawn_server_with_3_tools() -> std::path::PathBuf {
77        let dir = tempfile::tempdir().unwrap();
78        let path = dir.path().join("s.sock");
79        let listener = UnixListener::bind(&path).unwrap();
80        std::mem::forget(dir);
81
82        let ret = path.clone();
83        tokio::spawn(async move {
84            while let Ok((stream, _)) = listener.accept().await {
85                tokio::spawn(async move {
86                    let (mut r, mut w) = stream.into_split();
87                    loop {
88                        let mut lb = [0u8; 4];
89                        if r.read_exact(&mut lb).await.is_err() {
90                            return;
91                        }
92                        let n = u32::from_be_bytes(lb) as usize;
93                        let mut buf = vec![0u8; n];
94                        if r.read_exact(&mut buf).await.is_err() {
95                            return;
96                        }
97                        let req: serde_json::Value = serde_json::from_slice(&buf).unwrap();
98                        let reply = match req["type"].as_str() {
99                            Some("ping") => serde_json::json!({"type":"pong"}),
100                            Some("tool_list") => serde_json::json!({
101                                "type":"tool_list",
102                                "tools":[
103                                    {"id":"anos:fs.read","description":"r","tier":"hot","visibility":"read"},
104                                    {"id":"anos:fs.write","description":"w","tier":"hot","visibility":"write"},
105                                    {"id":"anos:web.search","description":"s","tier":"hot","visibility":"read"}
106                                ]
107                            }),
108                            _ => serde_json::json!({"type":"error","message":"no"}),
109                        };
110                        let body = serde_json::to_vec(&reply).unwrap();
111                        if w.write_all(&(body.len() as u32).to_be_bytes())
112                            .await
113                            .is_err()
114                        {
115                            return;
116                        }
117                        if w.write_all(&body).await.is_err() {
118                            return;
119                        }
120                        let _ = w.flush().await;
121                    }
122                });
123            }
124        });
125        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
126        ret
127    }
128
129    #[tokio::test]
130    async fn doctor_reports_ok_against_reachable_server() {
131        let sock = spawn_server_with_3_tools().await;
132        let mut out: Vec<u8> = Vec::new();
133        run(sock.clone(), DoctorArgs { json: false }, &mut out)
134            .await
135            .unwrap();
136        let s = String::from_utf8(out).unwrap();
137        assert!(s.contains("socket exists: true"));
138        assert!(s.contains("ping:          ok"));
139        assert!(s.contains("tool count:    3"));
140    }
141
142    #[tokio::test]
143    async fn doctor_json_flag_emits_structured_report() {
144        let sock = spawn_server_with_3_tools().await;
145        let mut out: Vec<u8> = Vec::new();
146        run(sock.clone(), DoctorArgs { json: true }, &mut out)
147            .await
148            .unwrap();
149        let s = String::from_utf8(out).unwrap();
150        let v: serde_json::Value = serde_json::from_str(s.trim()).unwrap();
151        assert_eq!(v["socket_exists"], true);
152        assert_eq!(v["ping_ok"], true);
153        assert_eq!(v["tool_count"], 3);
154        assert!(v["error"].is_null());
155    }
156
157    #[tokio::test]
158    async fn doctor_reports_unreachable_when_socket_missing() {
159        let dir = tempfile::tempdir().unwrap();
160        let missing = dir.path().join("does-not-exist.sock");
161        let mut out: Vec<u8> = Vec::new();
162        run(missing, DoctorArgs { json: false }, &mut out)
163            .await
164            .unwrap();
165        let s = String::from_utf8(out).unwrap();
166        assert!(s.contains("socket exists: false"));
167        assert!(s.contains("ping:          FAIL"));
168        assert!(s.contains("error:"));
169    }
170}