Skip to main content

grite_daemon/
supervisor.rs

1//! Supervisor module - manages workers and IPC sockets
2//!
3//! The supervisor:
4//! - Listens on a Unix socket for commands
5//! - Manages worker lifecycle
6//! - Routes commands to appropriate workers
7//! - Broadcasts notifications via internal channels
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::path::{Path, PathBuf};
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
15
16use libgrite_ipc::{
17    framing::{read_framed_async, write_framed_async},
18    messages::{ArchivedIpcRequest, IpcRequest, IpcResponse},
19    IpcCommand, Notification, IPC_SCHEMA_VERSION,
20};
21use tokio::net::{UnixListener, UnixStream};
22use tokio::sync::{mpsc, Mutex, Semaphore};
23use tracing::{debug, info, warn};
24
25use crate::error::DaemonError;
26use crate::state::{AtomicSupervisorState, SupervisorState};
27use crate::worker::{Worker, WorkerMessage};
28
29/// Maximum concurrent connections the daemon will handle
30const MAX_CONNECTIONS: usize = 256;
31
32/// Worker handle for communication
33struct WorkerHandle {
34    tx: mpsc::Sender<WorkerMessage>,
35    join_handle: Option<tokio::task::JoinHandle<()>>,
36    repo_root: PathBuf,
37    #[allow(dead_code)]
38    state: Option<Arc<crate::state::AtomicWorkerState>>,
39}
40
41/// Key for worker lookup — one worker per repository
42#[derive(Hash, Eq, PartialEq, Clone)]
43struct WorkerKey {
44    repo_root: String,
45}
46
47/// Shared daemon state accessible from all connection tasks.
48///
49/// Wrapped in `Arc` and passed to every spawned connection task,
50/// replacing the previous pattern of cloning 8+ individual values.
51struct DaemonState {
52    daemon_id: String,
53    host_id: String,
54    pid: u32,
55    started_ts: u64,
56    socket_path: String,
57    workers: Mutex<HashMap<WorkerKey, WorkerHandle>>,
58    notify_tx: mpsc::Sender<Notification>,
59    shutdown_tx: tokio::sync::broadcast::Sender<()>,
60    conn_semaphore: Arc<Semaphore>,
61    last_activity_ms: AtomicU64,
62    start_instant: Instant,
63    idle_timeout: Option<Duration>,
64    supervisor_state: AtomicSupervisorState,
65}
66
67impl DaemonState {
68    fn touch_activity(&self) {
69        let elapsed_ms = self.start_instant.elapsed().as_millis() as u64;
70        self.last_activity_ms.store(elapsed_ms, Ordering::Relaxed);
71    }
72}
73
74/// Supervisor manages workers and IPC
75pub struct Supervisor {
76    state: Arc<DaemonState>,
77    notify_rx: mpsc::Receiver<Notification>,
78}
79
80impl Supervisor {
81    /// Create a new supervisor
82    pub fn new(socket_path: String, idle_timeout: Option<Duration>) -> Self {
83        let (notify_tx, notify_rx) = mpsc::channel(1000);
84        let (shutdown_tx, _) = tokio::sync::broadcast::channel::<()>(1);
85        let start_instant = Instant::now();
86
87        let started_ts = SystemTime::now()
88            .duration_since(UNIX_EPOCH)
89            .unwrap_or_default()
90            .as_millis() as u64;
91
92        let state = Arc::new(DaemonState {
93            daemon_id: uuid::Uuid::new_v4().to_string(),
94            host_id: get_host_id(),
95            pid: std::process::id(),
96            started_ts,
97            socket_path,
98            workers: Mutex::new(HashMap::new()),
99            notify_tx,
100            shutdown_tx,
101            conn_semaphore: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
102            last_activity_ms: AtomicU64::new(0),
103            start_instant,
104            idle_timeout,
105            supervisor_state: AtomicSupervisorState::new(SupervisorState::Starting),
106        });
107
108        Self { state, notify_rx }
109    }
110
111    /// Run the supervisor until shutdown.
112    ///
113    /// Shutdown is triggered by either:
114    /// - The external `shutdown_signal` future resolving (e.g. SIGTERM)
115    /// - An internal trigger (idle timeout, DaemonStop command)
116    ///
117    /// All cleanup (socket removal, worker shutdown) is handled here.
118    pub async fn run(
119        mut self,
120        shutdown_signal: impl Future<Output = ()> + Send,
121    ) -> Result<(), DaemonError> {
122        info!(
123            daemon_id = %self.state.daemon_id,
124            socket_path = %self.state.socket_path,
125            idle_timeout_secs = ?self.state.idle_timeout.map(|d| d.as_secs()),
126            "Supervisor starting"
127        );
128
129        // Initialize last activity to now
130        self.state.touch_activity();
131
132        // Clean up stale socket file, but only if no live supervisor owns it
133        let socket_path = Path::new(&self.state.socket_path);
134        if socket_path.exists() {
135            if std::os::unix::net::UnixStream::connect(socket_path).is_ok() {
136                return Err(DaemonError::BindFailed(format!(
137                    "Another supervisor is already listening on {}",
138                    self.state.socket_path,
139                )));
140            }
141            std::fs::remove_file(socket_path).map_err(|e| {
142                DaemonError::BindFailed(format!(
143                    "Failed to remove stale socket {}: {}",
144                    self.state.socket_path, e
145                ))
146            })?;
147        }
148
149        // Bind Unix listener
150        let listener = UnixListener::bind(&self.state.socket_path).map_err(|e| {
151            DaemonError::BindFailed(format!(
152                "Failed to bind to {}: {}",
153                self.state.socket_path, e
154            ))
155        })?;
156
157        info!("Listening on {}", self.state.socket_path);
158        self.state
159            .supervisor_state
160            .transition(SupervisorState::Running, Ordering::SeqCst)
161            .ok();
162
163        // Spawn heartbeat task (also checks idle timeout)
164        let state_hb = self.state.clone();
165        let mut heartbeat_shutdown = self.state.shutdown_tx.subscribe();
166        tokio::spawn(async move {
167            let mut interval = tokio::time::interval(Duration::from_secs(10));
168            loop {
169                tokio::select! {
170                    _ = interval.tick() => {
171                        // Send heartbeats to workers
172                        let workers = state_hb.workers.lock().await;
173                        for handle in workers.values() {
174                            let _ = handle.tx.send(WorkerMessage::Heartbeat).await;
175                        }
176                        drop(workers);
177
178                        // Check idle timeout
179                        if let Some(timeout) = state_hb.idle_timeout {
180                            let last_ms = state_hb.last_activity_ms.load(Ordering::Relaxed);
181                            let now_ms = state_hb.start_instant.elapsed().as_millis() as u64;
182                            let idle_ms = now_ms.saturating_sub(last_ms);
183                            if idle_ms >= timeout.as_millis() as u64 {
184                                info!("Idle timeout reached ({} ms), shutting down", idle_ms);
185                                let _ = state_hb.shutdown_tx.send(());
186                                break;
187                            }
188                        }
189                    }
190                    _ = heartbeat_shutdown.recv() => {
191                        break;
192                    }
193                }
194            }
195        });
196
197        // Spawn notification consumer (just logs for now since PUB socket is removed)
198        let mut notify_rx = std::mem::replace(&mut self.notify_rx, mpsc::channel(1).1);
199        let mut notify_shutdown = self.state.shutdown_tx.subscribe();
200        tokio::spawn(async move {
201            loop {
202                tokio::select! {
203                    Some(notification) = notify_rx.recv() => {
204                        debug!(
205                            notification_type = %notification.notification_type(),
206                            "Notification emitted"
207                        );
208                    }
209                    _ = notify_shutdown.recv() => {
210                        break;
211                    }
212                }
213            }
214        });
215
216        // Main accept loop
217        let mut internal_shutdown = self.state.shutdown_tx.subscribe();
218        tokio::pin!(shutdown_signal);
219
220        loop {
221            tokio::select! {
222                _ = &mut shutdown_signal => {
223                    info!("Received shutdown signal");
224                    break;
225                }
226                _ = internal_shutdown.recv() => {
227                    info!("Internal shutdown signal received");
228                    break;
229                }
230                result = listener.accept() => {
231                    match result {
232                        Ok((stream, _addr)) => {
233                            let permit = match self.state.conn_semaphore.clone().try_acquire_owned() {
234                                Ok(permit) => permit,
235                                Err(_) => {
236                                    warn!("Connection limit reached ({}), dropping connection", MAX_CONNECTIONS);
237                                    continue;
238                                }
239                            };
240                            let state = self.state.clone();
241                            tokio::spawn(async move {
242                                state.touch_activity();
243                                handle_connection(stream, &state).await;
244                                state.touch_activity();
245                                drop(permit);
246                            });
247                        }
248                        Err(e) => {
249                            warn!("Accept error: {}", e);
250                        }
251                    }
252                }
253            }
254        }
255
256        // === Single cleanup path ===
257        self.state
258            .supervisor_state
259            .transition(SupervisorState::ShuttingDown, Ordering::SeqCst)
260            .ok();
261
262        // Signal background tasks (heartbeat, notifications) to stop.
263        // This is a no-op if shutdown was already triggered internally
264        // (idle timeout / DaemonStop), since those tasks already received
265        // the broadcast — the second send simply has no receivers.
266        let _ = self.state.shutdown_tx.send(());
267
268        // Clean up socket file
269        let _ = std::fs::remove_file(&self.state.socket_path);
270
271        // Stop accepting new connections so no new tasks can spawn.
272        drop(listener);
273
274        // Wait for all in-flight connection tasks to finish (they each
275        // hold a semaphore permit). This prevents the race where a
276        // connection task inserts a new worker after we drain the map.
277        let _ = tokio::time::timeout(
278            Duration::from_secs(10),
279            self.state
280                .conn_semaphore
281                .acquire_many(MAX_CONNECTIONS as u32),
282        )
283        .await;
284
285        // Now safe to drain — no connection tasks are running
286        shutdown_workers(&self.state).await;
287
288        self.state
289            .supervisor_state
290            .transition(SupervisorState::Stopped, Ordering::SeqCst)
291            .ok();
292
293        info!("Supervisor stopped");
294        Ok(())
295    }
296}
297
298/// Drain workers from the map and shut them down.
299///
300/// The mutex is released before sending shutdown messages or awaiting
301/// join handles, preventing deadlocks with in-flight connection tasks
302/// that may be waiting to insert new workers.
303async fn shutdown_workers(state: &DaemonState) {
304    let handles: Vec<WorkerHandle> = {
305        let mut workers = state.workers.lock().await;
306        workers.drain().map(|(_, h)| h).collect()
307    };
308    // Mutex released — in-flight connection tasks can now complete
309
310    for handle in &handles {
311        let _ = handle.tx.send(WorkerMessage::Shutdown).await;
312    }
313
314    for mut handle in handles {
315        if let Some(jh) = handle.join_handle.take() {
316            match tokio::time::timeout(Duration::from_secs(10), jh).await {
317                Ok(Ok(())) => {}
318                Ok(Err(e)) => warn!("Worker task panicked: {}", e),
319                Err(_) => warn!(
320                    "Worker {} didn't shut down within 10s",
321                    handle.repo_root.display()
322                ),
323            }
324        }
325    }
326}
327
328/// Handle a single client connection: read one request, send one response
329async fn handle_connection(mut stream: UnixStream, state: &DaemonState) {
330    // Read request with timeout
331    let request_bytes =
332        match tokio::time::timeout(Duration::from_secs(30), read_framed_async(&mut stream)).await {
333            Ok(Ok(bytes)) => bytes,
334            Ok(Err(e)) => {
335                debug!("Failed to read request: {}", e);
336                return;
337            }
338            Err(_) => {
339                debug!("Request read timed out");
340                return;
341            }
342        };
343
344    let response = process_request(&request_bytes, state).await;
345
346    // Serialize and send response
347    match rkyv::to_bytes::<rkyv::rancor::Error>(&response) {
348        Ok(bytes) => {
349            if let Err(e) = tokio::time::timeout(
350                Duration::from_secs(5),
351                write_framed_async(&mut stream, &bytes),
352            )
353            .await
354            {
355                warn!("Failed to send response: {:?}", e);
356            }
357        }
358        Err(e) => {
359            warn!("Failed to serialize response: {}", e);
360        }
361    }
362}
363
364/// Process a raw request and return a response
365async fn process_request(raw: &[u8], state: &DaemonState) -> IpcResponse {
366    // Deserialize request
367    let archived = match rkyv::access::<ArchivedIpcRequest, rkyv::rancor::Error>(raw) {
368        Ok(a) => a,
369        Err(e) => {
370            return IpcResponse::error(
371                "unknown".to_string(),
372                "deserialization".to_string(),
373                format!("Failed to deserialize request: {}", e),
374            );
375        }
376    };
377
378    // Check version
379    let version: u32 = archived.ipc_schema_version.into();
380    if version != IPC_SCHEMA_VERSION {
381        return IpcResponse::error(
382            archived.request_id.to_string(),
383            "version_mismatch".to_string(),
384            format!("Expected version {}, got {}", IPC_SCHEMA_VERSION, version),
385        );
386    }
387
388    // Deserialize to owned type
389    let request: IpcRequest = match rkyv::deserialize::<IpcRequest, rkyv::rancor::Error>(archived) {
390        Ok(r) => r,
391        Err(e) => {
392            return IpcResponse::error(
393                archived.request_id.to_string(),
394                "deserialization".to_string(),
395                format!("Failed to deserialize request: {}", e),
396            );
397        }
398    };
399
400    debug!(
401        request_id = %request.request_id,
402        repo = %request.repo_root,
403        actor = %request.actor_id,
404        "Handling request"
405    );
406
407    // Handle daemon-level commands at the supervisor, not in workers
408    match &request.command {
409        IpcCommand::DaemonStop => {
410            let _ = state.shutdown_tx.send(());
411            return IpcResponse::success(
412                request.request_id,
413                Some(serde_json::json!({"stopping": true}).to_string()),
414            );
415        }
416        IpcCommand::DaemonStatus => {
417            let workers_guard = state.workers.lock().await;
418            let worker_count = workers_guard.len();
419            drop(workers_guard);
420
421            let supervisor_state = format!("{:?}", state.supervisor_state.load(Ordering::SeqCst));
422            return IpcResponse::success(
423                request.request_id,
424                Some(
425                    serde_json::json!({
426                        "running": true,
427                        "daemon_id": state.daemon_id,
428                        "pid": state.pid,
429                        "host_id": state.host_id,
430                        "ipc_endpoint": state.socket_path,
431                        "started_ts": state.started_ts,
432                        "worker_count": worker_count,
433                        "state": supervisor_state,
434                    })
435                    .to_string(),
436                ),
437            );
438        }
439        _ => {}
440    }
441
442    // Route to worker
443    route_to_worker(request, state).await
444}
445
446/// Route a request to the appropriate worker, creating one if needed.
447///
448/// If the worker's channel is dead (task panicked or exited), the stale
449/// handle is removed and a fresh worker is spawned automatically.
450///
451/// Uses double-checked locking: the workers mutex is NOT held during
452/// `Worker::new` (which does blocking sled I/O). If two tasks race to
453/// create the same worker, the loser finds the winner's entry on re-check.
454async fn route_to_worker(request: IpcRequest, state: &DaemonState) -> IpcResponse {
455    let key = WorkerKey {
456        repo_root: request.repo_root.clone(),
457    };
458
459    // Fast path: check for existing live worker (mutex held briefly)
460    {
461        let mut workers_guard = state.workers.lock().await;
462
463        // Remove dead worker handles
464        if let Some(handle) = workers_guard.get(&key) {
465            if handle.tx.is_closed() {
466                warn!(
467                    repo = %handle.repo_root.display(),
468                    "Removing dead worker handle"
469                );
470                workers_guard.remove(&key);
471            }
472        }
473
474        if let Some(handle) = workers_guard.get(&key) {
475            let tx = handle.tx.clone();
476            drop(workers_guard);
477            return send_to_worker(&request, tx).await;
478        }
479    }
480    // Mutex released — slow path: create worker on blocking thread pool.
481    // Worker::new opens the sled store which can block for seconds.
482    let (tx, rx) = mpsc::channel(100);
483    let repo_root = PathBuf::from(&request.repo_root);
484    let actor_id = request.actor_id.clone();
485    let ntx = state.notify_tx.clone();
486    let hid = state.host_id.clone();
487    let ipc = state.socket_path.clone();
488
489    let worker_result =
490        tokio::task::spawn_blocking(move || Worker::new(repo_root, actor_id, rx, ntx, hid, ipc))
491            .await;
492
493    let worker = match worker_result {
494        Ok(Ok(w)) => w,
495        Ok(Err(e)) => {
496            // Creation failed — another task may have won the race.
497            // Re-check the map before returning an error.
498            let workers_guard = state.workers.lock().await;
499            if let Some(handle) = workers_guard.get(&key) {
500                if !handle.tx.is_closed() {
501                    let tx = handle.tx.clone();
502                    drop(workers_guard);
503                    return send_to_worker(&request, tx).await;
504                }
505            }
506            return IpcResponse::error(
507                request.request_id,
508                "worker_creation_failed".to_string(),
509                e.to_string(),
510            );
511        }
512        Err(e) => {
513            return IpcResponse::error(
514                request.request_id,
515                "worker_creation_failed".to_string(),
516                format!("Worker creation panicked: {}", e),
517            );
518        }
519    };
520
521    // Re-acquire lock and insert (double-check for races)
522    {
523        let mut workers_guard = state.workers.lock().await;
524
525        // Another task may have created a worker for this key while we
526        // were blocked. If so, use theirs and drop ours.
527        if let Some(handle) = workers_guard.get(&key) {
528            if !handle.tx.is_closed() {
529                let tx = handle.tx.clone();
530                drop(workers_guard);
531                // Our worker is dropped here — its sled lock releases on Drop
532                return send_to_worker(&request, tx).await;
533            }
534            workers_guard.remove(&key);
535        }
536
537        let repo_root = worker.repo_root.clone();
538        let worker_state = Some(worker.state.clone());
539        let join_handle = tokio::spawn(worker.run());
540
541        workers_guard.insert(
542            key,
543            WorkerHandle {
544                tx: tx.clone(),
545                join_handle: Some(join_handle),
546                repo_root,
547                state: worker_state,
548            },
549        );
550    }
551
552    send_to_worker(&request, tx).await
553}
554
555/// Send a request to an existing worker and wait for the response
556async fn send_to_worker(request: &IpcRequest, tx: mpsc::Sender<WorkerMessage>) -> IpcResponse {
557    let (response_tx, response_rx) = tokio::sync::oneshot::channel();
558    let msg = WorkerMessage::Command {
559        request_id: request.request_id.clone(),
560        actor_id: request.actor_id.clone(),
561        command: request.command.clone(),
562        response_tx,
563    };
564
565    if tx.send(msg).await.is_err() {
566        return IpcResponse::error(
567            request.request_id.clone(),
568            "worker_unavailable".to_string(),
569            "Worker channel closed".to_string(),
570        );
571    }
572
573    // Wait for response with timeout
574    match tokio::time::timeout(Duration::from_secs(30), response_rx).await {
575        Ok(Ok(response)) => response,
576        Ok(Err(_)) => IpcResponse::error(
577            request.request_id.clone(),
578            "worker_error".to_string(),
579            "Worker response channel dropped".to_string(),
580        ),
581        Err(_) => IpcResponse::error(
582            request.request_id.clone(),
583            "timeout".to_string(),
584            "Worker timed out".to_string(),
585        ),
586    }
587}
588
589/// Get a stable host identifier
590fn get_host_id() -> String {
591    std::env::var("HOSTNAME")
592        .or_else(|_| std::fs::read_to_string("/etc/hostname").map(|s| s.trim().to_string()))
593        .unwrap_or_else(|_| uuid::Uuid::new_v4().to_string())
594}