lore_cli/daemon/
server.rs

1//! Unix socket IPC server for daemon communication.
2//!
3//! Provides a simple request/response protocol over Unix domain sockets
4//! for communicating with the running daemon. Supports commands like
5//! status, stop, and stats.
6
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::{UnixListener, UnixStream};
13use tokio::sync::{oneshot, RwLock};
14
15use super::state::DaemonStats;
16use crate::storage::Database;
17
18/// Commands that can be sent to the daemon via IPC.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(tag = "command", rename_all = "snake_case")]
21pub enum DaemonCommand {
22    /// Request the daemon's current status.
23    Status,
24    /// Request the daemon to shut down gracefully.
25    Stop,
26    /// Request runtime statistics from the daemon.
27    Stats,
28    /// Ping to check if daemon is responsive.
29    Ping,
30    /// Request the current active session for a working directory.
31    GetCurrentSession {
32        /// The working directory to check for an active session.
33        working_directory: String,
34    },
35}
36
37/// Responses from the daemon to IPC commands.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum DaemonResponse {
41    /// Status response indicating daemon is running.
42    Status {
43        running: bool,
44        pid: u32,
45        uptime_seconds: u64,
46        version: String,
47    },
48    /// Acknowledgment that stop command was received.
49    Stopping,
50    /// Runtime statistics.
51    Stats(DaemonStats),
52    /// Ping response.
53    Pong,
54    /// Current session response.
55    CurrentSession {
56        /// The session ID if an active session was found.
57        session_id: Option<String>,
58    },
59    /// Error response.
60    Error { message: String },
61}
62
63/// Runs the IPC server on the given Unix socket path.
64///
65/// The server listens for incoming connections and processes commands
66/// until a shutdown signal is received or the Stop command is sent.
67///
68/// # Arguments
69///
70/// * `socket_path` - Path for the Unix domain socket
71/// * `stats` - Shared statistics that can be read by clients
72/// * `shutdown_tx` - Sender to signal daemon shutdown when Stop is received
73/// * `mut shutdown_rx` - Receiver that signals when to stop the server
74///
75/// # Errors
76///
77/// Returns an error if the socket cannot be created or bound.
78pub async fn run_server(
79    socket_path: &Path,
80    stats: Arc<RwLock<DaemonStats>>,
81    shutdown_tx: Option<oneshot::Sender<()>>,
82    mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
83) -> Result<()> {
84    // Remove existing socket file if present
85    if socket_path.exists() {
86        std::fs::remove_file(socket_path).context("Failed to remove existing socket file")?;
87    }
88
89    let listener = UnixListener::bind(socket_path).context("Failed to bind Unix socket")?;
90
91    tracing::info!("IPC server listening on {:?}", socket_path);
92
93    // Wrap shutdown_tx in Arc<Mutex> so it can be moved into the handler
94    let shutdown_tx = Arc::new(std::sync::Mutex::new(shutdown_tx));
95
96    loop {
97        tokio::select! {
98            accept_result = listener.accept() => {
99                match accept_result {
100                    Ok((stream, _addr)) => {
101                        let stats_clone = stats.clone();
102                        let shutdown_tx_clone = shutdown_tx.clone();
103                        tokio::spawn(async move {
104                            if let Err(e) = handle_connection(stream, stats_clone, shutdown_tx_clone).await {
105                                tracing::warn!("Error handling IPC connection: {}", e);
106                            }
107                        });
108                    }
109                    Err(e) => {
110                        tracing::warn!("Failed to accept connection: {}", e);
111                    }
112                }
113            }
114            _ = shutdown_rx.recv() => {
115                tracing::info!("IPC server shutting down");
116                break;
117            }
118        }
119    }
120
121    Ok(())
122}
123
124/// Handles a single client connection.
125async fn handle_connection(
126    stream: UnixStream,
127    stats: Arc<RwLock<DaemonStats>>,
128    shutdown_tx: Arc<std::sync::Mutex<Option<oneshot::Sender<()>>>>,
129) -> Result<()> {
130    let (reader, mut writer) = stream.into_split();
131    let mut reader = BufReader::new(reader);
132    let mut line = String::new();
133
134    // Read a single line (one command per connection)
135    reader
136        .read_line(&mut line)
137        .await
138        .context("Failed to read from socket")?;
139
140    let command: DaemonCommand =
141        serde_json::from_str(line.trim()).context("Failed to parse command")?;
142
143    tracing::debug!("Received IPC command: {:?}", command);
144
145    let response = match command {
146        DaemonCommand::Status => {
147            let stats_guard = stats.read().await;
148            let uptime = chrono::Utc::now()
149                .signed_duration_since(stats_guard.started_at)
150                .num_seconds() as u64;
151            DaemonResponse::Status {
152                running: true,
153                pid: std::process::id(),
154                uptime_seconds: uptime,
155                version: env!("CARGO_PKG_VERSION").to_string(),
156            }
157        }
158        DaemonCommand::Stop => {
159            // Signal the daemon to shut down
160            // If the lock is poisoned, we still want to try to shut down
161            let mut guard = shutdown_tx
162                .lock()
163                .unwrap_or_else(|poisoned| poisoned.into_inner());
164            if let Some(tx) = guard.take() {
165                let _ = tx.send(());
166            }
167            DaemonResponse::Stopping
168        }
169        DaemonCommand::Stats => {
170            let stats_guard = stats.read().await;
171            DaemonResponse::Stats(stats_guard.clone())
172        }
173        DaemonCommand::Ping => DaemonResponse::Pong,
174        DaemonCommand::GetCurrentSession { working_directory } => {
175            // Query the database for the most recent session in this directory
176            match get_current_session_for_directory(&working_directory) {
177                Ok(session_id) => DaemonResponse::CurrentSession { session_id },
178                Err(e) => DaemonResponse::Error {
179                    message: format!("Failed to get current session: {e}"),
180                },
181            }
182        }
183    };
184
185    let response_json = serde_json::to_string(&response).context("Failed to serialize response")?;
186
187    writer
188        .write_all(response_json.as_bytes())
189        .await
190        .context("Failed to write response")?;
191    writer
192        .write_all(b"\n")
193        .await
194        .context("Failed to write newline")?;
195    writer.flush().await.context("Failed to flush writer")?;
196
197    Ok(())
198}
199
200/// Sends a command to the daemon and returns the response.
201///
202/// Connects to the Unix socket, sends the command, and reads the response.
203///
204/// # Arguments
205///
206/// * `socket_path` - Path to the daemon's Unix socket
207/// * `command` - The command to send
208///
209/// # Errors
210///
211/// Returns an error if the connection fails, the command cannot be sent,
212/// or the response cannot be read or parsed.
213pub async fn send_command(socket_path: &Path, command: DaemonCommand) -> Result<DaemonResponse> {
214    let stream = UnixStream::connect(socket_path)
215        .await
216        .context("Failed to connect to daemon socket")?;
217
218    let (reader, mut writer) = stream.into_split();
219
220    // Send command
221    let command_json = serde_json::to_string(&command).context("Failed to serialize command")?;
222    writer
223        .write_all(command_json.as_bytes())
224        .await
225        .context("Failed to write command")?;
226    writer
227        .write_all(b"\n")
228        .await
229        .context("Failed to write newline")?;
230    writer.flush().await.context("Failed to flush")?;
231
232    // Read response
233    let mut reader = BufReader::new(reader);
234    let mut line = String::new();
235    reader
236        .read_line(&mut line)
237        .await
238        .context("Failed to read response")?;
239
240    let response: DaemonResponse =
241        serde_json::from_str(line.trim()).context("Failed to parse response")?;
242
243    Ok(response)
244}
245
246/// Synchronous wrapper for sending a command to the daemon.
247///
248/// Creates a temporary tokio runtime to send the command.
249/// Use this from non-async contexts like CLI commands.
250pub fn send_command_sync(socket_path: &Path, command: DaemonCommand) -> Result<DaemonResponse> {
251    let rt = tokio::runtime::Runtime::new().context("Failed to create tokio runtime")?;
252    rt.block_on(send_command(socket_path, command))
253}
254
255/// Gets the current active session for a working directory.
256///
257/// Queries the database for the most recent session whose working directory
258/// matches or contains the given path. This is used by the daemon to respond
259/// to GetCurrentSession IPC requests.
260///
261/// Returns the session ID as a string if found, or None if no matching session exists.
262fn get_current_session_for_directory(working_dir: &str) -> Result<Option<String>> {
263    let db = Database::open_default()?;
264    let session = db.get_most_recent_session_for_directory(working_dir)?;
265    Ok(session.map(|s| s.id.to_string()))
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use tempfile::tempdir;
272
273    #[test]
274    fn test_daemon_command_serialization() {
275        let commands = vec![
276            DaemonCommand::Status,
277            DaemonCommand::Stop,
278            DaemonCommand::Stats,
279            DaemonCommand::Ping,
280        ];
281
282        for cmd in commands {
283            let json = serde_json::to_string(&cmd).expect("Failed to serialize");
284            let parsed: DaemonCommand = serde_json::from_str(&json).expect("Failed to parse");
285            // Just verify round-trip works (can't compare Debug output reliably)
286            let _ = parsed;
287        }
288    }
289
290    #[test]
291    fn test_daemon_response_status_serialization() {
292        let response = DaemonResponse::Status {
293            running: true,
294            pid: 12345,
295            uptime_seconds: 3600,
296            version: "0.1.6".to_string(),
297        };
298
299        let json = serde_json::to_string(&response).expect("Failed to serialize");
300        assert!(json.contains("\"type\":\"status\""));
301        assert!(json.contains("\"running\":true"));
302        assert!(json.contains("\"pid\":12345"));
303        assert!(json.contains("\"version\":\"0.1.6\""));
304
305        let parsed: DaemonResponse = serde_json::from_str(&json).expect("Failed to parse");
306        match parsed {
307            DaemonResponse::Status {
308                running,
309                pid,
310                uptime_seconds,
311                version,
312            } => {
313                assert!(running);
314                assert_eq!(pid, 12345);
315                assert_eq!(uptime_seconds, 3600);
316                assert_eq!(version, "0.1.6");
317            }
318            _ => panic!("Expected Status response"),
319        }
320    }
321
322    #[test]
323    fn test_daemon_response_stopping_serialization() {
324        let response = DaemonResponse::Stopping;
325        let json = serde_json::to_string(&response).expect("Failed to serialize");
326        assert!(json.contains("\"type\":\"stopping\""));
327    }
328
329    #[test]
330    fn test_daemon_response_stats_serialization() {
331        let stats = DaemonStats::default();
332        let response = DaemonResponse::Stats(stats);
333
334        let json = serde_json::to_string(&response).expect("Failed to serialize");
335        assert!(json.contains("\"type\":\"stats\""));
336        assert!(json.contains("\"files_watched\""));
337    }
338
339    #[test]
340    fn test_daemon_response_error_serialization() {
341        let response = DaemonResponse::Error {
342            message: "Something went wrong".to_string(),
343        };
344
345        let json = serde_json::to_string(&response).expect("Failed to serialize");
346        assert!(json.contains("\"type\":\"error\""));
347        assert!(json.contains("Something went wrong"));
348    }
349
350    #[tokio::test]
351    async fn test_server_client_communication() {
352        let dir = tempdir().expect("Failed to create temp dir");
353        let socket_path = dir.path().join("test.sock");
354
355        let stats = Arc::new(RwLock::new(DaemonStats::default()));
356        let (shutdown_tx, _shutdown_rx) = oneshot::channel();
357        let (broadcast_tx, broadcast_rx) = tokio::sync::broadcast::channel(1);
358
359        // Start server in background
360        let socket_path_clone = socket_path.clone();
361        let stats_clone = stats.clone();
362        let server_handle = tokio::spawn(async move {
363            run_server(
364                &socket_path_clone,
365                stats_clone,
366                Some(shutdown_tx),
367                broadcast_rx,
368            )
369            .await
370        });
371
372        // Give server time to start
373        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
374
375        // Send ping command
376        let response = send_command(&socket_path, DaemonCommand::Ping)
377            .await
378            .expect("Failed to send command");
379
380        match response {
381            DaemonResponse::Pong => {}
382            _ => panic!("Expected Pong response"),
383        }
384
385        // Send status command
386        let response = send_command(&socket_path, DaemonCommand::Status)
387            .await
388            .expect("Failed to send command");
389
390        match response {
391            DaemonResponse::Status { running, .. } => {
392                assert!(running);
393            }
394            _ => panic!("Expected Status response"),
395        }
396
397        // Send stop command
398        let response = send_command(&socket_path, DaemonCommand::Stop)
399            .await
400            .expect("Failed to send command");
401
402        match response {
403            DaemonResponse::Stopping => {}
404            _ => panic!("Expected Stopping response"),
405        }
406
407        // Signal broadcast shutdown and wait for server to stop
408        let _ = broadcast_tx.send(());
409        let _ = tokio::time::timeout(tokio::time::Duration::from_secs(1), server_handle).await;
410    }
411}