lore_cli/daemon/
server.rs

1//! Unix socket IPC server for daemon communication.
2//!
3//! Provides a simple request/response protocol over Unix domain sockets
4//! for communicating with the running daemon. Supports commands like
5//! status, stop, and stats.
6
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::{UnixListener, UnixStream};
13use tokio::sync::{oneshot, RwLock};
14
15use super::state::DaemonStats;
16use crate::storage::Database;
17
18/// Commands that can be sent to the daemon via IPC.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(tag = "command", rename_all = "snake_case")]
21pub enum DaemonCommand {
22    /// Request the daemon's current status.
23    Status,
24    /// Request the daemon to shut down gracefully.
25    Stop,
26    /// Request runtime statistics from the daemon.
27    Stats,
28    /// Ping to check if daemon is responsive.
29    Ping,
30    /// Request the current active session for a working directory.
31    GetCurrentSession {
32        /// The working directory to check for an active session.
33        working_directory: String,
34    },
35}
36
37/// Responses from the daemon to IPC commands.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum DaemonResponse {
41    /// Status response indicating daemon is running.
42    Status {
43        running: bool,
44        pid: u32,
45        uptime_seconds: u64,
46    },
47    /// Acknowledgment that stop command was received.
48    Stopping,
49    /// Runtime statistics.
50    Stats(DaemonStats),
51    /// Ping response.
52    Pong,
53    /// Current session response.
54    CurrentSession {
55        /// The session ID if an active session was found.
56        session_id: Option<String>,
57    },
58    /// Error response.
59    Error { message: String },
60}
61
62/// Runs the IPC server on the given Unix socket path.
63///
64/// The server listens for incoming connections and processes commands
65/// until a shutdown signal is received or the Stop command is sent.
66///
67/// # Arguments
68///
69/// * `socket_path` - Path for the Unix domain socket
70/// * `stats` - Shared statistics that can be read by clients
71/// * `shutdown_tx` - Sender to signal daemon shutdown when Stop is received
72/// * `mut shutdown_rx` - Receiver that signals when to stop the server
73///
74/// # Errors
75///
76/// Returns an error if the socket cannot be created or bound.
77pub async fn run_server(
78    socket_path: &Path,
79    stats: Arc<RwLock<DaemonStats>>,
80    shutdown_tx: Option<oneshot::Sender<()>>,
81    mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
82) -> Result<()> {
83    // Remove existing socket file if present
84    if socket_path.exists() {
85        std::fs::remove_file(socket_path).context("Failed to remove existing socket file")?;
86    }
87
88    let listener = UnixListener::bind(socket_path).context("Failed to bind Unix socket")?;
89
90    tracing::info!("IPC server listening on {:?}", socket_path);
91
92    // Wrap shutdown_tx in Arc<Mutex> so it can be moved into the handler
93    let shutdown_tx = Arc::new(std::sync::Mutex::new(shutdown_tx));
94
95    loop {
96        tokio::select! {
97            accept_result = listener.accept() => {
98                match accept_result {
99                    Ok((stream, _addr)) => {
100                        let stats_clone = stats.clone();
101                        let shutdown_tx_clone = shutdown_tx.clone();
102                        tokio::spawn(async move {
103                            if let Err(e) = handle_connection(stream, stats_clone, shutdown_tx_clone).await {
104                                tracing::warn!("Error handling IPC connection: {}", e);
105                            }
106                        });
107                    }
108                    Err(e) => {
109                        tracing::warn!("Failed to accept connection: {}", e);
110                    }
111                }
112            }
113            _ = shutdown_rx.recv() => {
114                tracing::info!("IPC server shutting down");
115                break;
116            }
117        }
118    }
119
120    Ok(())
121}
122
123/// Handles a single client connection.
124async fn handle_connection(
125    stream: UnixStream,
126    stats: Arc<RwLock<DaemonStats>>,
127    shutdown_tx: Arc<std::sync::Mutex<Option<oneshot::Sender<()>>>>,
128) -> Result<()> {
129    let (reader, mut writer) = stream.into_split();
130    let mut reader = BufReader::new(reader);
131    let mut line = String::new();
132
133    // Read a single line (one command per connection)
134    reader
135        .read_line(&mut line)
136        .await
137        .context("Failed to read from socket")?;
138
139    let command: DaemonCommand =
140        serde_json::from_str(line.trim()).context("Failed to parse command")?;
141
142    tracing::debug!("Received IPC command: {:?}", command);
143
144    let response = match command {
145        DaemonCommand::Status => {
146            let stats_guard = stats.read().await;
147            let uptime = chrono::Utc::now()
148                .signed_duration_since(stats_guard.started_at)
149                .num_seconds() as u64;
150            DaemonResponse::Status {
151                running: true,
152                pid: std::process::id(),
153                uptime_seconds: uptime,
154            }
155        }
156        DaemonCommand::Stop => {
157            // Signal the daemon to shut down
158            // If the lock is poisoned, we still want to try to shut down
159            let mut guard = shutdown_tx
160                .lock()
161                .unwrap_or_else(|poisoned| poisoned.into_inner());
162            if let Some(tx) = guard.take() {
163                let _ = tx.send(());
164            }
165            DaemonResponse::Stopping
166        }
167        DaemonCommand::Stats => {
168            let stats_guard = stats.read().await;
169            DaemonResponse::Stats(stats_guard.clone())
170        }
171        DaemonCommand::Ping => DaemonResponse::Pong,
172        DaemonCommand::GetCurrentSession { working_directory } => {
173            // Query the database for the most recent session in this directory
174            match get_current_session_for_directory(&working_directory) {
175                Ok(session_id) => DaemonResponse::CurrentSession { session_id },
176                Err(e) => DaemonResponse::Error {
177                    message: format!("Failed to get current session: {e}"),
178                },
179            }
180        }
181    };
182
183    let response_json = serde_json::to_string(&response).context("Failed to serialize response")?;
184
185    writer
186        .write_all(response_json.as_bytes())
187        .await
188        .context("Failed to write response")?;
189    writer
190        .write_all(b"\n")
191        .await
192        .context("Failed to write newline")?;
193    writer.flush().await.context("Failed to flush writer")?;
194
195    Ok(())
196}
197
198/// Sends a command to the daemon and returns the response.
199///
200/// Connects to the Unix socket, sends the command, and reads the response.
201///
202/// # Arguments
203///
204/// * `socket_path` - Path to the daemon's Unix socket
205/// * `command` - The command to send
206///
207/// # Errors
208///
209/// Returns an error if the connection fails, the command cannot be sent,
210/// or the response cannot be read or parsed.
211pub async fn send_command(socket_path: &Path, command: DaemonCommand) -> Result<DaemonResponse> {
212    let stream = UnixStream::connect(socket_path)
213        .await
214        .context("Failed to connect to daemon socket")?;
215
216    let (reader, mut writer) = stream.into_split();
217
218    // Send command
219    let command_json = serde_json::to_string(&command).context("Failed to serialize command")?;
220    writer
221        .write_all(command_json.as_bytes())
222        .await
223        .context("Failed to write command")?;
224    writer
225        .write_all(b"\n")
226        .await
227        .context("Failed to write newline")?;
228    writer.flush().await.context("Failed to flush")?;
229
230    // Read response
231    let mut reader = BufReader::new(reader);
232    let mut line = String::new();
233    reader
234        .read_line(&mut line)
235        .await
236        .context("Failed to read response")?;
237
238    let response: DaemonResponse =
239        serde_json::from_str(line.trim()).context("Failed to parse response")?;
240
241    Ok(response)
242}
243
244/// Synchronous wrapper for sending a command to the daemon.
245///
246/// Creates a temporary tokio runtime to send the command.
247/// Use this from non-async contexts like CLI commands.
248pub fn send_command_sync(socket_path: &Path, command: DaemonCommand) -> Result<DaemonResponse> {
249    let rt = tokio::runtime::Runtime::new().context("Failed to create tokio runtime")?;
250    rt.block_on(send_command(socket_path, command))
251}
252
253/// Gets the current active session for a working directory.
254///
255/// Queries the database for the most recent session whose working directory
256/// matches or contains the given path. This is used by the daemon to respond
257/// to GetCurrentSession IPC requests.
258///
259/// Returns the session ID as a string if found, or None if no matching session exists.
260fn get_current_session_for_directory(working_dir: &str) -> Result<Option<String>> {
261    let db = Database::open_default()?;
262    let session = db.get_most_recent_session_for_directory(working_dir)?;
263    Ok(session.map(|s| s.id.to_string()))
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use tempfile::tempdir;
270
271    #[test]
272    fn test_daemon_command_serialization() {
273        let commands = vec![
274            DaemonCommand::Status,
275            DaemonCommand::Stop,
276            DaemonCommand::Stats,
277            DaemonCommand::Ping,
278        ];
279
280        for cmd in commands {
281            let json = serde_json::to_string(&cmd).expect("Failed to serialize");
282            let parsed: DaemonCommand = serde_json::from_str(&json).expect("Failed to parse");
283            // Just verify round-trip works (can't compare Debug output reliably)
284            let _ = parsed;
285        }
286    }
287
288    #[test]
289    fn test_daemon_response_status_serialization() {
290        let response = DaemonResponse::Status {
291            running: true,
292            pid: 12345,
293            uptime_seconds: 3600,
294        };
295
296        let json = serde_json::to_string(&response).expect("Failed to serialize");
297        assert!(json.contains("\"type\":\"status\""));
298        assert!(json.contains("\"running\":true"));
299        assert!(json.contains("\"pid\":12345"));
300
301        let parsed: DaemonResponse = serde_json::from_str(&json).expect("Failed to parse");
302        match parsed {
303            DaemonResponse::Status {
304                running,
305                pid,
306                uptime_seconds,
307            } => {
308                assert!(running);
309                assert_eq!(pid, 12345);
310                assert_eq!(uptime_seconds, 3600);
311            }
312            _ => panic!("Expected Status response"),
313        }
314    }
315
316    #[test]
317    fn test_daemon_response_stopping_serialization() {
318        let response = DaemonResponse::Stopping;
319        let json = serde_json::to_string(&response).expect("Failed to serialize");
320        assert!(json.contains("\"type\":\"stopping\""));
321    }
322
323    #[test]
324    fn test_daemon_response_stats_serialization() {
325        let stats = DaemonStats::default();
326        let response = DaemonResponse::Stats(stats);
327
328        let json = serde_json::to_string(&response).expect("Failed to serialize");
329        assert!(json.contains("\"type\":\"stats\""));
330        assert!(json.contains("\"files_watched\""));
331    }
332
333    #[test]
334    fn test_daemon_response_error_serialization() {
335        let response = DaemonResponse::Error {
336            message: "Something went wrong".to_string(),
337        };
338
339        let json = serde_json::to_string(&response).expect("Failed to serialize");
340        assert!(json.contains("\"type\":\"error\""));
341        assert!(json.contains("Something went wrong"));
342    }
343
344    #[tokio::test]
345    async fn test_server_client_communication() {
346        let dir = tempdir().expect("Failed to create temp dir");
347        let socket_path = dir.path().join("test.sock");
348
349        let stats = Arc::new(RwLock::new(DaemonStats::default()));
350        let (shutdown_tx, _shutdown_rx) = oneshot::channel();
351        let (broadcast_tx, broadcast_rx) = tokio::sync::broadcast::channel(1);
352
353        // Start server in background
354        let socket_path_clone = socket_path.clone();
355        let stats_clone = stats.clone();
356        let server_handle = tokio::spawn(async move {
357            run_server(
358                &socket_path_clone,
359                stats_clone,
360                Some(shutdown_tx),
361                broadcast_rx,
362            )
363            .await
364        });
365
366        // Give server time to start
367        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
368
369        // Send ping command
370        let response = send_command(&socket_path, DaemonCommand::Ping)
371            .await
372            .expect("Failed to send command");
373
374        match response {
375            DaemonResponse::Pong => {}
376            _ => panic!("Expected Pong response"),
377        }
378
379        // Send status command
380        let response = send_command(&socket_path, DaemonCommand::Status)
381            .await
382            .expect("Failed to send command");
383
384        match response {
385            DaemonResponse::Status { running, .. } => {
386                assert!(running);
387            }
388            _ => panic!("Expected Status response"),
389        }
390
391        // Send stop command
392        let response = send_command(&socket_path, DaemonCommand::Stop)
393            .await
394            .expect("Failed to send command");
395
396        match response {
397            DaemonResponse::Stopping => {}
398            _ => panic!("Expected Stopping response"),
399        }
400
401        // Signal broadcast shutdown and wait for server to stop
402        let _ = broadcast_tx.send(());
403        let _ = tokio::time::timeout(tokio::time::Duration::from_secs(1), server_handle).await;
404    }
405}