1use std::io::BufRead;
2use std::io::BufReader;
3use std::io::ErrorKind;
4use std::io::Write;
5use std::os::unix::net::UnixStream;
6use std::sync::atomic::AtomicU64;
7use std::sync::atomic::Ordering;
8use std::time::Duration;
9
10use serde::Deserialize;
11use serde::Serialize;
12use serde_json::Value;
13
14use crate::ipc::error::ClientError;
15use crate::ipc::error_codes;
16use crate::ipc::socket::socket_path;
17
18static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
19
20pub mod polling {
22 use std::time::Duration;
23
24 pub const MAX_STARTUP_POLLS: u32 = 50;
26 pub const INITIAL_POLL_INTERVAL: Duration = Duration::from_millis(50);
28 pub const MAX_POLL_INTERVAL: Duration = Duration::from_millis(500);
30 pub const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
32}
33
34#[derive(Debug, Clone)]
35pub struct DaemonClientConfig {
36 pub read_timeout: Duration,
37 pub write_timeout: Duration,
38 pub max_retries: u32,
39 pub initial_retry_delay: Duration,
40}
41
42impl Default for DaemonClientConfig {
43 fn default() -> Self {
44 Self {
45 read_timeout: Duration::from_secs(60),
46 write_timeout: Duration::from_secs(10),
47 max_retries: 3,
48 initial_retry_delay: Duration::from_millis(100),
49 }
50 }
51}
52
53impl DaemonClientConfig {
54 pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
55 self.read_timeout = timeout;
56 self
57 }
58
59 pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
60 self.write_timeout = timeout;
61 self
62 }
63
64 pub fn with_max_retries(mut self, retries: u32) -> Self {
65 self.max_retries = retries;
66 self
67 }
68}
69
70#[derive(Debug, Serialize)]
71struct Request {
72 jsonrpc: String,
73 id: u64,
74 method: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 params: Option<Value>,
77}
78
79#[derive(Debug, Deserialize)]
80struct Response {
81 #[allow(dead_code)]
82 jsonrpc: String,
83 #[allow(dead_code)]
84 id: u64,
85 result: Option<Value>,
86 error: Option<RpcError>,
87}
88
89#[derive(Debug, Deserialize)]
90struct RpcError {
91 code: i32,
92 message: String,
93 #[serde(default)]
94 data: Option<Value>,
95}
96
97pub trait DaemonClient: Send + Sync {
102 fn call(&mut self, method: &str, params: Option<Value>) -> Result<Value, ClientError>;
104
105 fn call_with_config(
107 &mut self,
108 method: &str,
109 params: Option<Value>,
110 config: &DaemonClientConfig,
111 ) -> Result<Value, ClientError>;
112
113 fn call_with_retry(
115 &mut self,
116 method: &str,
117 params: Option<Value>,
118 max_retries: u32,
119 ) -> Result<Value, ClientError>;
120}
121
122pub struct UnixSocketClient;
124
125fn is_retriable_error(error: &ClientError) -> bool {
126 match error {
127 ClientError::ConnectionFailed(io_err) => matches!(
128 io_err.kind(),
129 ErrorKind::ConnectionRefused | ErrorKind::WouldBlock | ErrorKind::TimedOut
130 ),
131 ClientError::RpcError { retryable, .. } => *retryable,
132 _ => false,
133 }
134}
135
136impl UnixSocketClient {
137 pub fn connect() -> Result<Self, ClientError> {
138 let path = socket_path();
139 if !path.exists() {
140 return Err(ClientError::DaemonNotRunning);
141 }
142
143 let stream = UnixStream::connect(&path)?;
144 drop(stream);
145
146 Ok(Self)
147 }
148
149 pub fn is_daemon_running() -> bool {
150 let path = socket_path();
151 if !path.exists() {
152 return false;
153 }
154
155 UnixStream::connect(path).is_ok()
156 }
157}
158
159impl DaemonClient for UnixSocketClient {
160 fn call(&mut self, method: &str, params: Option<Value>) -> Result<Value, ClientError> {
161 self.call_with_config(method, params, &DaemonClientConfig::default())
162 }
163
164 fn call_with_config(
165 &mut self,
166 method: &str,
167 params: Option<Value>,
168 config: &DaemonClientConfig,
169 ) -> Result<Value, ClientError> {
170 let path = socket_path();
171 let mut stream = UnixStream::connect(&path)?;
172
173 stream.set_read_timeout(Some(config.read_timeout))?;
174 stream.set_write_timeout(Some(config.write_timeout))?;
175
176 let request = Request {
177 jsonrpc: "2.0".to_string(),
178 id: REQUEST_ID.fetch_add(1, Ordering::SeqCst),
179 method: method.to_string(),
180 params,
181 };
182
183 let request_json = serde_json::to_string(&request)?;
184
185 writeln!(stream, "{}", request_json)?;
186 stream.flush()?;
187
188 let mut reader = BufReader::new(&stream);
189 let mut response_line = String::new();
190 reader.read_line(&mut response_line)?;
191
192 let response: Response = serde_json::from_str(&response_line)?;
193
194 if let Some(error) = response.error {
195 let (category, retryable, context, suggestion) = if let Some(data) = error.data.as_ref()
196 {
197 let cat = data
198 .get("category")
199 .and_then(|v| v.as_str())
200 .and_then(|s| s.parse::<error_codes::ErrorCategory>().ok());
201 let retry = data
202 .get("retryable")
203 .and_then(|v| v.as_bool())
204 .unwrap_or_else(|| error_codes::is_retryable(error.code));
205 let ctx = data.get("context").cloned();
206 let sug = data
207 .get("suggestion")
208 .and_then(|v| v.as_str())
209 .map(String::from);
210 (cat, retry, ctx, sug)
211 } else {
212 (
213 Some(error_codes::category_for_code(error.code)),
214 error_codes::is_retryable(error.code),
215 None,
216 None,
217 )
218 };
219
220 return Err(ClientError::RpcError {
221 code: error.code,
222 message: error.message,
223 category,
224 retryable,
225 context,
226 suggestion,
227 });
228 }
229
230 response.result.ok_or(ClientError::InvalidResponse)
231 }
232
233 fn call_with_retry(
234 &mut self,
235 method: &str,
236 params: Option<Value>,
237 max_retries: u32,
238 ) -> Result<Value, ClientError> {
239 let config = DaemonClientConfig::default().with_max_retries(max_retries);
240 let mut delay = config.initial_retry_delay;
241 let mut last_error = None;
242
243 for attempt in 0..=config.max_retries {
244 let params_clone = params.clone();
245 match self.call_with_config(method, params_clone, &config) {
246 Ok(result) => return Ok(result),
247 Err(e) => {
248 if !is_retriable_error(&e) || attempt == config.max_retries {
249 return Err(e);
250 }
251 last_error = Some(e);
252 std::thread::sleep(delay);
253 delay *= 2; }
255 }
256 }
257
258 Err(last_error.unwrap_or(ClientError::DaemonNotRunning))
259 }
260}
261
262pub fn start_daemon_background() -> Result<(), ClientError> {
263 use std::fs::OpenOptions;
264 use std::process::Command;
265 use std::process::Stdio;
266
267 let exe = std::env::current_exe()?;
268 let log_path = socket_path().with_extension("log");
269
270 let log_file = match OpenOptions::new().create(true).append(true).open(&log_path) {
271 Ok(f) => Some(f),
272 Err(e) => {
273 eprintln!(
274 "Warning: Could not open daemon log file {}: {}",
275 log_path.display(),
276 e
277 );
278 None
279 }
280 };
281
282 let stderr = match log_file {
283 Some(f) => Stdio::from(f),
284 None => Stdio::null(),
285 };
286
287 Command::new(exe)
288 .args(["daemon", "start", "--foreground"])
289 .stdin(Stdio::null())
290 .stdout(Stdio::null())
291 .stderr(stderr)
292 .spawn()?;
293
294 let mut delay = polling::INITIAL_POLL_INTERVAL;
295 for i in 0..polling::MAX_STARTUP_POLLS {
296 std::thread::sleep(delay);
297 if UnixSocketClient::is_daemon_running() {
298 return Ok(());
299 }
300 delay = (delay * 2).min(polling::MAX_POLL_INTERVAL);
302
303 if i == polling::MAX_STARTUP_POLLS - 1 {
304 if let Ok(log_content) = std::fs::read_to_string(&log_path) {
305 let last_lines: String = log_content
306 .lines()
307 .rev()
308 .take(5)
309 .collect::<Vec<_>>()
310 .join("\n");
311 if !last_lines.is_empty() {
312 eprintln!("Daemon failed to start. Recent log output:\n{}", last_lines);
313 }
314 }
315 }
316 }
317
318 Err(ClientError::DaemonNotRunning)
319}
320
321pub fn ensure_daemon() -> Result<UnixSocketClient, ClientError> {
322 if !UnixSocketClient::is_daemon_running() {
323 start_daemon_background()?;
324 }
325
326 UnixSocketClient::connect()
327}
328
329#[derive(Debug, Clone, PartialEq, Eq)]
331pub enum PidLookupResult {
332 Found(u32),
334 NotRunning,
336 Error(String),
338}
339
340pub fn get_daemon_pid() -> PidLookupResult {
342 let lock_path = socket_path().with_extension("lock");
343 if !lock_path.exists() {
344 return PidLookupResult::NotRunning;
345 }
346
347 match std::fs::read_to_string(&lock_path) {
348 Err(e) => PidLookupResult::Error(format!(
349 "Failed to read lock file {}: {}",
350 lock_path.display(),
351 e
352 )),
353 Ok(content) => match content.trim().parse::<u32>() {
354 Ok(pid) => PidLookupResult::Found(pid),
355 Err(e) => PidLookupResult::Error(format!(
356 "Lock file contains invalid PID '{}': {}",
357 content.trim(),
358 e
359 )),
360 },
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_request_serializes_to_jsonrpc_2_0() {
370 let request = Request {
371 jsonrpc: "2.0".to_string(),
372 id: 1,
373 method: "health".to_string(),
374 params: None,
375 };
376 let json = serde_json::to_string(&request).unwrap();
377 assert!(json.contains("\"jsonrpc\":\"2.0\""));
378 assert!(json.contains("\"id\":1"));
379 assert!(json.contains("\"method\":\"health\""));
380 assert!(!json.contains("\"params\""));
381 }
382
383 #[test]
384 fn test_request_serializes_with_params() {
385 let request = Request {
386 jsonrpc: "2.0".to_string(),
387 id: 42,
388 method: "spawn".to_string(),
389 params: Some(serde_json::json!({"command": "bash", "cols": 80})),
390 };
391 let json = serde_json::to_string(&request).unwrap();
392 assert!(json.contains("\"params\""));
393 assert!(json.contains("\"command\":\"bash\""));
394 }
395
396 #[test]
397 fn test_response_deserializes_success_result() {
398 let json = r#"{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}"#;
399 let response: Response = serde_json::from_str(json).unwrap();
400 assert!(response.result.is_some());
401 assert!(response.error.is_none());
402 }
403
404 #[test]
405 fn test_response_deserializes_error() {
406 let json =
407 r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#;
408 let response: Response = serde_json::from_str(json).unwrap();
409 assert!(response.result.is_none());
410 assert!(response.error.is_some());
411 let error = response.error.unwrap();
412 assert_eq!(error.code, -32600);
413 }
414
415 #[test]
416 fn test_client_error_daemon_not_running_display() {
417 let err = ClientError::DaemonNotRunning;
418 assert_eq!(err.to_string(), "Daemon not running");
419 }
420
421 #[test]
422 fn test_client_error_invalid_response_display() {
423 let err = ClientError::InvalidResponse;
424 assert_eq!(err.to_string(), "Invalid response from daemon");
425 }
426
427 #[test]
428 fn test_client_error_rpc_error_display() {
429 let err = ClientError::RpcError {
430 code: -32601,
431 message: "Method not found".to_string(),
432 category: None,
433 retryable: false,
434 context: None,
435 suggestion: None,
436 };
437 assert_eq!(err.to_string(), "RPC error (-32601): Method not found");
438 }
439
440 #[test]
441 fn test_config_default_values() {
442 let config = DaemonClientConfig::default();
443 assert_eq!(config.read_timeout, Duration::from_secs(60));
444 assert_eq!(config.write_timeout, Duration::from_secs(10));
445 assert_eq!(config.max_retries, 3);
446 assert_eq!(config.initial_retry_delay, Duration::from_millis(100));
447 }
448
449 #[test]
450 fn test_config_builder_pattern() {
451 let config = DaemonClientConfig::default()
452 .with_read_timeout(Duration::from_secs(30))
453 .with_write_timeout(Duration::from_secs(5))
454 .with_max_retries(5);
455 assert_eq!(config.read_timeout, Duration::from_secs(30));
456 assert_eq!(config.write_timeout, Duration::from_secs(5));
457 assert_eq!(config.max_retries, 5);
458 }
459
460 #[test]
461 fn test_is_retriable_error_connection_refused() {
462 let io_err = std::io::Error::new(ErrorKind::ConnectionRefused, "connection refused");
463 let err = ClientError::ConnectionFailed(io_err);
464 assert!(is_retriable_error(&err));
465 }
466
467 #[test]
468 fn test_is_retriable_error_would_block() {
469 let io_err = std::io::Error::new(ErrorKind::WouldBlock, "would block");
470 let err = ClientError::ConnectionFailed(io_err);
471 assert!(is_retriable_error(&err));
472 }
473
474 #[test]
475 fn test_is_retriable_error_timed_out() {
476 let io_err = std::io::Error::new(ErrorKind::TimedOut, "timed out");
477 let err = ClientError::ConnectionFailed(io_err);
478 assert!(is_retriable_error(&err));
479 }
480
481 #[test]
482 fn test_is_retriable_error_rpc_error_not_retriable() {
483 let err = ClientError::RpcError {
484 code: -32600,
485 message: "Invalid request".to_string(),
486 category: None,
487 retryable: false,
488 context: None,
489 suggestion: None,
490 };
491 assert!(!is_retriable_error(&err));
492 }
493
494 #[test]
495 fn test_is_retriable_error_rpc_lock_timeout() {
496 let err = ClientError::RpcError {
497 code: error_codes::LOCK_TIMEOUT,
498 message: "Lock timeout".to_string(),
499 category: Some(error_codes::ErrorCategory::Busy),
500 retryable: true,
501 context: None,
502 suggestion: None,
503 };
504 assert!(is_retriable_error(&err));
505 }
506
507 #[test]
508 fn test_is_retriable_error_daemon_not_running() {
509 let err = ClientError::DaemonNotRunning;
510 assert!(!is_retriable_error(&err));
511 }
512}