1use std::collections::HashMap;
10use std::future::Future;
11use std::path::{Path, PathBuf};
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
15
16use libgrite_ipc::{
17 framing::{read_framed_async, write_framed_async},
18 messages::{ArchivedIpcRequest, IpcRequest, IpcResponse},
19 IpcCommand, Notification, IPC_SCHEMA_VERSION,
20};
21use tokio::net::{UnixListener, UnixStream};
22use tokio::sync::{mpsc, Mutex, Semaphore};
23use tracing::{debug, info, warn};
24
25use crate::error::DaemonError;
26use crate::state::{AtomicSupervisorState, SupervisorState};
27use crate::worker::{Worker, WorkerMessage};
28
29const MAX_CONNECTIONS: usize = 256;
31
32struct WorkerHandle {
34 tx: mpsc::Sender<WorkerMessage>,
35 join_handle: Option<tokio::task::JoinHandle<()>>,
36 repo_root: PathBuf,
37 #[allow(dead_code)]
38 state: Option<Arc<crate::state::AtomicWorkerState>>,
39}
40
41#[derive(Hash, Eq, PartialEq, Clone)]
43struct WorkerKey {
44 repo_root: String,
45}
46
47struct DaemonState {
52 daemon_id: String,
53 host_id: String,
54 pid: u32,
55 started_ts: u64,
56 socket_path: String,
57 workers: Mutex<HashMap<WorkerKey, WorkerHandle>>,
58 notify_tx: mpsc::Sender<Notification>,
59 shutdown_tx: tokio::sync::broadcast::Sender<()>,
60 conn_semaphore: Arc<Semaphore>,
61 last_activity_ms: AtomicU64,
62 start_instant: Instant,
63 idle_timeout: Option<Duration>,
64 supervisor_state: AtomicSupervisorState,
65}
66
67impl DaemonState {
68 fn touch_activity(&self) {
69 let elapsed_ms = self.start_instant.elapsed().as_millis() as u64;
70 self.last_activity_ms.store(elapsed_ms, Ordering::Relaxed);
71 }
72}
73
74pub struct Supervisor {
76 state: Arc<DaemonState>,
77 notify_rx: mpsc::Receiver<Notification>,
78}
79
80impl Supervisor {
81 pub fn new(socket_path: String, idle_timeout: Option<Duration>) -> Self {
83 let (notify_tx, notify_rx) = mpsc::channel(1000);
84 let (shutdown_tx, _) = tokio::sync::broadcast::channel::<()>(1);
85 let start_instant = Instant::now();
86
87 let started_ts = SystemTime::now()
88 .duration_since(UNIX_EPOCH)
89 .unwrap_or_default()
90 .as_millis() as u64;
91
92 let state = Arc::new(DaemonState {
93 daemon_id: uuid::Uuid::new_v4().to_string(),
94 host_id: get_host_id(),
95 pid: std::process::id(),
96 started_ts,
97 socket_path,
98 workers: Mutex::new(HashMap::new()),
99 notify_tx,
100 shutdown_tx,
101 conn_semaphore: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
102 last_activity_ms: AtomicU64::new(0),
103 start_instant,
104 idle_timeout,
105 supervisor_state: AtomicSupervisorState::new(SupervisorState::Starting),
106 });
107
108 Self { state, notify_rx }
109 }
110
111 pub async fn run(
119 mut self,
120 shutdown_signal: impl Future<Output = ()> + Send,
121 ) -> Result<(), DaemonError> {
122 info!(
123 daemon_id = %self.state.daemon_id,
124 socket_path = %self.state.socket_path,
125 idle_timeout_secs = ?self.state.idle_timeout.map(|d| d.as_secs()),
126 "Supervisor starting"
127 );
128
129 self.state.touch_activity();
131
132 let socket_path = Path::new(&self.state.socket_path);
134 if socket_path.exists() {
135 if std::os::unix::net::UnixStream::connect(socket_path).is_ok() {
136 return Err(DaemonError::BindFailed(format!(
137 "Another supervisor is already listening on {}",
138 self.state.socket_path,
139 )));
140 }
141 std::fs::remove_file(socket_path).map_err(|e| {
142 DaemonError::BindFailed(format!(
143 "Failed to remove stale socket {}: {}",
144 self.state.socket_path, e
145 ))
146 })?;
147 }
148
149 let listener = UnixListener::bind(&self.state.socket_path).map_err(|e| {
151 DaemonError::BindFailed(format!(
152 "Failed to bind to {}: {}",
153 self.state.socket_path, e
154 ))
155 })?;
156
157 info!("Listening on {}", self.state.socket_path);
158 self.state
159 .supervisor_state
160 .transition(SupervisorState::Running, Ordering::SeqCst)
161 .ok();
162
163 let state_hb = self.state.clone();
165 let mut heartbeat_shutdown = self.state.shutdown_tx.subscribe();
166 tokio::spawn(async move {
167 let mut interval = tokio::time::interval(Duration::from_secs(10));
168 loop {
169 tokio::select! {
170 _ = interval.tick() => {
171 let workers = state_hb.workers.lock().await;
173 for handle in workers.values() {
174 let _ = handle.tx.send(WorkerMessage::Heartbeat).await;
175 }
176 drop(workers);
177
178 if let Some(timeout) = state_hb.idle_timeout {
180 let last_ms = state_hb.last_activity_ms.load(Ordering::Relaxed);
181 let now_ms = state_hb.start_instant.elapsed().as_millis() as u64;
182 let idle_ms = now_ms.saturating_sub(last_ms);
183 if idle_ms >= timeout.as_millis() as u64 {
184 info!("Idle timeout reached ({} ms), shutting down", idle_ms);
185 let _ = state_hb.shutdown_tx.send(());
186 break;
187 }
188 }
189 }
190 _ = heartbeat_shutdown.recv() => {
191 break;
192 }
193 }
194 }
195 });
196
197 let mut notify_rx = std::mem::replace(&mut self.notify_rx, mpsc::channel(1).1);
199 let mut notify_shutdown = self.state.shutdown_tx.subscribe();
200 tokio::spawn(async move {
201 loop {
202 tokio::select! {
203 Some(notification) = notify_rx.recv() => {
204 debug!(
205 notification_type = %notification.notification_type(),
206 "Notification emitted"
207 );
208 }
209 _ = notify_shutdown.recv() => {
210 break;
211 }
212 }
213 }
214 });
215
216 let mut internal_shutdown = self.state.shutdown_tx.subscribe();
218 tokio::pin!(shutdown_signal);
219
220 loop {
221 tokio::select! {
222 _ = &mut shutdown_signal => {
223 info!("Received shutdown signal");
224 break;
225 }
226 _ = internal_shutdown.recv() => {
227 info!("Internal shutdown signal received");
228 break;
229 }
230 result = listener.accept() => {
231 match result {
232 Ok((stream, _addr)) => {
233 let permit = match self.state.conn_semaphore.clone().try_acquire_owned() {
234 Ok(permit) => permit,
235 Err(_) => {
236 warn!("Connection limit reached ({}), dropping connection", MAX_CONNECTIONS);
237 continue;
238 }
239 };
240 let state = self.state.clone();
241 tokio::spawn(async move {
242 state.touch_activity();
243 handle_connection(stream, &state).await;
244 state.touch_activity();
245 drop(permit);
246 });
247 }
248 Err(e) => {
249 warn!("Accept error: {}", e);
250 }
251 }
252 }
253 }
254 }
255
256 self.state
258 .supervisor_state
259 .transition(SupervisorState::ShuttingDown, Ordering::SeqCst)
260 .ok();
261
262 let _ = self.state.shutdown_tx.send(());
267
268 let _ = std::fs::remove_file(&self.state.socket_path);
270
271 drop(listener);
273
274 let _ = tokio::time::timeout(
278 Duration::from_secs(10),
279 self.state
280 .conn_semaphore
281 .acquire_many(MAX_CONNECTIONS as u32),
282 )
283 .await;
284
285 shutdown_workers(&self.state).await;
287
288 self.state
289 .supervisor_state
290 .transition(SupervisorState::Stopped, Ordering::SeqCst)
291 .ok();
292
293 info!("Supervisor stopped");
294 Ok(())
295 }
296}
297
298async fn shutdown_workers(state: &DaemonState) {
304 let handles: Vec<WorkerHandle> = {
305 let mut workers = state.workers.lock().await;
306 workers.drain().map(|(_, h)| h).collect()
307 };
308 for handle in &handles {
311 let _ = handle.tx.send(WorkerMessage::Shutdown).await;
312 }
313
314 for mut handle in handles {
315 if let Some(jh) = handle.join_handle.take() {
316 match tokio::time::timeout(Duration::from_secs(10), jh).await {
317 Ok(Ok(())) => {}
318 Ok(Err(e)) => warn!("Worker task panicked: {}", e),
319 Err(_) => warn!(
320 "Worker {} didn't shut down within 10s",
321 handle.repo_root.display()
322 ),
323 }
324 }
325 }
326}
327
328async fn handle_connection(mut stream: UnixStream, state: &DaemonState) {
330 let request_bytes =
332 match tokio::time::timeout(Duration::from_secs(30), read_framed_async(&mut stream)).await {
333 Ok(Ok(bytes)) => bytes,
334 Ok(Err(e)) => {
335 debug!("Failed to read request: {}", e);
336 return;
337 }
338 Err(_) => {
339 debug!("Request read timed out");
340 return;
341 }
342 };
343
344 let response = process_request(&request_bytes, state).await;
345
346 match rkyv::to_bytes::<rkyv::rancor::Error>(&response) {
348 Ok(bytes) => {
349 if let Err(e) = tokio::time::timeout(
350 Duration::from_secs(5),
351 write_framed_async(&mut stream, &bytes),
352 )
353 .await
354 {
355 warn!("Failed to send response: {:?}", e);
356 }
357 }
358 Err(e) => {
359 warn!("Failed to serialize response: {}", e);
360 }
361 }
362}
363
364async fn process_request(raw: &[u8], state: &DaemonState) -> IpcResponse {
366 let archived = match rkyv::access::<ArchivedIpcRequest, rkyv::rancor::Error>(raw) {
368 Ok(a) => a,
369 Err(e) => {
370 return IpcResponse::error(
371 "unknown".to_string(),
372 "deserialization".to_string(),
373 format!("Failed to deserialize request: {}", e),
374 );
375 }
376 };
377
378 let version: u32 = archived.ipc_schema_version.into();
380 if version != IPC_SCHEMA_VERSION {
381 return IpcResponse::error(
382 archived.request_id.to_string(),
383 "version_mismatch".to_string(),
384 format!("Expected version {}, got {}", IPC_SCHEMA_VERSION, version),
385 );
386 }
387
388 let request: IpcRequest = match rkyv::deserialize::<IpcRequest, rkyv::rancor::Error>(archived) {
390 Ok(r) => r,
391 Err(e) => {
392 return IpcResponse::error(
393 archived.request_id.to_string(),
394 "deserialization".to_string(),
395 format!("Failed to deserialize request: {}", e),
396 );
397 }
398 };
399
400 debug!(
401 request_id = %request.request_id,
402 repo = %request.repo_root,
403 actor = %request.actor_id,
404 "Handling request"
405 );
406
407 match &request.command {
409 IpcCommand::DaemonStop => {
410 let _ = state.shutdown_tx.send(());
411 return IpcResponse::success(
412 request.request_id,
413 Some(serde_json::json!({"stopping": true}).to_string()),
414 );
415 }
416 IpcCommand::DaemonStatus => {
417 let workers_guard = state.workers.lock().await;
418 let worker_count = workers_guard.len();
419 drop(workers_guard);
420
421 let supervisor_state = format!("{:?}", state.supervisor_state.load(Ordering::SeqCst));
422 return IpcResponse::success(
423 request.request_id,
424 Some(
425 serde_json::json!({
426 "running": true,
427 "daemon_id": state.daemon_id,
428 "pid": state.pid,
429 "host_id": state.host_id,
430 "ipc_endpoint": state.socket_path,
431 "started_ts": state.started_ts,
432 "worker_count": worker_count,
433 "state": supervisor_state,
434 })
435 .to_string(),
436 ),
437 );
438 }
439 _ => {}
440 }
441
442 route_to_worker(request, state).await
444}
445
446async fn route_to_worker(request: IpcRequest, state: &DaemonState) -> IpcResponse {
455 let key = WorkerKey {
456 repo_root: request.repo_root.clone(),
457 };
458
459 {
461 let mut workers_guard = state.workers.lock().await;
462
463 if let Some(handle) = workers_guard.get(&key) {
465 if handle.tx.is_closed() {
466 warn!(
467 repo = %handle.repo_root.display(),
468 "Removing dead worker handle"
469 );
470 workers_guard.remove(&key);
471 }
472 }
473
474 if let Some(handle) = workers_guard.get(&key) {
475 let tx = handle.tx.clone();
476 drop(workers_guard);
477 return send_to_worker(&request, tx).await;
478 }
479 }
480 let (tx, rx) = mpsc::channel(100);
483 let repo_root = PathBuf::from(&request.repo_root);
484 let actor_id = request.actor_id.clone();
485 let ntx = state.notify_tx.clone();
486 let hid = state.host_id.clone();
487 let ipc = state.socket_path.clone();
488
489 let worker_result =
490 tokio::task::spawn_blocking(move || Worker::new(repo_root, actor_id, rx, ntx, hid, ipc))
491 .await;
492
493 let worker = match worker_result {
494 Ok(Ok(w)) => w,
495 Ok(Err(e)) => {
496 let workers_guard = state.workers.lock().await;
499 if let Some(handle) = workers_guard.get(&key) {
500 if !handle.tx.is_closed() {
501 let tx = handle.tx.clone();
502 drop(workers_guard);
503 return send_to_worker(&request, tx).await;
504 }
505 }
506 return IpcResponse::error(
507 request.request_id,
508 "worker_creation_failed".to_string(),
509 e.to_string(),
510 );
511 }
512 Err(e) => {
513 return IpcResponse::error(
514 request.request_id,
515 "worker_creation_failed".to_string(),
516 format!("Worker creation panicked: {}", e),
517 );
518 }
519 };
520
521 {
523 let mut workers_guard = state.workers.lock().await;
524
525 if let Some(handle) = workers_guard.get(&key) {
528 if !handle.tx.is_closed() {
529 let tx = handle.tx.clone();
530 drop(workers_guard);
531 return send_to_worker(&request, tx).await;
533 }
534 workers_guard.remove(&key);
535 }
536
537 let repo_root = worker.repo_root.clone();
538 let worker_state = Some(worker.state.clone());
539 let join_handle = tokio::spawn(worker.run());
540
541 workers_guard.insert(
542 key,
543 WorkerHandle {
544 tx: tx.clone(),
545 join_handle: Some(join_handle),
546 repo_root,
547 state: worker_state,
548 },
549 );
550 }
551
552 send_to_worker(&request, tx).await
553}
554
555async fn send_to_worker(request: &IpcRequest, tx: mpsc::Sender<WorkerMessage>) -> IpcResponse {
557 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
558 let msg = WorkerMessage::Command {
559 request_id: request.request_id.clone(),
560 actor_id: request.actor_id.clone(),
561 command: request.command.clone(),
562 response_tx,
563 };
564
565 if tx.send(msg).await.is_err() {
566 return IpcResponse::error(
567 request.request_id.clone(),
568 "worker_unavailable".to_string(),
569 "Worker channel closed".to_string(),
570 );
571 }
572
573 match tokio::time::timeout(Duration::from_secs(30), response_rx).await {
575 Ok(Ok(response)) => response,
576 Ok(Err(_)) => IpcResponse::error(
577 request.request_id.clone(),
578 "worker_error".to_string(),
579 "Worker response channel dropped".to_string(),
580 ),
581 Err(_) => IpcResponse::error(
582 request.request_id.clone(),
583 "timeout".to_string(),
584 "Worker timed out".to_string(),
585 ),
586 }
587}
588
589fn get_host_id() -> String {
591 std::env::var("HOSTNAME")
592 .or_else(|_| std::fs::read_to_string("/etc/hostname").map(|s| s.trim().to_string()))
593 .unwrap_or_else(|_| uuid::Uuid::new_v4().to_string())
594}