1use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::{UnixListener, UnixStream};
13use tokio::sync::{oneshot, RwLock};
14
15use super::state::DaemonStats;
16use crate::storage::Database;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(tag = "command", rename_all = "snake_case")]
21pub enum DaemonCommand {
22 Status,
24 Stop,
26 Stats,
28 Ping,
30 GetCurrentSession {
32 working_directory: String,
34 },
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", rename_all = "snake_case")]
40pub enum DaemonResponse {
41 Status {
43 running: bool,
44 pid: u32,
45 uptime_seconds: u64,
46 version: String,
47 },
48 Stopping,
50 Stats(DaemonStats),
52 Pong,
54 CurrentSession {
56 session_id: Option<String>,
58 },
59 Error { message: String },
61}
62
63pub async fn run_server(
79 socket_path: &Path,
80 stats: Arc<RwLock<DaemonStats>>,
81 shutdown_tx: Option<oneshot::Sender<()>>,
82 mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
83) -> Result<()> {
84 if socket_path.exists() {
86 std::fs::remove_file(socket_path).context("Failed to remove existing socket file")?;
87 }
88
89 let listener = UnixListener::bind(socket_path).context("Failed to bind Unix socket")?;
90
91 tracing::info!("IPC server listening on {:?}", socket_path);
92
93 let shutdown_tx = Arc::new(std::sync::Mutex::new(shutdown_tx));
95
96 loop {
97 tokio::select! {
98 accept_result = listener.accept() => {
99 match accept_result {
100 Ok((stream, _addr)) => {
101 let stats_clone = stats.clone();
102 let shutdown_tx_clone = shutdown_tx.clone();
103 tokio::spawn(async move {
104 if let Err(e) = handle_connection(stream, stats_clone, shutdown_tx_clone).await {
105 tracing::warn!("Error handling IPC connection: {}", e);
106 }
107 });
108 }
109 Err(e) => {
110 tracing::warn!("Failed to accept connection: {}", e);
111 }
112 }
113 }
114 _ = shutdown_rx.recv() => {
115 tracing::info!("IPC server shutting down");
116 break;
117 }
118 }
119 }
120
121 Ok(())
122}
123
124async fn handle_connection(
126 stream: UnixStream,
127 stats: Arc<RwLock<DaemonStats>>,
128 shutdown_tx: Arc<std::sync::Mutex<Option<oneshot::Sender<()>>>>,
129) -> Result<()> {
130 let (reader, mut writer) = stream.into_split();
131 let mut reader = BufReader::new(reader);
132 let mut line = String::new();
133
134 reader
136 .read_line(&mut line)
137 .await
138 .context("Failed to read from socket")?;
139
140 let command: DaemonCommand =
141 serde_json::from_str(line.trim()).context("Failed to parse command")?;
142
143 tracing::debug!("Received IPC command: {:?}", command);
144
145 let response = match command {
146 DaemonCommand::Status => {
147 let stats_guard = stats.read().await;
148 let uptime = chrono::Utc::now()
149 .signed_duration_since(stats_guard.started_at)
150 .num_seconds() as u64;
151 DaemonResponse::Status {
152 running: true,
153 pid: std::process::id(),
154 uptime_seconds: uptime,
155 version: env!("CARGO_PKG_VERSION").to_string(),
156 }
157 }
158 DaemonCommand::Stop => {
159 let mut guard = shutdown_tx
162 .lock()
163 .unwrap_or_else(|poisoned| poisoned.into_inner());
164 if let Some(tx) = guard.take() {
165 let _ = tx.send(());
166 }
167 DaemonResponse::Stopping
168 }
169 DaemonCommand::Stats => {
170 let stats_guard = stats.read().await;
171 DaemonResponse::Stats(stats_guard.clone())
172 }
173 DaemonCommand::Ping => DaemonResponse::Pong,
174 DaemonCommand::GetCurrentSession { working_directory } => {
175 match get_current_session_for_directory(&working_directory) {
177 Ok(session_id) => DaemonResponse::CurrentSession { session_id },
178 Err(e) => DaemonResponse::Error {
179 message: format!("Failed to get current session: {e}"),
180 },
181 }
182 }
183 };
184
185 let response_json = serde_json::to_string(&response).context("Failed to serialize response")?;
186
187 writer
188 .write_all(response_json.as_bytes())
189 .await
190 .context("Failed to write response")?;
191 writer
192 .write_all(b"\n")
193 .await
194 .context("Failed to write newline")?;
195 writer.flush().await.context("Failed to flush writer")?;
196
197 Ok(())
198}
199
200pub async fn send_command(socket_path: &Path, command: DaemonCommand) -> Result<DaemonResponse> {
214 let stream = UnixStream::connect(socket_path)
215 .await
216 .context("Failed to connect to daemon socket")?;
217
218 let (reader, mut writer) = stream.into_split();
219
220 let command_json = serde_json::to_string(&command).context("Failed to serialize command")?;
222 writer
223 .write_all(command_json.as_bytes())
224 .await
225 .context("Failed to write command")?;
226 writer
227 .write_all(b"\n")
228 .await
229 .context("Failed to write newline")?;
230 writer.flush().await.context("Failed to flush")?;
231
232 let mut reader = BufReader::new(reader);
234 let mut line = String::new();
235 reader
236 .read_line(&mut line)
237 .await
238 .context("Failed to read response")?;
239
240 let response: DaemonResponse =
241 serde_json::from_str(line.trim()).context("Failed to parse response")?;
242
243 Ok(response)
244}
245
246pub fn send_command_sync(socket_path: &Path, command: DaemonCommand) -> Result<DaemonResponse> {
251 let rt = tokio::runtime::Runtime::new().context("Failed to create tokio runtime")?;
252 rt.block_on(send_command(socket_path, command))
253}
254
255fn get_current_session_for_directory(working_dir: &str) -> Result<Option<String>> {
263 let db = Database::open_default()?;
264 let session = db.get_most_recent_session_for_directory(working_dir)?;
265 Ok(session.map(|s| s.id.to_string()))
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use tempfile::tempdir;
272
273 #[test]
274 fn test_daemon_command_serialization() {
275 let commands = vec![
276 DaemonCommand::Status,
277 DaemonCommand::Stop,
278 DaemonCommand::Stats,
279 DaemonCommand::Ping,
280 ];
281
282 for cmd in commands {
283 let json = serde_json::to_string(&cmd).expect("Failed to serialize");
284 let parsed: DaemonCommand = serde_json::from_str(&json).expect("Failed to parse");
285 let _ = parsed;
287 }
288 }
289
290 #[test]
291 fn test_daemon_response_status_serialization() {
292 let response = DaemonResponse::Status {
293 running: true,
294 pid: 12345,
295 uptime_seconds: 3600,
296 version: "0.1.6".to_string(),
297 };
298
299 let json = serde_json::to_string(&response).expect("Failed to serialize");
300 assert!(json.contains("\"type\":\"status\""));
301 assert!(json.contains("\"running\":true"));
302 assert!(json.contains("\"pid\":12345"));
303 assert!(json.contains("\"version\":\"0.1.6\""));
304
305 let parsed: DaemonResponse = serde_json::from_str(&json).expect("Failed to parse");
306 match parsed {
307 DaemonResponse::Status {
308 running,
309 pid,
310 uptime_seconds,
311 version,
312 } => {
313 assert!(running);
314 assert_eq!(pid, 12345);
315 assert_eq!(uptime_seconds, 3600);
316 assert_eq!(version, "0.1.6");
317 }
318 _ => panic!("Expected Status response"),
319 }
320 }
321
322 #[test]
323 fn test_daemon_response_stopping_serialization() {
324 let response = DaemonResponse::Stopping;
325 let json = serde_json::to_string(&response).expect("Failed to serialize");
326 assert!(json.contains("\"type\":\"stopping\""));
327 }
328
329 #[test]
330 fn test_daemon_response_stats_serialization() {
331 let stats = DaemonStats::default();
332 let response = DaemonResponse::Stats(stats);
333
334 let json = serde_json::to_string(&response).expect("Failed to serialize");
335 assert!(json.contains("\"type\":\"stats\""));
336 assert!(json.contains("\"files_watched\""));
337 }
338
339 #[test]
340 fn test_daemon_response_error_serialization() {
341 let response = DaemonResponse::Error {
342 message: "Something went wrong".to_string(),
343 };
344
345 let json = serde_json::to_string(&response).expect("Failed to serialize");
346 assert!(json.contains("\"type\":\"error\""));
347 assert!(json.contains("Something went wrong"));
348 }
349
350 #[tokio::test]
351 async fn test_server_client_communication() {
352 let dir = tempdir().expect("Failed to create temp dir");
353 let socket_path = dir.path().join("test.sock");
354
355 let stats = Arc::new(RwLock::new(DaemonStats::default()));
356 let (shutdown_tx, _shutdown_rx) = oneshot::channel();
357 let (broadcast_tx, broadcast_rx) = tokio::sync::broadcast::channel(1);
358
359 let socket_path_clone = socket_path.clone();
361 let stats_clone = stats.clone();
362 let server_handle = tokio::spawn(async move {
363 run_server(
364 &socket_path_clone,
365 stats_clone,
366 Some(shutdown_tx),
367 broadcast_rx,
368 )
369 .await
370 });
371
372 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
374
375 let response = send_command(&socket_path, DaemonCommand::Ping)
377 .await
378 .expect("Failed to send command");
379
380 match response {
381 DaemonResponse::Pong => {}
382 _ => panic!("Expected Pong response"),
383 }
384
385 let response = send_command(&socket_path, DaemonCommand::Status)
387 .await
388 .expect("Failed to send command");
389
390 match response {
391 DaemonResponse::Status { running, .. } => {
392 assert!(running);
393 }
394 _ => panic!("Expected Status response"),
395 }
396
397 let response = send_command(&socket_path, DaemonCommand::Stop)
399 .await
400 .expect("Failed to send command");
401
402 match response {
403 DaemonResponse::Stopping => {}
404 _ => panic!("Expected Stopping response"),
405 }
406
407 let _ = broadcast_tx.send(());
409 let _ = tokio::time::timeout(tokio::time::Duration::from_secs(1), server_handle).await;
410 }
411}