Skip to main content

ix_daemon/
server.rs

1use crate::queue::SyncQueue;
2use crate::{
3    Command, DEFAULT_IDLE_TIMEOUT_MS, DaemonError, EnqueueSyncPayload, EnqueueSyncResponse,
4    ErrorCode, PROTOCOL_VERSION, PingResponse, Request, Response, ResponsePayload,
5    ShutdownResponse, StatusPayload, StatusResponse, WaitSyncPayload, WaitSyncResponse,
6};
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::UnixListener;
13use tokio::sync::broadcast;
14
15const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
16
17pub struct Server {
18    socket_path: String,
19    idle_timeout_ms: u64,
20    start_time: Instant,
21    shutdown_tx: broadcast::Sender<()>,
22    queue: Arc<SyncQueue>,
23    last_activity: Arc<AtomicU64>,
24}
25
26impl Server {
27    pub fn new(socket_path: impl Into<String>) -> Self {
28        Self::with_idle_timeout(socket_path, DEFAULT_IDLE_TIMEOUT_MS)
29    }
30
31    pub fn with_idle_timeout(socket_path: impl Into<String>, idle_timeout_ms: u64) -> Self {
32        let (shutdown_tx, _) = broadcast::channel(1);
33        Self {
34            socket_path: socket_path.into(),
35            idle_timeout_ms,
36            start_time: Instant::now(),
37            shutdown_tx,
38            queue: Arc::new(SyncQueue::new()),
39            last_activity: Arc::new(AtomicU64::new(0)),
40        }
41    }
42
43    pub fn expanded_socket_path(&self) -> String {
44        expand_tilde(&self.socket_path)
45    }
46
47    fn touch_activity(&self) {
48        #[allow(clippy::cast_possible_truncation)]
49        let now = self.start_time.elapsed().as_millis() as u64;
50        self.last_activity.store(now, Ordering::Relaxed);
51    }
52
53    fn is_idle(&self) -> bool {
54        if self.idle_timeout_ms == 0 {
55            return false;
56        }
57
58        #[allow(clippy::cast_possible_truncation)]
59        let now = self.start_time.elapsed().as_millis() as u64;
60        let last = self.last_activity.load(Ordering::Relaxed);
61        now.saturating_sub(last) > self.idle_timeout_ms
62    }
63
64    // Event loop handling multiple async channels (commands, signals, idle timeout) requires
65    // this complexity; splitting would obscure the unified state machine logic.
66    #[allow(clippy::cognitive_complexity)]
67    pub async fn run(&self) -> Result<(), DaemonError> {
68        let socket_path = self.expanded_socket_path();
69
70        if let Some(parent) = Path::new(&socket_path).parent() {
71            tokio::fs::create_dir_all(parent).await?;
72        }
73
74        if Path::new(&socket_path).exists() {
75            tokio::fs::remove_file(&socket_path).await?;
76        }
77
78        let listener = UnixListener::bind(&socket_path)?;
79        tracing::info!("ixcheld listening on {}", socket_path);
80
81        if self.idle_timeout_ms > 0 {
82            tracing::info!("Idle timeout: {}ms", self.idle_timeout_ms);
83        }
84
85        self.touch_activity();
86
87        let mut shutdown_rx = self.shutdown_tx.subscribe();
88        let idle_check_interval = Duration::from_secs(10);
89
90        loop {
91            tokio::select! {
92                accept_result = listener.accept() => {
93                    match accept_result {
94                        Ok((stream, _)) => {
95                            self.touch_activity();
96                            let queue = Arc::clone(&self.queue);
97                            let start_time = self.start_time;
98                            let shutdown_tx = self.shutdown_tx.clone();
99                            let last_activity = Arc::clone(&self.last_activity);
100                            tokio::spawn(async move {
101                                if let Err(e) = handle_connection(stream, queue, start_time, shutdown_tx, last_activity).await {
102                                    tracing::error!("Connection error: {}", e);
103                                }
104                            });
105                        }
106                        Err(e) => {
107                            tracing::error!("Accept error: {}", e);
108                        }
109                    }
110                }
111                _ = shutdown_rx.recv() => {
112                    tracing::info!("Shutdown signal received");
113                    break;
114                }
115                () = tokio::time::sleep(idle_check_interval), if self.idle_timeout_ms > 0 => {
116                    if self.is_idle() && self.queue.list_queues().await.is_empty() {
117                        tracing::info!("Idle timeout reached, shutting down");
118                        break;
119                    }
120                }
121            }
122        }
123
124        let _ = tokio::fs::remove_file(&socket_path).await;
125        Ok(())
126    }
127
128    pub fn shutdown(&self) {
129        let _ = self.shutdown_tx.send(());
130    }
131}
132
133async fn handle_connection(
134    stream: tokio::net::UnixStream,
135    queue: Arc<SyncQueue>,
136    start_time: Instant,
137    shutdown_tx: broadcast::Sender<()>,
138    last_activity: Arc<AtomicU64>,
139) -> Result<(), DaemonError> {
140    let (reader, mut writer) = stream.into_split();
141    let mut reader = BufReader::new(reader);
142    let mut line = String::new();
143
144    loop {
145        line.clear();
146        let bytes_read = reader.read_line(&mut line).await?;
147
148        if bytes_read == 0 {
149            break;
150        }
151
152        #[allow(clippy::cast_possible_truncation)]
153        let now = start_time.elapsed().as_millis() as u64;
154        last_activity.store(now, Ordering::Relaxed);
155
156        if line.len() > MAX_MESSAGE_SIZE {
157            let resp = Response::error("", ErrorCode::InvalidRequest, "Message too large");
158            let json = serde_json::to_string(&resp)?;
159            writer.write_all(json.as_bytes()).await?;
160            writer.write_all(b"\n").await?;
161            continue;
162        }
163
164        let response = match serde_json::from_str::<Request>(line.trim()) {
165            Ok(req) => {
166                if req.version == PROTOCOL_VERSION {
167                    handle_command(&req, &queue, start_time, &shutdown_tx).await
168                } else {
169                    Response::error(
170                        &req.id,
171                        ErrorCode::IncompatibleVersion,
172                        format!(
173                            "Protocol version mismatch: expected {PROTOCOL_VERSION}, got {}",
174                            req.version
175                        ),
176                    )
177                }
178            }
179            Err(e) => Response::error("", ErrorCode::InvalidRequest, e.to_string()),
180        };
181
182        let json = serde_json::to_string(&response)?;
183        writer.write_all(json.as_bytes()).await?;
184        writer.write_all(b"\n").await?;
185        writer.flush().await?;
186    }
187
188    Ok(())
189}
190
191async fn handle_command(
192    req: &Request,
193    queue: &SyncQueue,
194    start_time: Instant,
195    shutdown_tx: &broadcast::Sender<()>,
196) -> Response {
197    match &req.command {
198        Command::Ping => Response::ok(
199            &req.id,
200            ResponsePayload::Ping(PingResponse {
201                daemon_version: env!("CARGO_PKG_VERSION").to_string(),
202            }),
203        ),
204
205        Command::EnqueueSync(EnqueueSyncPayload { directory, force }) => {
206            let (sync_id, _is_new) = queue
207                .enqueue(&req.repo_root, &req.tool, directory, *force)
208                .await;
209
210            queue.get(&sync_id).await.map_or_else(
211                || {
212                    Response::error(
213                        &req.id,
214                        ErrorCode::InternalError,
215                        "Failed to create sync job",
216                    )
217                },
218                |job| {
219                    Response::ok(
220                        &req.id,
221                        ResponsePayload::EnqueueSync(EnqueueSyncResponse {
222                            sync_id,
223                            queued_at_ms: job.queued_at_ms(),
224                        }),
225                    )
226                },
227            )
228        }
229
230        Command::WaitSync(WaitSyncPayload {
231            sync_id,
232            timeout_ms,
233        }) => {
234            let timeout = Duration::from_millis(*timeout_ms);
235
236            match queue.wait(sync_id, timeout).await {
237                Some(final_state) => {
238                    let job_stats = queue.get(sync_id).await.and_then(|j| j.stats);
239                    Response::ok(
240                        &req.id,
241                        ResponsePayload::WaitSync(WaitSyncResponse {
242                            sync_id: sync_id.clone(),
243                            state: final_state,
244                            stats: job_stats,
245                        }),
246                    )
247                }
248                None => Response::error(
249                    &req.id,
250                    ErrorCode::Timeout,
251                    format!("Timeout waiting for sync {sync_id}"),
252                ),
253            }
254        }
255
256        Command::Status(StatusPayload { .. }) => {
257            #[allow(clippy::cast_possible_truncation)]
258            let uptime_ms = start_time.elapsed().as_millis() as u64;
259            let queues = queue.list_queues().await;
260            Response::ok(
261                &req.id,
262                ResponsePayload::Status(StatusResponse { queues, uptime_ms }),
263            )
264        }
265
266        Command::Shutdown(payload) => {
267            tracing::info!("Shutdown requested: {}", payload.reason);
268            let _ = shutdown_tx.send(());
269            Response::ok(&req.id, ResponsePayload::Shutdown(ShutdownResponse {}))
270        }
271    }
272}
273
274fn expand_tilde(path: &str) -> String {
275    if let Some(rest) = path.strip_prefix("~/")
276        && let Some(home) = dirs_next::home_dir()
277    {
278        return home.join(rest).to_string_lossy().to_string();
279    }
280    path.to_string()
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_expand_tilde() {
289        let expanded = expand_tilde("~/.ixchel/run/ixcheld.sock");
290        assert!(!expanded.starts_with('~'));
291        assert!(expanded.contains(".ixchel/run/ixcheld.sock"));
292    }
293
294    #[test]
295    fn test_expand_tilde_no_tilde() {
296        let path = "/tmp/test.sock";
297        assert_eq!(expand_tilde(path), path);
298    }
299}