Skip to main content

krait/
client.rs

1use std::path::{Path, PathBuf};
2use std::time::Duration;
3
4use anyhow::{bail, Context};
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::UnixStream;
7use tracing::debug;
8
9use crate::protocol::{Request, Response};
10
11const MAX_RETRIES: u32 = 3;
12const RETRY_DELAYS_MS: [u64; 3] = [100, 200, 500];
13
14pub struct DaemonClient {
15    stream: UnixStream,
16}
17
18impl DaemonClient {
19    /// Connect to a running daemon at the given socket path.
20    ///
21    /// # Errors
22    /// Returns an error if the connection fails.
23    pub async fn connect(socket_path: &Path) -> anyhow::Result<Self> {
24        let stream = UnixStream::connect(socket_path)
25            .await
26            .with_context(|| format!("failed to connect to daemon at {}", socket_path.display()))?;
27        Ok(Self { stream })
28    }
29
30    /// Connect to the daemon, auto-starting it if not running.
31    ///
32    /// # Errors
33    /// Returns an error if the daemon cannot be started or connected to.
34    pub async fn connect_or_start(socket_path: &Path) -> anyhow::Result<Self> {
35        if let Ok(client) = Self::connect(socket_path).await {
36            return Ok(client);
37        }
38
39        debug!("daemon not running, starting it");
40        spawn_daemon()?;
41
42        for (attempt, delay_ms) in RETRY_DELAYS_MS.iter().enumerate() {
43            tokio::time::sleep(Duration::from_millis(*delay_ms)).await;
44            match Self::connect(socket_path).await {
45                Ok(client) => {
46                    debug!("connected after {} retries", attempt + 1);
47                    return Ok(client);
48                }
49                Err(e) if attempt == (MAX_RETRIES as usize - 1) => return Err(e),
50                Err(_) => {}
51            }
52        }
53
54        bail!(
55            "Daemon failed to start after {MAX_RETRIES} attempts. \
56             Run `krait daemon start` manually for debug output."
57        )
58    }
59
60    /// Send a request and receive the response.
61    ///
62    /// # Errors
63    /// Returns an error on IO or serialization failure.
64    pub async fn send(&mut self, request: &Request) -> anyhow::Result<Response> {
65        let json = serde_json::to_vec(request)?;
66        let len = u32::try_from(json.len())?;
67
68        self.stream.write_u32(len).await?;
69        self.stream.write_all(&json).await?;
70        self.stream.flush().await?;
71
72        let resp_len = self.stream.read_u32().await?;
73        if resp_len > crate::protocol::MAX_FRAME_SIZE {
74            anyhow::bail!("oversized response frame: {resp_len} bytes");
75        }
76        let mut buf = vec![0u8; resp_len as usize];
77        self.stream.read_exact(&mut buf).await?;
78
79        let response = serde_json::from_slice(&buf)?;
80        Ok(response)
81    }
82}
83
84/// Spawn the daemon as a detached background process.
85fn spawn_daemon() -> anyhow::Result<()> {
86    let exe = std::env::current_exe().context("failed to get current executable path")?;
87
88    std::process::Command::new(exe)
89        .args(["daemon", "start"])
90        .stdin(std::process::Stdio::null())
91        .stdout(std::process::Stdio::null())
92        .stderr(std::process::Stdio::null())
93        .spawn()
94        .context("failed to spawn daemon process")?;
95
96    Ok(())
97}
98
99/// Convert CLI command to protocol request.
100#[must_use]
101pub fn command_to_request(command: &crate::cli::Command) -> Request {
102    use crate::cli::{Command, EditCommand, FindCommand, ListCommand, ReadCommand};
103
104    match command {
105        Command::Init { .. } => unreachable!("init is handled locally"),
106        Command::Status => Request::Status,
107        Command::Check { path, errors_only } => Request::Check {
108            path: path.clone(),
109            errors_only: *errors_only,
110        },
111        Command::Find(FindCommand::Symbol {
112            name,
113            path,
114            src_only,
115            include_body,
116        }) => Request::FindSymbol {
117            name: name.clone(),
118            path_filter: path.clone(),
119            src_only: *src_only,
120            include_body: *include_body,
121        },
122        Command::Find(FindCommand::Refs { name, with_symbol }) => Request::FindRefs {
123            name: name.clone(),
124            with_symbol: *with_symbol,
125        },
126        Command::Find(FindCommand::Impl { name }) => Request::FindImpl { name: name.clone() },
127        Command::List(ListCommand::Symbols { path, depth }) => Request::ListSymbols {
128            path: path.clone(),
129            depth: *depth,
130        },
131        Command::Read(ReadCommand::File {
132            path,
133            from,
134            to,
135            max_lines,
136        }) => Request::ReadFile {
137            path: path.clone(),
138            from: *from,
139            to: *to,
140            max_lines: *max_lines,
141        },
142        Command::Read(ReadCommand::Symbol {
143            name,
144            signature_only,
145            max_lines,
146            path,
147            has_body,
148        }) => Request::ReadSymbol {
149            name: name.clone(),
150            signature_only: *signature_only,
151            max_lines: *max_lines,
152            path_filter: path.clone(),
153            has_body: *has_body,
154        },
155        Command::Edit(EditCommand::Replace { symbol }) => Request::EditReplace {
156            symbol: symbol.clone(),
157            code: String::new(), // stdin will be read separately
158        },
159        Command::Edit(EditCommand::InsertAfter { symbol }) => Request::EditInsertAfter {
160            symbol: symbol.clone(),
161            code: String::new(),
162        },
163        Command::Edit(EditCommand::InsertBefore { symbol }) => Request::EditInsertBefore {
164            symbol: symbol.clone(),
165            code: String::new(),
166        },
167        Command::Daemon(_) => unreachable!("daemon commands are handled directly"),
168        Command::Hover { name } => Request::Hover { name: name.clone() },
169        Command::Format { path } => Request::Format { path: path.clone() },
170        Command::Rename { symbol, new_name } => Request::Rename {
171            name: symbol.clone(),
172            new_name: new_name.clone(),
173        },
174        Command::Fix { path } => Request::Fix { path: path.clone() },
175        Command::Watch { .. } => unreachable!("watch is handled client-side"),
176        Command::Search { .. } => unreachable!("search is handled client-side"),
177        Command::Server(_) => unreachable!("server commands are handled client-side"),
178    }
179}
180
181/// Format a PID file path from a socket path for display purposes.
182#[must_use]
183pub fn pid_path_from_socket(socket_path: &Path) -> PathBuf {
184    socket_path.with_extension("pid")
185}
186
187#[cfg(test)]
188mod tests {
189    use std::time::Duration;
190
191    use super::*;
192    use crate::daemon::server::run_server;
193
194    #[tokio::test]
195    async fn client_connects_to_running_daemon() {
196        let dir = tempfile::tempdir().unwrap();
197        let sock = dir.path().join("test.sock");
198        let dir_root = dir.path().to_path_buf();
199
200        let sock_clone = sock.clone();
201        let _handle = tokio::spawn(async move {
202            run_server(&sock_clone, Duration::from_secs(5), &dir_root)
203                .await
204                .unwrap();
205        });
206        tokio::time::sleep(Duration::from_millis(50)).await;
207
208        let mut client = DaemonClient::connect(&sock).await.unwrap();
209        let resp = client.send(&Request::Status).await.unwrap();
210        assert!(resp.success);
211
212        // Clean up
213        let mut client = DaemonClient::connect(&sock).await.unwrap();
214        client.send(&Request::DaemonStop).await.unwrap();
215    }
216
217    #[tokio::test]
218    async fn client_stop_shuts_down_daemon() {
219        let dir = tempfile::tempdir().unwrap();
220        let sock = dir.path().join("test.sock");
221        let dir_root = dir.path().to_path_buf();
222
223        let sock_clone = sock.clone();
224        let handle = tokio::spawn(async move {
225            run_server(&sock_clone, Duration::from_secs(5), &dir_root)
226                .await
227                .unwrap();
228        });
229        tokio::time::sleep(Duration::from_millis(50)).await;
230
231        let mut client = DaemonClient::connect(&sock).await.unwrap();
232        let resp = client.send(&Request::DaemonStop).await.unwrap();
233        assert!(resp.success);
234
235        let _ = handle.await;
236    }
237
238    #[test]
239    fn command_to_request_list_symbols() {
240        use crate::cli::{Command, ListCommand};
241        use crate::protocol::Request;
242        use std::path::PathBuf;
243
244        let cmd = Command::List(ListCommand::Symbols {
245            path: PathBuf::from("src/lib.rs"),
246            depth: 2,
247        });
248        let req = command_to_request(&cmd);
249        assert!(matches!(req, Request::ListSymbols { depth: 2, .. }));
250    }
251
252    #[test]
253    fn command_to_request_read_file() {
254        use crate::cli::{Command, ReadCommand};
255        use crate::protocol::Request;
256        use std::path::PathBuf;
257
258        let cmd = Command::Read(ReadCommand::File {
259            path: PathBuf::from("main.rs"),
260            from: Some(1),
261            to: Some(10),
262            max_lines: None,
263        });
264        let req = command_to_request(&cmd);
265        assert!(matches!(
266            req,
267            Request::ReadFile {
268                from: Some(1),
269                to: Some(10),
270                ..
271            }
272        ));
273    }
274
275    #[test]
276    fn command_to_request_read_symbol() {
277        use crate::cli::{Command, ReadCommand};
278        use crate::protocol::Request;
279
280        let cmd = Command::Read(ReadCommand::Symbol {
281            name: "Config".into(),
282            signature_only: true,
283            max_lines: Some(20),
284            path: None,
285            has_body: false,
286        });
287        let req = command_to_request(&cmd);
288        assert!(matches!(
289            req,
290            Request::ReadSymbol {
291                signature_only: true,
292                max_lines: Some(20),
293                ..
294            }
295        ));
296    }
297
298    #[tokio::test]
299    async fn handle_connection_rejects_oversized_frame() {
300        use crate::daemon::server::run_server;
301        use tokio::io::{AsyncReadExt, AsyncWriteExt};
302        use tokio::net::UnixStream;
303
304        let dir = tempfile::tempdir().unwrap();
305        let sock = dir.path().join("test.sock");
306        let dir_root = dir.path().to_path_buf();
307
308        let sock_clone = sock.clone();
309        let _handle = tokio::spawn(async move {
310            run_server(&sock_clone, Duration::from_secs(5), &dir_root)
311                .await
312                .unwrap();
313        });
314        tokio::time::sleep(Duration::from_millis(50)).await;
315
316        // Send an oversized frame (20 MB > 10 MB limit)
317        let mut stream = UnixStream::connect(&sock).await.unwrap();
318        let oversized_len: u32 = 20 * 1024 * 1024;
319        stream.write_u32(oversized_len).await.unwrap();
320        stream.flush().await.unwrap();
321
322        // The connection should be closed by the server (we won't get a valid response)
323        let result = stream.read_u32().await;
324        assert!(
325            result.is_err(),
326            "server should close connection on oversized frame"
327        );
328
329        // Clean up
330        if let Ok(mut client) = DaemonClient::connect(&sock).await {
331            let _ = client.send(&Request::DaemonStop).await;
332        }
333    }
334
335    #[tokio::test]
336    async fn client_connect_fails_without_daemon() {
337        let dir = tempfile::tempdir().unwrap();
338        let sock = dir.path().join("nonexistent.sock");
339
340        let result = DaemonClient::connect(&sock).await;
341        assert!(result.is_err());
342    }
343}