1use std::path::{Path, PathBuf};
2use std::time::Duration;
3
4use anyhow::{bail, Context};
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::UnixStream;
7use tracing::debug;
8
9use crate::protocol::{Request, Response};
10
11const MAX_RETRIES: u32 = 3;
12const RETRY_DELAYS_MS: [u64; 3] = [100, 200, 500];
13
14pub struct DaemonClient {
15 stream: UnixStream,
16}
17
18impl DaemonClient {
19 pub async fn connect(socket_path: &Path) -> anyhow::Result<Self> {
24 let stream = UnixStream::connect(socket_path)
25 .await
26 .with_context(|| format!("failed to connect to daemon at {}", socket_path.display()))?;
27 Ok(Self { stream })
28 }
29
30 pub async fn connect_or_start(socket_path: &Path) -> anyhow::Result<Self> {
35 if let Ok(client) = Self::connect(socket_path).await {
36 return Ok(client);
37 }
38
39 debug!("daemon not running, starting it");
40 spawn_daemon()?;
41
42 for (attempt, delay_ms) in RETRY_DELAYS_MS.iter().enumerate() {
43 tokio::time::sleep(Duration::from_millis(*delay_ms)).await;
44 match Self::connect(socket_path).await {
45 Ok(client) => {
46 debug!("connected after {} retries", attempt + 1);
47 return Ok(client);
48 }
49 Err(e) if attempt == (MAX_RETRIES as usize - 1) => return Err(e),
50 Err(_) => {}
51 }
52 }
53
54 bail!(
55 "Daemon failed to start after {MAX_RETRIES} attempts. \
56 Run `krait daemon start` manually for debug output."
57 )
58 }
59
60 pub async fn send(&mut self, request: &Request) -> anyhow::Result<Response> {
65 let json = serde_json::to_vec(request)?;
66 let len = u32::try_from(json.len())?;
67
68 self.stream.write_u32(len).await?;
69 self.stream.write_all(&json).await?;
70 self.stream.flush().await?;
71
72 let resp_len = self.stream.read_u32().await?;
73 if resp_len > crate::protocol::MAX_FRAME_SIZE {
74 anyhow::bail!("oversized response frame: {resp_len} bytes");
75 }
76 let mut buf = vec![0u8; resp_len as usize];
77 self.stream.read_exact(&mut buf).await?;
78
79 let response = serde_json::from_slice(&buf)?;
80 Ok(response)
81 }
82}
83
84fn spawn_daemon() -> anyhow::Result<()> {
86 let exe = std::env::current_exe().context("failed to get current executable path")?;
87
88 std::process::Command::new(exe)
89 .args(["daemon", "start"])
90 .stdin(std::process::Stdio::null())
91 .stdout(std::process::Stdio::null())
92 .stderr(std::process::Stdio::null())
93 .spawn()
94 .context("failed to spawn daemon process")?;
95
96 Ok(())
97}
98
99#[must_use]
101pub fn command_to_request(command: &crate::cli::Command) -> Request {
102 use crate::cli::{Command, EditCommand, FindCommand, ListCommand, ReadCommand};
103
104 match command {
105 Command::Init { .. } => unreachable!("init is handled locally"),
106 Command::Status => Request::Status,
107 Command::Check { path, errors_only } => Request::Check {
108 path: path.clone(),
109 errors_only: *errors_only,
110 },
111 Command::Find(FindCommand::Symbol {
112 name,
113 path,
114 src_only,
115 include_body,
116 }) => Request::FindSymbol {
117 name: name.clone(),
118 path_filter: path.clone(),
119 src_only: *src_only,
120 include_body: *include_body,
121 },
122 Command::Find(FindCommand::Refs { name, with_symbol }) => Request::FindRefs {
123 name: name.clone(),
124 with_symbol: *with_symbol,
125 },
126 Command::Find(FindCommand::Impl { name }) => Request::FindImpl { name: name.clone() },
127 Command::List(ListCommand::Symbols { path, depth }) => Request::ListSymbols {
128 path: path.clone(),
129 depth: *depth,
130 },
131 Command::Read(ReadCommand::File {
132 path,
133 from,
134 to,
135 max_lines,
136 }) => Request::ReadFile {
137 path: path.clone(),
138 from: *from,
139 to: *to,
140 max_lines: *max_lines,
141 },
142 Command::Read(ReadCommand::Symbol {
143 name,
144 signature_only,
145 max_lines,
146 path,
147 has_body,
148 }) => Request::ReadSymbol {
149 name: name.clone(),
150 signature_only: *signature_only,
151 max_lines: *max_lines,
152 path_filter: path.clone(),
153 has_body: *has_body,
154 },
155 Command::Edit(EditCommand::Replace { symbol }) => Request::EditReplace {
156 symbol: symbol.clone(),
157 code: String::new(), },
159 Command::Edit(EditCommand::InsertAfter { symbol }) => Request::EditInsertAfter {
160 symbol: symbol.clone(),
161 code: String::new(),
162 },
163 Command::Edit(EditCommand::InsertBefore { symbol }) => Request::EditInsertBefore {
164 symbol: symbol.clone(),
165 code: String::new(),
166 },
167 Command::Daemon(_) => unreachable!("daemon commands are handled directly"),
168 Command::Hover { name } => Request::Hover { name: name.clone() },
169 Command::Format { path } => Request::Format { path: path.clone() },
170 Command::Rename { symbol, new_name } => Request::Rename {
171 name: symbol.clone(),
172 new_name: new_name.clone(),
173 },
174 Command::Fix { path } => Request::Fix { path: path.clone() },
175 Command::Watch { .. } => unreachable!("watch is handled client-side"),
176 Command::Search { .. } => unreachable!("search is handled client-side"),
177 Command::Server(_) => unreachable!("server commands are handled client-side"),
178 }
179}
180
181#[must_use]
183pub fn pid_path_from_socket(socket_path: &Path) -> PathBuf {
184 socket_path.with_extension("pid")
185}
186
187#[cfg(test)]
188mod tests {
189 use std::time::Duration;
190
191 use super::*;
192 use crate::daemon::server::run_server;
193
194 #[tokio::test]
195 async fn client_connects_to_running_daemon() {
196 let dir = tempfile::tempdir().unwrap();
197 let sock = dir.path().join("test.sock");
198 let dir_root = dir.path().to_path_buf();
199
200 let sock_clone = sock.clone();
201 let _handle = tokio::spawn(async move {
202 run_server(&sock_clone, Duration::from_secs(5), &dir_root)
203 .await
204 .unwrap();
205 });
206 tokio::time::sleep(Duration::from_millis(50)).await;
207
208 let mut client = DaemonClient::connect(&sock).await.unwrap();
209 let resp = client.send(&Request::Status).await.unwrap();
210 assert!(resp.success);
211
212 let mut client = DaemonClient::connect(&sock).await.unwrap();
214 client.send(&Request::DaemonStop).await.unwrap();
215 }
216
217 #[tokio::test]
218 async fn client_stop_shuts_down_daemon() {
219 let dir = tempfile::tempdir().unwrap();
220 let sock = dir.path().join("test.sock");
221 let dir_root = dir.path().to_path_buf();
222
223 let sock_clone = sock.clone();
224 let handle = tokio::spawn(async move {
225 run_server(&sock_clone, Duration::from_secs(5), &dir_root)
226 .await
227 .unwrap();
228 });
229 tokio::time::sleep(Duration::from_millis(50)).await;
230
231 let mut client = DaemonClient::connect(&sock).await.unwrap();
232 let resp = client.send(&Request::DaemonStop).await.unwrap();
233 assert!(resp.success);
234
235 let _ = handle.await;
236 }
237
238 #[test]
239 fn command_to_request_list_symbols() {
240 use crate::cli::{Command, ListCommand};
241 use crate::protocol::Request;
242 use std::path::PathBuf;
243
244 let cmd = Command::List(ListCommand::Symbols {
245 path: PathBuf::from("src/lib.rs"),
246 depth: 2,
247 });
248 let req = command_to_request(&cmd);
249 assert!(matches!(req, Request::ListSymbols { depth: 2, .. }));
250 }
251
252 #[test]
253 fn command_to_request_read_file() {
254 use crate::cli::{Command, ReadCommand};
255 use crate::protocol::Request;
256 use std::path::PathBuf;
257
258 let cmd = Command::Read(ReadCommand::File {
259 path: PathBuf::from("main.rs"),
260 from: Some(1),
261 to: Some(10),
262 max_lines: None,
263 });
264 let req = command_to_request(&cmd);
265 assert!(matches!(
266 req,
267 Request::ReadFile {
268 from: Some(1),
269 to: Some(10),
270 ..
271 }
272 ));
273 }
274
275 #[test]
276 fn command_to_request_read_symbol() {
277 use crate::cli::{Command, ReadCommand};
278 use crate::protocol::Request;
279
280 let cmd = Command::Read(ReadCommand::Symbol {
281 name: "Config".into(),
282 signature_only: true,
283 max_lines: Some(20),
284 path: None,
285 has_body: false,
286 });
287 let req = command_to_request(&cmd);
288 assert!(matches!(
289 req,
290 Request::ReadSymbol {
291 signature_only: true,
292 max_lines: Some(20),
293 ..
294 }
295 ));
296 }
297
298 #[tokio::test]
299 async fn handle_connection_rejects_oversized_frame() {
300 use crate::daemon::server::run_server;
301 use tokio::io::{AsyncReadExt, AsyncWriteExt};
302 use tokio::net::UnixStream;
303
304 let dir = tempfile::tempdir().unwrap();
305 let sock = dir.path().join("test.sock");
306 let dir_root = dir.path().to_path_buf();
307
308 let sock_clone = sock.clone();
309 let _handle = tokio::spawn(async move {
310 run_server(&sock_clone, Duration::from_secs(5), &dir_root)
311 .await
312 .unwrap();
313 });
314 tokio::time::sleep(Duration::from_millis(50)).await;
315
316 let mut stream = UnixStream::connect(&sock).await.unwrap();
318 let oversized_len: u32 = 20 * 1024 * 1024;
319 stream.write_u32(oversized_len).await.unwrap();
320 stream.flush().await.unwrap();
321
322 let result = stream.read_u32().await;
324 assert!(
325 result.is_err(),
326 "server should close connection on oversized frame"
327 );
328
329 if let Ok(mut client) = DaemonClient::connect(&sock).await {
331 let _ = client.send(&Request::DaemonStop).await;
332 }
333 }
334
335 #[tokio::test]
336 async fn client_connect_fails_without_daemon() {
337 let dir = tempfile::tempdir().unwrap();
338 let sock = dir.path().join("nonexistent.sock");
339
340 let result = DaemonClient::connect(&sock).await;
341 assert!(result.is_err());
342 }
343}