lore_cli/daemon/
server.rs1use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::{UnixListener, UnixStream};
13use tokio::sync::{oneshot, RwLock};
14
15use super::state::DaemonStats;
16use crate::storage::Database;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(tag = "command", rename_all = "snake_case")]
21pub enum DaemonCommand {
22 Status,
24 Stop,
26 Stats,
28 Ping,
30 GetCurrentSession {
32 working_directory: String,
34 },
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum DaemonResponse {
41 Status {
43 running: bool,
44 pid: u32,
45 uptime_seconds: u64,
46 },
47 Stopping,
49 Stats(DaemonStats),
51 Pong,
53 CurrentSession {
55 session_id: Option<String>,
57 },
58 Error { message: String },
60}
61
62pub async fn run_server(
78 socket_path: &Path,
79 stats: Arc<RwLock<DaemonStats>>,
80 shutdown_tx: Option<oneshot::Sender<()>>,
81 mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
82) -> Result<()> {
83 if socket_path.exists() {
85 std::fs::remove_file(socket_path).context("Failed to remove existing socket file")?;
86 }
87
88 let listener = UnixListener::bind(socket_path).context("Failed to bind Unix socket")?;
89
90 tracing::info!("IPC server listening on {:?}", socket_path);
91
92 let shutdown_tx = Arc::new(std::sync::Mutex::new(shutdown_tx));
94
95 loop {
96 tokio::select! {
97 accept_result = listener.accept() => {
98 match accept_result {
99 Ok((stream, _addr)) => {
100 let stats_clone = stats.clone();
101 let shutdown_tx_clone = shutdown_tx.clone();
102 tokio::spawn(async move {
103 if let Err(e) = handle_connection(stream, stats_clone, shutdown_tx_clone).await {
104 tracing::warn!("Error handling IPC connection: {}", e);
105 }
106 });
107 }
108 Err(e) => {
109 tracing::warn!("Failed to accept connection: {}", e);
110 }
111 }
112 }
113 _ = shutdown_rx.recv() => {
114 tracing::info!("IPC server shutting down");
115 break;
116 }
117 }
118 }
119
120 Ok(())
121}
122
123async fn handle_connection(
125 stream: UnixStream,
126 stats: Arc<RwLock<DaemonStats>>,
127 shutdown_tx: Arc<std::sync::Mutex<Option<oneshot::Sender<()>>>>,
128) -> Result<()> {
129 let (reader, mut writer) = stream.into_split();
130 let mut reader = BufReader::new(reader);
131 let mut line = String::new();
132
133 reader
135 .read_line(&mut line)
136 .await
137 .context("Failed to read from socket")?;
138
139 let command: DaemonCommand =
140 serde_json::from_str(line.trim()).context("Failed to parse command")?;
141
142 tracing::debug!("Received IPC command: {:?}", command);
143
144 let response = match command {
145 DaemonCommand::Status => {
146 let stats_guard = stats.read().await;
147 let uptime = chrono::Utc::now()
148 .signed_duration_since(stats_guard.started_at)
149 .num_seconds() as u64;
150 DaemonResponse::Status {
151 running: true,
152 pid: std::process::id(),
153 uptime_seconds: uptime,
154 }
155 }
156 DaemonCommand::Stop => {
157 let mut guard = shutdown_tx
160 .lock()
161 .unwrap_or_else(|poisoned| poisoned.into_inner());
162 if let Some(tx) = guard.take() {
163 let _ = tx.send(());
164 }
165 DaemonResponse::Stopping
166 }
167 DaemonCommand::Stats => {
168 let stats_guard = stats.read().await;
169 DaemonResponse::Stats(stats_guard.clone())
170 }
171 DaemonCommand::Ping => DaemonResponse::Pong,
172 DaemonCommand::GetCurrentSession { working_directory } => {
173 match get_current_session_for_directory(&working_directory) {
175 Ok(session_id) => DaemonResponse::CurrentSession { session_id },
176 Err(e) => DaemonResponse::Error {
177 message: format!("Failed to get current session: {e}"),
178 },
179 }
180 }
181 };
182
183 let response_json = serde_json::to_string(&response).context("Failed to serialize response")?;
184
185 writer
186 .write_all(response_json.as_bytes())
187 .await
188 .context("Failed to write response")?;
189 writer
190 .write_all(b"\n")
191 .await
192 .context("Failed to write newline")?;
193 writer.flush().await.context("Failed to flush writer")?;
194
195 Ok(())
196}
197
198pub async fn send_command(socket_path: &Path, command: DaemonCommand) -> Result<DaemonResponse> {
212 let stream = UnixStream::connect(socket_path)
213 .await
214 .context("Failed to connect to daemon socket")?;
215
216 let (reader, mut writer) = stream.into_split();
217
218 let command_json = serde_json::to_string(&command).context("Failed to serialize command")?;
220 writer
221 .write_all(command_json.as_bytes())
222 .await
223 .context("Failed to write command")?;
224 writer
225 .write_all(b"\n")
226 .await
227 .context("Failed to write newline")?;
228 writer.flush().await.context("Failed to flush")?;
229
230 let mut reader = BufReader::new(reader);
232 let mut line = String::new();
233 reader
234 .read_line(&mut line)
235 .await
236 .context("Failed to read response")?;
237
238 let response: DaemonResponse =
239 serde_json::from_str(line.trim()).context("Failed to parse response")?;
240
241 Ok(response)
242}
243
244pub fn send_command_sync(socket_path: &Path, command: DaemonCommand) -> Result<DaemonResponse> {
249 let rt = tokio::runtime::Runtime::new().context("Failed to create tokio runtime")?;
250 rt.block_on(send_command(socket_path, command))
251}
252
253fn get_current_session_for_directory(working_dir: &str) -> Result<Option<String>> {
261 let db = Database::open_default()?;
262 let session = db.get_most_recent_session_for_directory(working_dir)?;
263 Ok(session.map(|s| s.id.to_string()))
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use tempfile::tempdir;
270
271 #[test]
272 fn test_daemon_command_serialization() {
273 let commands = vec![
274 DaemonCommand::Status,
275 DaemonCommand::Stop,
276 DaemonCommand::Stats,
277 DaemonCommand::Ping,
278 ];
279
280 for cmd in commands {
281 let json = serde_json::to_string(&cmd).expect("Failed to serialize");
282 let parsed: DaemonCommand = serde_json::from_str(&json).expect("Failed to parse");
283 let _ = parsed;
285 }
286 }
287
288 #[test]
289 fn test_daemon_response_status_serialization() {
290 let response = DaemonResponse::Status {
291 running: true,
292 pid: 12345,
293 uptime_seconds: 3600,
294 };
295
296 let json = serde_json::to_string(&response).expect("Failed to serialize");
297 assert!(json.contains("\"type\":\"status\""));
298 assert!(json.contains("\"running\":true"));
299 assert!(json.contains("\"pid\":12345"));
300
301 let parsed: DaemonResponse = serde_json::from_str(&json).expect("Failed to parse");
302 match parsed {
303 DaemonResponse::Status {
304 running,
305 pid,
306 uptime_seconds,
307 } => {
308 assert!(running);
309 assert_eq!(pid, 12345);
310 assert_eq!(uptime_seconds, 3600);
311 }
312 _ => panic!("Expected Status response"),
313 }
314 }
315
316 #[test]
317 fn test_daemon_response_stopping_serialization() {
318 let response = DaemonResponse::Stopping;
319 let json = serde_json::to_string(&response).expect("Failed to serialize");
320 assert!(json.contains("\"type\":\"stopping\""));
321 }
322
323 #[test]
324 fn test_daemon_response_stats_serialization() {
325 let stats = DaemonStats::default();
326 let response = DaemonResponse::Stats(stats);
327
328 let json = serde_json::to_string(&response).expect("Failed to serialize");
329 assert!(json.contains("\"type\":\"stats\""));
330 assert!(json.contains("\"files_watched\""));
331 }
332
333 #[test]
334 fn test_daemon_response_error_serialization() {
335 let response = DaemonResponse::Error {
336 message: "Something went wrong".to_string(),
337 };
338
339 let json = serde_json::to_string(&response).expect("Failed to serialize");
340 assert!(json.contains("\"type\":\"error\""));
341 assert!(json.contains("Something went wrong"));
342 }
343
344 #[tokio::test]
345 async fn test_server_client_communication() {
346 let dir = tempdir().expect("Failed to create temp dir");
347 let socket_path = dir.path().join("test.sock");
348
349 let stats = Arc::new(RwLock::new(DaemonStats::default()));
350 let (shutdown_tx, _shutdown_rx) = oneshot::channel();
351 let (broadcast_tx, broadcast_rx) = tokio::sync::broadcast::channel(1);
352
353 let socket_path_clone = socket_path.clone();
355 let stats_clone = stats.clone();
356 let server_handle = tokio::spawn(async move {
357 run_server(
358 &socket_path_clone,
359 stats_clone,
360 Some(shutdown_tx),
361 broadcast_rx,
362 )
363 .await
364 });
365
366 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
368
369 let response = send_command(&socket_path, DaemonCommand::Ping)
371 .await
372 .expect("Failed to send command");
373
374 match response {
375 DaemonResponse::Pong => {}
376 _ => panic!("Expected Pong response"),
377 }
378
379 let response = send_command(&socket_path, DaemonCommand::Status)
381 .await
382 .expect("Failed to send command");
383
384 match response {
385 DaemonResponse::Status { running, .. } => {
386 assert!(running);
387 }
388 _ => panic!("Expected Status response"),
389 }
390
391 let response = send_command(&socket_path, DaemonCommand::Stop)
393 .await
394 .expect("Failed to send command");
395
396 match response {
397 DaemonResponse::Stopping => {}
398 _ => panic!("Expected Stopping response"),
399 }
400
401 let _ = broadcast_tx.send(());
403 let _ = tokio::time::timeout(tokio::time::Duration::from_secs(1), server_handle).await;
404 }
405}