1use crate::queue::SyncQueue;
2use crate::{
3 Command, DEFAULT_IDLE_TIMEOUT_MS, DaemonError, EnqueueSyncPayload, EnqueueSyncResponse,
4 ErrorCode, PROTOCOL_VERSION, PingResponse, Request, Response, ResponsePayload,
5 ShutdownResponse, StatusPayload, StatusResponse, WaitSyncPayload, WaitSyncResponse,
6};
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::UnixListener;
13use tokio::sync::broadcast;
14
15const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
16
17pub struct Server {
18 socket_path: String,
19 idle_timeout_ms: u64,
20 start_time: Instant,
21 shutdown_tx: broadcast::Sender<()>,
22 queue: Arc<SyncQueue>,
23 last_activity: Arc<AtomicU64>,
24}
25
26impl Server {
27 pub fn new(socket_path: impl Into<String>) -> Self {
28 Self::with_idle_timeout(socket_path, DEFAULT_IDLE_TIMEOUT_MS)
29 }
30
31 pub fn with_idle_timeout(socket_path: impl Into<String>, idle_timeout_ms: u64) -> Self {
32 let (shutdown_tx, _) = broadcast::channel(1);
33 Self {
34 socket_path: socket_path.into(),
35 idle_timeout_ms,
36 start_time: Instant::now(),
37 shutdown_tx,
38 queue: Arc::new(SyncQueue::new()),
39 last_activity: Arc::new(AtomicU64::new(0)),
40 }
41 }
42
43 pub fn expanded_socket_path(&self) -> String {
44 expand_tilde(&self.socket_path)
45 }
46
47 fn touch_activity(&self) {
48 #[allow(clippy::cast_possible_truncation)]
49 let now = self.start_time.elapsed().as_millis() as u64;
50 self.last_activity.store(now, Ordering::Relaxed);
51 }
52
53 fn is_idle(&self) -> bool {
54 if self.idle_timeout_ms == 0 {
55 return false;
56 }
57
58 #[allow(clippy::cast_possible_truncation)]
59 let now = self.start_time.elapsed().as_millis() as u64;
60 let last = self.last_activity.load(Ordering::Relaxed);
61 now.saturating_sub(last) > self.idle_timeout_ms
62 }
63
64 #[allow(clippy::cognitive_complexity)]
67 pub async fn run(&self) -> Result<(), DaemonError> {
68 let socket_path = self.expanded_socket_path();
69
70 if let Some(parent) = Path::new(&socket_path).parent() {
71 tokio::fs::create_dir_all(parent).await?;
72 }
73
74 if Path::new(&socket_path).exists() {
75 tokio::fs::remove_file(&socket_path).await?;
76 }
77
78 let listener = UnixListener::bind(&socket_path)?;
79 tracing::info!("ixcheld listening on {}", socket_path);
80
81 if self.idle_timeout_ms > 0 {
82 tracing::info!("Idle timeout: {}ms", self.idle_timeout_ms);
83 }
84
85 self.touch_activity();
86
87 let mut shutdown_rx = self.shutdown_tx.subscribe();
88 let idle_check_interval = Duration::from_secs(10);
89
90 loop {
91 tokio::select! {
92 accept_result = listener.accept() => {
93 match accept_result {
94 Ok((stream, _)) => {
95 self.touch_activity();
96 let queue = Arc::clone(&self.queue);
97 let start_time = self.start_time;
98 let shutdown_tx = self.shutdown_tx.clone();
99 let last_activity = Arc::clone(&self.last_activity);
100 tokio::spawn(async move {
101 if let Err(e) = handle_connection(stream, queue, start_time, shutdown_tx, last_activity).await {
102 tracing::error!("Connection error: {}", e);
103 }
104 });
105 }
106 Err(e) => {
107 tracing::error!("Accept error: {}", e);
108 }
109 }
110 }
111 _ = shutdown_rx.recv() => {
112 tracing::info!("Shutdown signal received");
113 break;
114 }
115 () = tokio::time::sleep(idle_check_interval), if self.idle_timeout_ms > 0 => {
116 if self.is_idle() && self.queue.list_queues().await.is_empty() {
117 tracing::info!("Idle timeout reached, shutting down");
118 break;
119 }
120 }
121 }
122 }
123
124 let _ = tokio::fs::remove_file(&socket_path).await;
125 Ok(())
126 }
127
128 pub fn shutdown(&self) {
129 let _ = self.shutdown_tx.send(());
130 }
131}
132
133async fn handle_connection(
134 stream: tokio::net::UnixStream,
135 queue: Arc<SyncQueue>,
136 start_time: Instant,
137 shutdown_tx: broadcast::Sender<()>,
138 last_activity: Arc<AtomicU64>,
139) -> Result<(), DaemonError> {
140 let (reader, mut writer) = stream.into_split();
141 let mut reader = BufReader::new(reader);
142 let mut line = String::new();
143
144 loop {
145 line.clear();
146 let bytes_read = reader.read_line(&mut line).await?;
147
148 if bytes_read == 0 {
149 break;
150 }
151
152 #[allow(clippy::cast_possible_truncation)]
153 let now = start_time.elapsed().as_millis() as u64;
154 last_activity.store(now, Ordering::Relaxed);
155
156 if line.len() > MAX_MESSAGE_SIZE {
157 let resp = Response::error("", ErrorCode::InvalidRequest, "Message too large");
158 let json = serde_json::to_string(&resp)?;
159 writer.write_all(json.as_bytes()).await?;
160 writer.write_all(b"\n").await?;
161 continue;
162 }
163
164 let response = match serde_json::from_str::<Request>(line.trim()) {
165 Ok(req) => {
166 if req.version == PROTOCOL_VERSION {
167 handle_command(&req, &queue, start_time, &shutdown_tx).await
168 } else {
169 Response::error(
170 &req.id,
171 ErrorCode::IncompatibleVersion,
172 format!(
173 "Protocol version mismatch: expected {PROTOCOL_VERSION}, got {}",
174 req.version
175 ),
176 )
177 }
178 }
179 Err(e) => Response::error("", ErrorCode::InvalidRequest, e.to_string()),
180 };
181
182 let json = serde_json::to_string(&response)?;
183 writer.write_all(json.as_bytes()).await?;
184 writer.write_all(b"\n").await?;
185 writer.flush().await?;
186 }
187
188 Ok(())
189}
190
191async fn handle_command(
192 req: &Request,
193 queue: &SyncQueue,
194 start_time: Instant,
195 shutdown_tx: &broadcast::Sender<()>,
196) -> Response {
197 match &req.command {
198 Command::Ping => Response::ok(
199 &req.id,
200 ResponsePayload::Ping(PingResponse {
201 daemon_version: env!("CARGO_PKG_VERSION").to_string(),
202 }),
203 ),
204
205 Command::EnqueueSync(EnqueueSyncPayload { directory, force }) => {
206 let (sync_id, _is_new) = queue
207 .enqueue(&req.repo_root, &req.tool, directory, *force)
208 .await;
209
210 queue.get(&sync_id).await.map_or_else(
211 || {
212 Response::error(
213 &req.id,
214 ErrorCode::InternalError,
215 "Failed to create sync job",
216 )
217 },
218 |job| {
219 Response::ok(
220 &req.id,
221 ResponsePayload::EnqueueSync(EnqueueSyncResponse {
222 sync_id,
223 queued_at_ms: job.queued_at_ms(),
224 }),
225 )
226 },
227 )
228 }
229
230 Command::WaitSync(WaitSyncPayload {
231 sync_id,
232 timeout_ms,
233 }) => {
234 let timeout = Duration::from_millis(*timeout_ms);
235
236 match queue.wait(sync_id, timeout).await {
237 Some(final_state) => {
238 let job_stats = queue.get(sync_id).await.and_then(|j| j.stats);
239 Response::ok(
240 &req.id,
241 ResponsePayload::WaitSync(WaitSyncResponse {
242 sync_id: sync_id.clone(),
243 state: final_state,
244 stats: job_stats,
245 }),
246 )
247 }
248 None => Response::error(
249 &req.id,
250 ErrorCode::Timeout,
251 format!("Timeout waiting for sync {sync_id}"),
252 ),
253 }
254 }
255
256 Command::Status(StatusPayload { .. }) => {
257 #[allow(clippy::cast_possible_truncation)]
258 let uptime_ms = start_time.elapsed().as_millis() as u64;
259 let queues = queue.list_queues().await;
260 Response::ok(
261 &req.id,
262 ResponsePayload::Status(StatusResponse { queues, uptime_ms }),
263 )
264 }
265
266 Command::Shutdown(payload) => {
267 tracing::info!("Shutdown requested: {}", payload.reason);
268 let _ = shutdown_tx.send(());
269 Response::ok(&req.id, ResponsePayload::Shutdown(ShutdownResponse {}))
270 }
271 }
272}
273
274fn expand_tilde(path: &str) -> String {
275 if let Some(rest) = path.strip_prefix("~/")
276 && let Some(home) = dirs_next::home_dir()
277 {
278 return home.join(rest).to_string_lossy().to_string();
279 }
280 path.to_string()
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_expand_tilde() {
289 let expanded = expand_tilde("~/.ixchel/run/ixcheld.sock");
290 assert!(!expanded.starts_with('~'));
291 assert!(expanded.contains(".ixchel/run/ixcheld.sock"));
292 }
293
294 #[test]
295 fn test_expand_tilde_no_tilde() {
296 let path = "/tmp/test.sock";
297 assert_eq!(expand_tilde(path), path);
298 }
299}