agent_tui/ipc/
client.rs

1use std::io::BufRead;
2use std::io::BufReader;
3use std::io::ErrorKind;
4use std::io::Write;
5use std::os::unix::net::UnixStream;
6use std::sync::atomic::AtomicU64;
7use std::sync::atomic::Ordering;
8use std::time::Duration;
9
10use serde::Deserialize;
11use serde::Serialize;
12use serde_json::Value;
13
14use crate::ipc::error::ClientError;
15use crate::ipc::error_codes;
16use crate::ipc::socket::socket_path;
17
18static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
19
20/// Polling configuration for daemon startup/shutdown.
21pub mod polling {
22    use std::time::Duration;
23
24    /// Maximum number of polls during daemon startup.
25    pub const MAX_STARTUP_POLLS: u32 = 50;
26    /// Initial delay between polls.
27    pub const INITIAL_POLL_INTERVAL: Duration = Duration::from_millis(50);
28    /// Maximum delay between polls (after exponential backoff).
29    pub const MAX_POLL_INTERVAL: Duration = Duration::from_millis(500);
30    /// Timeout for daemon shutdown.
31    pub const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
32}
33
34#[derive(Debug, Clone)]
35pub struct DaemonClientConfig {
36    pub read_timeout: Duration,
37    pub write_timeout: Duration,
38    pub max_retries: u32,
39    pub initial_retry_delay: Duration,
40}
41
42impl Default for DaemonClientConfig {
43    fn default() -> Self {
44        Self {
45            read_timeout: Duration::from_secs(60),
46            write_timeout: Duration::from_secs(10),
47            max_retries: 3,
48            initial_retry_delay: Duration::from_millis(100),
49        }
50    }
51}
52
53impl DaemonClientConfig {
54    pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
55        self.read_timeout = timeout;
56        self
57    }
58
59    pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
60        self.write_timeout = timeout;
61        self
62    }
63
64    pub fn with_max_retries(mut self, retries: u32) -> Self {
65        self.max_retries = retries;
66        self
67    }
68}
69
70#[derive(Debug, Serialize)]
71struct Request {
72    jsonrpc: String,
73    id: u64,
74    method: String,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    params: Option<Value>,
77}
78
79#[derive(Debug, Deserialize)]
80struct Response {
81    #[allow(dead_code)]
82    jsonrpc: String,
83    #[allow(dead_code)]
84    id: u64,
85    result: Option<Value>,
86    error: Option<RpcError>,
87}
88
89#[derive(Debug, Deserialize)]
90struct RpcError {
91    code: i32,
92    message: String,
93    #[serde(default)]
94    data: Option<Value>,
95}
96
97/// Trait for daemon client implementations.
98///
99/// This trait abstracts the communication with the daemon, allowing for
100/// different transport implementations (Unix socket, mock for testing, etc.).
101pub trait DaemonClient: Send + Sync {
102    /// Make an RPC call to the daemon.
103    fn call(&mut self, method: &str, params: Option<Value>) -> Result<Value, ClientError>;
104
105    /// Make an RPC call with custom configuration.
106    fn call_with_config(
107        &mut self,
108        method: &str,
109        params: Option<Value>,
110        config: &DaemonClientConfig,
111    ) -> Result<Value, ClientError>;
112
113    /// Make an RPC call with retry logic.
114    fn call_with_retry(
115        &mut self,
116        method: &str,
117        params: Option<Value>,
118        max_retries: u32,
119    ) -> Result<Value, ClientError>;
120}
121
122/// Unix socket-based daemon client implementation.
123pub struct UnixSocketClient;
124
125fn is_retriable_error(error: &ClientError) -> bool {
126    match error {
127        ClientError::ConnectionFailed(io_err) => matches!(
128            io_err.kind(),
129            ErrorKind::ConnectionRefused | ErrorKind::WouldBlock | ErrorKind::TimedOut
130        ),
131        ClientError::RpcError { retryable, .. } => *retryable,
132        _ => false,
133    }
134}
135
136impl UnixSocketClient {
137    pub fn connect() -> Result<Self, ClientError> {
138        let path = socket_path();
139        if !path.exists() {
140            return Err(ClientError::DaemonNotRunning);
141        }
142
143        let stream = UnixStream::connect(&path)?;
144        drop(stream);
145
146        Ok(Self)
147    }
148
149    pub fn is_daemon_running() -> bool {
150        let path = socket_path();
151        if !path.exists() {
152            return false;
153        }
154
155        UnixStream::connect(path).is_ok()
156    }
157}
158
159impl DaemonClient for UnixSocketClient {
160    fn call(&mut self, method: &str, params: Option<Value>) -> Result<Value, ClientError> {
161        self.call_with_config(method, params, &DaemonClientConfig::default())
162    }
163
164    fn call_with_config(
165        &mut self,
166        method: &str,
167        params: Option<Value>,
168        config: &DaemonClientConfig,
169    ) -> Result<Value, ClientError> {
170        let path = socket_path();
171        let mut stream = UnixStream::connect(&path)?;
172
173        stream.set_read_timeout(Some(config.read_timeout))?;
174        stream.set_write_timeout(Some(config.write_timeout))?;
175
176        let request = Request {
177            jsonrpc: "2.0".to_string(),
178            id: REQUEST_ID.fetch_add(1, Ordering::SeqCst),
179            method: method.to_string(),
180            params,
181        };
182
183        let request_json = serde_json::to_string(&request)?;
184
185        writeln!(stream, "{}", request_json)?;
186        stream.flush()?;
187
188        let mut reader = BufReader::new(&stream);
189        let mut response_line = String::new();
190        reader.read_line(&mut response_line)?;
191
192        let response: Response = serde_json::from_str(&response_line)?;
193
194        if let Some(error) = response.error {
195            let (category, retryable, context, suggestion) = if let Some(data) = error.data.as_ref()
196            {
197                let cat = data
198                    .get("category")
199                    .and_then(|v| v.as_str())
200                    .and_then(|s| s.parse::<error_codes::ErrorCategory>().ok());
201                let retry = data
202                    .get("retryable")
203                    .and_then(|v| v.as_bool())
204                    .unwrap_or_else(|| error_codes::is_retryable(error.code));
205                let ctx = data.get("context").cloned();
206                let sug = data
207                    .get("suggestion")
208                    .and_then(|v| v.as_str())
209                    .map(String::from);
210                (cat, retry, ctx, sug)
211            } else {
212                (
213                    Some(error_codes::category_for_code(error.code)),
214                    error_codes::is_retryable(error.code),
215                    None,
216                    None,
217                )
218            };
219
220            return Err(ClientError::RpcError {
221                code: error.code,
222                message: error.message,
223                category,
224                retryable,
225                context,
226                suggestion,
227            });
228        }
229
230        response.result.ok_or(ClientError::InvalidResponse)
231    }
232
233    fn call_with_retry(
234        &mut self,
235        method: &str,
236        params: Option<Value>,
237        max_retries: u32,
238    ) -> Result<Value, ClientError> {
239        let config = DaemonClientConfig::default().with_max_retries(max_retries);
240        let mut delay = config.initial_retry_delay;
241        let mut last_error = None;
242
243        for attempt in 0..=config.max_retries {
244            let params_clone = params.clone();
245            match self.call_with_config(method, params_clone, &config) {
246                Ok(result) => return Ok(result),
247                Err(e) => {
248                    if !is_retriable_error(&e) || attempt == config.max_retries {
249                        return Err(e);
250                    }
251                    last_error = Some(e);
252                    std::thread::sleep(delay);
253                    delay *= 2; // exponential backoff: 100ms, 200ms, 400ms
254                }
255            }
256        }
257
258        Err(last_error.unwrap_or(ClientError::DaemonNotRunning))
259    }
260}
261
262pub fn start_daemon_background() -> Result<(), ClientError> {
263    use std::fs::OpenOptions;
264    use std::process::Command;
265    use std::process::Stdio;
266
267    let exe = std::env::current_exe()?;
268    let log_path = socket_path().with_extension("log");
269
270    let log_file = match OpenOptions::new().create(true).append(true).open(&log_path) {
271        Ok(f) => Some(f),
272        Err(e) => {
273            eprintln!(
274                "Warning: Could not open daemon log file {}: {}",
275                log_path.display(),
276                e
277            );
278            None
279        }
280    };
281
282    let stderr = match log_file {
283        Some(f) => Stdio::from(f),
284        None => Stdio::null(),
285    };
286
287    Command::new(exe)
288        .args(["daemon", "start", "--foreground"])
289        .stdin(Stdio::null())
290        .stdout(Stdio::null())
291        .stderr(stderr)
292        .spawn()?;
293
294    let mut delay = polling::INITIAL_POLL_INTERVAL;
295    for i in 0..polling::MAX_STARTUP_POLLS {
296        std::thread::sleep(delay);
297        if UnixSocketClient::is_daemon_running() {
298            return Ok(());
299        }
300        // Exponential backoff with cap
301        delay = (delay * 2).min(polling::MAX_POLL_INTERVAL);
302
303        if i == polling::MAX_STARTUP_POLLS - 1 {
304            if let Ok(log_content) = std::fs::read_to_string(&log_path) {
305                let last_lines: String = log_content
306                    .lines()
307                    .rev()
308                    .take(5)
309                    .collect::<Vec<_>>()
310                    .join("\n");
311                if !last_lines.is_empty() {
312                    eprintln!("Daemon failed to start. Recent log output:\n{}", last_lines);
313                }
314            }
315        }
316    }
317
318    Err(ClientError::DaemonNotRunning)
319}
320
321pub fn ensure_daemon() -> Result<UnixSocketClient, ClientError> {
322    if !UnixSocketClient::is_daemon_running() {
323        start_daemon_background()?;
324    }
325
326    UnixSocketClient::connect()
327}
328
329/// Result of PID lookup from lock file.
330#[derive(Debug, Clone, PartialEq, Eq)]
331pub enum PidLookupResult {
332    /// Daemon is running with this PID.
333    Found(u32),
334    /// No lock file exists (daemon not running).
335    NotRunning,
336    /// Lock file exists but could not be read or parsed.
337    Error(String),
338}
339
340/// Get the daemon PID from the lock file.
341pub fn get_daemon_pid() -> PidLookupResult {
342    let lock_path = socket_path().with_extension("lock");
343    if !lock_path.exists() {
344        return PidLookupResult::NotRunning;
345    }
346
347    match std::fs::read_to_string(&lock_path) {
348        Err(e) => PidLookupResult::Error(format!(
349            "Failed to read lock file {}: {}",
350            lock_path.display(),
351            e
352        )),
353        Ok(content) => match content.trim().parse::<u32>() {
354            Ok(pid) => PidLookupResult::Found(pid),
355            Err(e) => PidLookupResult::Error(format!(
356                "Lock file contains invalid PID '{}': {}",
357                content.trim(),
358                e
359            )),
360        },
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_request_serializes_to_jsonrpc_2_0() {
370        let request = Request {
371            jsonrpc: "2.0".to_string(),
372            id: 1,
373            method: "health".to_string(),
374            params: None,
375        };
376        let json = serde_json::to_string(&request).unwrap();
377        assert!(json.contains("\"jsonrpc\":\"2.0\""));
378        assert!(json.contains("\"id\":1"));
379        assert!(json.contains("\"method\":\"health\""));
380        assert!(!json.contains("\"params\""));
381    }
382
383    #[test]
384    fn test_request_serializes_with_params() {
385        let request = Request {
386            jsonrpc: "2.0".to_string(),
387            id: 42,
388            method: "spawn".to_string(),
389            params: Some(serde_json::json!({"command": "bash", "cols": 80})),
390        };
391        let json = serde_json::to_string(&request).unwrap();
392        assert!(json.contains("\"params\""));
393        assert!(json.contains("\"command\":\"bash\""));
394    }
395
396    #[test]
397    fn test_response_deserializes_success_result() {
398        let json = r#"{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}"#;
399        let response: Response = serde_json::from_str(json).unwrap();
400        assert!(response.result.is_some());
401        assert!(response.error.is_none());
402    }
403
404    #[test]
405    fn test_response_deserializes_error() {
406        let json =
407            r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#;
408        let response: Response = serde_json::from_str(json).unwrap();
409        assert!(response.result.is_none());
410        assert!(response.error.is_some());
411        let error = response.error.unwrap();
412        assert_eq!(error.code, -32600);
413    }
414
415    #[test]
416    fn test_client_error_daemon_not_running_display() {
417        let err = ClientError::DaemonNotRunning;
418        assert_eq!(err.to_string(), "Daemon not running");
419    }
420
421    #[test]
422    fn test_client_error_invalid_response_display() {
423        let err = ClientError::InvalidResponse;
424        assert_eq!(err.to_string(), "Invalid response from daemon");
425    }
426
427    #[test]
428    fn test_client_error_rpc_error_display() {
429        let err = ClientError::RpcError {
430            code: -32601,
431            message: "Method not found".to_string(),
432            category: None,
433            retryable: false,
434            context: None,
435            suggestion: None,
436        };
437        assert_eq!(err.to_string(), "RPC error (-32601): Method not found");
438    }
439
440    #[test]
441    fn test_config_default_values() {
442        let config = DaemonClientConfig::default();
443        assert_eq!(config.read_timeout, Duration::from_secs(60));
444        assert_eq!(config.write_timeout, Duration::from_secs(10));
445        assert_eq!(config.max_retries, 3);
446        assert_eq!(config.initial_retry_delay, Duration::from_millis(100));
447    }
448
449    #[test]
450    fn test_config_builder_pattern() {
451        let config = DaemonClientConfig::default()
452            .with_read_timeout(Duration::from_secs(30))
453            .with_write_timeout(Duration::from_secs(5))
454            .with_max_retries(5);
455        assert_eq!(config.read_timeout, Duration::from_secs(30));
456        assert_eq!(config.write_timeout, Duration::from_secs(5));
457        assert_eq!(config.max_retries, 5);
458    }
459
460    #[test]
461    fn test_is_retriable_error_connection_refused() {
462        let io_err = std::io::Error::new(ErrorKind::ConnectionRefused, "connection refused");
463        let err = ClientError::ConnectionFailed(io_err);
464        assert!(is_retriable_error(&err));
465    }
466
467    #[test]
468    fn test_is_retriable_error_would_block() {
469        let io_err = std::io::Error::new(ErrorKind::WouldBlock, "would block");
470        let err = ClientError::ConnectionFailed(io_err);
471        assert!(is_retriable_error(&err));
472    }
473
474    #[test]
475    fn test_is_retriable_error_timed_out() {
476        let io_err = std::io::Error::new(ErrorKind::TimedOut, "timed out");
477        let err = ClientError::ConnectionFailed(io_err);
478        assert!(is_retriable_error(&err));
479    }
480
481    #[test]
482    fn test_is_retriable_error_rpc_error_not_retriable() {
483        let err = ClientError::RpcError {
484            code: -32600,
485            message: "Invalid request".to_string(),
486            category: None,
487            retryable: false,
488            context: None,
489            suggestion: None,
490        };
491        assert!(!is_retriable_error(&err));
492    }
493
494    #[test]
495    fn test_is_retriable_error_rpc_lock_timeout() {
496        let err = ClientError::RpcError {
497            code: error_codes::LOCK_TIMEOUT,
498            message: "Lock timeout".to_string(),
499            category: Some(error_codes::ErrorCategory::Busy),
500            retryable: true,
501            context: None,
502            suggestion: None,
503        };
504        assert!(is_retriable_error(&err));
505    }
506
507    #[test]
508    fn test_is_retriable_error_daemon_not_running() {
509        let err = ClientError::DaemonNotRunning;
510        assert!(!is_retriable_error(&err));
511    }
512}