1use std::collections::HashMap;
2use std::net::{SocketAddr, TcpStream};
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
5use std::sync::{Arc, Mutex, MutexGuard};
6use std::time::Duration;
7
8use beamr::atom::{Atom, AtomTable};
9use beamr::module::ModuleRegistry;
10use beamr::native::native_process::NativeHandlerFactory;
11use beamr::process::ExitReason;
12use beamr::scheduler::{Scheduler, SchedulerConfig};
13
14use liminal::protocol::WorkerRegistration;
15
16use super::notifier::ConnectionNotifier;
17use super::process::ConnectionProcess;
18use super::services::{ConnectionServices, LiminalConnectionServices};
19use crate::ServerError;
20use crate::config::types::ServerConfig;
21
22const CONNECTION_SCHEDULER_THREADS: usize = 4;
23const CONNECTION_SHUTDOWN_CONTROL_ATOM: &str = "liminal_server_connection_shutdown_control";
24
25#[cfg(test)]
26#[path = "supervisor_tests.rs"]
27mod tests;
28
29#[derive(Clone, Debug)]
31pub struct ConnectionSupervisor {
32 inner: Arc<SupervisorInner>,
33}
34
35impl ConnectionSupervisor {
36 pub fn from_config(config: &ServerConfig) -> Result<Self, ServerError> {
41 Self::with_services(Arc::new(LiminalConnectionServices::from_config(config)?))
42 }
43
44 pub fn new() -> Result<Self, ServerError> {
49 Self::with_services(Arc::new(LiminalConnectionServices::empty()?))
50 }
51
52 pub fn with_services(services: Arc<dyn ConnectionServices>) -> Result<Self, ServerError> {
57 SupervisorInner::new(services, None).map(|inner| Self {
58 inner: Arc::new(inner),
59 })
60 }
61
62 pub fn with_services_and_notifier(
74 services: Arc<dyn ConnectionServices>,
75 notifier: Arc<dyn ConnectionNotifier>,
76 ) -> Result<Self, ServerError> {
77 SupervisorInner::new(services, Some(notifier)).map(|inner| Self {
78 inner: Arc::new(inner),
79 })
80 }
81
82 pub fn spawn_connection(&self, stream: TcpStream) -> Result<ConnectionHandle, ServerError> {
87 self.inner.spawn_connection(stream)
88 }
89
90 #[must_use]
92 pub fn scheduler(&self) -> Arc<Scheduler> {
93 Arc::clone(&self.inner.scheduler)
94 }
95
96 #[must_use]
98 pub fn reap_crashed_connections(&self) -> usize {
99 self.inner.runtime.reap_crashed(&self.inner.scheduler)
100 }
101
102 #[must_use]
104 pub fn is_tracked(&self, pid: u64) -> bool {
105 self.inner.runtime.contains(pid)
106 }
107
108 #[must_use]
110 pub fn active_connection_count(&self) -> usize {
111 self.inner.runtime.active_count()
112 }
113
114 #[must_use]
120 pub fn active_connection_pids(&self) -> Vec<u64> {
121 self.inner
122 .runtime
123 .active_connections()
124 .into_iter()
125 .map(|connection| connection.pid)
126 .collect()
127 }
128
129 pub fn notify_shutdown_subscribers(&self) {
134 self.inner
135 .broadcast_control(&ConnectionControl::NotifyShutdown);
136 }
137
138 pub fn force_close_active_connections(&self) {
143 for connection in self.inner.runtime.active_connections() {
144 tracing::warn!(
145 connection_pid = connection.pid,
146 peer_addr = ?connection.peer_addr,
147 "forcefully closing connection after drain timeout"
148 );
149 if !self
150 .inner
151 .enqueue_control(connection.pid, ConnectionControl::ForceClose)
152 {
153 tracing::warn!(
154 connection_pid = connection.pid,
155 peer_addr = ?connection.peer_addr,
156 "failed to request forceful connection close; process is not live"
157 );
158 }
159 }
160 }
161
162 pub fn push_to_connection(
178 &self,
179 pid: u64,
180 payload: Vec<u8>,
181 ) -> Result<PushReplyAwaiter, ServerError> {
182 let correlation_id = self.inner.runtime.next_push_correlation_id();
183 let receiver = self.inner.runtime.register_push(pid, correlation_id)?;
184 let control = ConnectionControl::Push {
185 correlation_id,
186 payload,
187 };
188 if self.inner.enqueue_control(pid, control) {
189 Ok(PushReplyAwaiter {
190 correlation_id,
191 receiver,
192 })
193 } else {
194 self.inner.runtime.cancel_push(correlation_id);
197 Err(ServerError::ListenerAccept {
198 message: format!("cannot push to connection process {pid}: process is not live"),
199 })
200 }
201 }
202
203 pub fn flush_durable_state(&self) -> Result<(), ServerError> {
208 self.inner.runtime.services().flush_durable_state()
209 }
210
211 pub fn shutdown(&self) {
213 self.inner.scheduler.shutdown();
214 }
215}
216
217#[derive(Clone, Debug)]
219pub struct ConnectionHandle {
220 pid: u64,
221 peer_addr: Option<SocketAddr>,
222 supervisor: Arc<SupervisorInner>,
223}
224
225impl ConnectionHandle {
226 #[must_use]
228 pub const fn pid(&self) -> u64 {
229 self.pid
230 }
231
232 #[must_use]
234 pub const fn peer_addr(&self) -> Option<SocketAddr> {
235 self.peer_addr
236 }
237
238 #[must_use]
240 pub fn is_live(&self) -> bool {
241 self.supervisor
242 .scheduler
243 .process_table()
244 .get(self.pid)
245 .is_some()
246 }
247
248 pub fn request_crash(&self) -> Result<(), ServerError> {
253 if self
254 .supervisor
255 .scheduler
256 .enqueue_atom_message(self.pid, Atom::ERROR)
257 {
258 Ok(())
259 } else {
260 Err(ServerError::ListenerAccept {
261 message: format!("connection process {} is not live", self.pid),
262 })
263 }
264 }
265}
266
267#[derive(Debug)]
274pub struct PushReplyAwaiter {
275 correlation_id: u64,
276 receiver: Receiver<Vec<u8>>,
277}
278
279impl PushReplyAwaiter {
280 #[must_use]
282 pub const fn correlation_id(&self) -> u64 {
283 self.correlation_id
284 }
285
286 pub fn receive(&self, timeout: Duration) -> Result<Vec<u8>, ServerError> {
295 self.receiver
296 .recv_timeout(timeout)
297 .map_err(|error| match error {
298 RecvTimeoutError::Timeout => ServerError::PushReplyTimeout {
299 correlation_id: self.correlation_id,
300 },
301 RecvTimeoutError::Disconnected => ServerError::PushReplyDisconnected {
302 correlation_id: self.correlation_id,
303 },
304 })
305 }
306}
307
308pub(super) struct SupervisorInner {
309 scheduler: Arc<Scheduler>,
310 runtime: Arc<ConnectionRuntime>,
311}
312
313impl std::fmt::Debug for SupervisorInner {
314 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 formatter
316 .debug_struct("SupervisorInner")
317 .field("runtime", &self.runtime)
318 .finish_non_exhaustive()
319 }
320}
321
322impl SupervisorInner {
323 fn new(
324 services: Arc<dyn ConnectionServices>,
325 notifier: Option<Arc<dyn ConnectionNotifier>>,
326 ) -> Result<Self, ServerError> {
327 let atoms = AtomTable::with_common_atoms();
328 let control_atom = atoms.intern(CONNECTION_SHUTDOWN_CONTROL_ATOM);
329 let registry = Arc::new(ModuleRegistry::new());
330
331 let scheduler = Scheduler::new(
332 SchedulerConfig {
333 thread_count: Some(CONNECTION_SCHEDULER_THREADS),
334 ..SchedulerConfig::default()
335 },
336 registry,
337 )
338 .map_err(|message| ServerError::ListenerAccept {
339 message: format!("failed to start connection scheduler: {message}"),
340 })?;
341 Ok(Self {
342 scheduler: Arc::new(scheduler),
343 runtime: Arc::new(ConnectionRuntime::new(services, control_atom, notifier)),
344 })
345 }
346
347 fn spawn_connection(
348 self: &Arc<Self>,
349 stream: TcpStream,
350 ) -> Result<ConnectionHandle, ServerError> {
351 stream
352 .set_nonblocking(true)
353 .map_err(|error| ServerError::ListenerAccept {
354 message: format!("failed to configure connection stream: {error}"),
355 })?;
356 let peer_addr = stream.peer_addr().ok();
357 let holder = Arc::new(Mutex::new(Some(stream)));
358 let runtime = Arc::clone(&self.runtime);
359 let process_holder = Arc::clone(&holder);
360 let factory: NativeHandlerFactory = Box::new(move || {
361 Box::new(ConnectionProcess::from_holder(
362 Arc::clone(&runtime),
363 peer_addr,
364 &process_holder,
365 ))
366 });
367 let pid =
368 self.scheduler
369 .spawn_native(factory)
370 .map_err(|error| ServerError::ListenerAccept {
371 message: format!("failed to spawn connection process: {error}"),
372 })?;
373 self.runtime.register(pid, peer_addr)?;
374 Ok(ConnectionHandle {
375 pid,
376 peer_addr,
377 supervisor: Arc::clone(self),
378 })
379 }
380
381 fn broadcast_control(&self, control: &ConnectionControl) {
382 for connection in self.runtime.active_connections() {
383 if !self.enqueue_control(connection.pid, control.clone()) {
384 tracing::debug!(
385 connection_pid = connection.pid,
386 peer_addr = ?connection.peer_addr,
387 ?control,
388 "connection control message skipped because process is not live"
389 );
390 }
391 }
392 }
393
394 fn enqueue_control(&self, pid: u64, control: ConnectionControl) -> bool {
395 let removal_key = control.clone();
399 if self.runtime.push_control(pid, control).is_err() {
400 return false;
401 }
402 if self
403 .scheduler
404 .enqueue_atom_message(pid, self.runtime.control_atom())
405 {
406 true
407 } else {
408 self.runtime.remove_control(pid, &removal_key);
409 false
410 }
411 }
412}
413
414#[derive(Debug, Clone, PartialEq, Eq)]
415pub(super) enum ConnectionControl {
416 NotifyShutdown,
417 ForceClose,
418 Push {
421 correlation_id: u64,
422 payload: Vec<u8>,
423 },
424}
425
426#[derive(Debug, Clone, Copy, PartialEq, Eq)]
427pub struct ActiveConnection {
428 pid: u64,
429 peer_addr: Option<SocketAddr>,
430}
431
432#[derive(Debug)]
433pub(super) struct ConnectionRuntime {
434 services: Arc<dyn ConnectionServices>,
435 records: Mutex<HashMap<u64, ConnectionRecord>>,
436 controls: Mutex<Vec<QueuedConnectionControl>>,
437 control_atom: Atom,
438 push_replies: Mutex<HashMap<u64, PendingPush>>,
444 next_push_id: AtomicU64,
447 notifier: Option<Arc<dyn ConnectionNotifier>>,
451}
452
453impl ConnectionRuntime {
454 fn new(
455 services: Arc<dyn ConnectionServices>,
456 control_atom: Atom,
457 notifier: Option<Arc<dyn ConnectionNotifier>>,
458 ) -> Self {
459 Self {
460 services,
461 records: Mutex::new(HashMap::new()),
462 controls: Mutex::new(Vec::new()),
463 control_atom,
464 push_replies: Mutex::new(HashMap::new()),
465 next_push_id: AtomicU64::new(1),
466 notifier,
467 }
468 }
469
470 #[cfg(test)]
474 pub(super) fn for_tests(services: Arc<dyn ConnectionServices>) -> Self {
475 let atoms = AtomTable::with_common_atoms();
476 let control_atom = atoms.intern(CONNECTION_SHUTDOWN_CONTROL_ATOM);
477 Self::new(services, control_atom, None)
478 }
479
480 #[cfg(test)]
483 pub(super) fn for_tests_with_notifier(
484 services: Arc<dyn ConnectionServices>,
485 notifier: Arc<dyn ConnectionNotifier>,
486 ) -> Self {
487 let atoms = AtomTable::with_common_atoms();
488 let control_atom = atoms.intern(CONNECTION_SHUTDOWN_CONTROL_ATOM);
489 Self::new(services, control_atom, Some(notifier))
490 }
491
492 pub(super) fn services(&self) -> &dyn ConnectionServices {
493 self.services.as_ref()
494 }
495
496 pub(super) fn notifier(&self) -> Option<&Arc<dyn ConnectionNotifier>> {
498 self.notifier.as_ref()
499 }
500
501 pub(super) fn notifier_channel_publish(&self, pid: u64, channel: &str, payload: &[u8]) -> bool {
507 self.notifier
508 .as_ref()
509 .is_some_and(|notifier| notifier.on_channel_publish(pid, channel, payload))
510 }
511
512 pub(super) fn set_registration(
520 &self,
521 pid: u64,
522 registration: WorkerRegistration,
523 ) -> Result<(), ServerError> {
524 if let Some(record) = lock(&self.records, "connection registry")?.get_mut(&pid) {
525 record.registration = Some(registration);
526 }
527 Ok(())
528 }
529
530 fn next_push_correlation_id(&self) -> u64 {
532 self.next_push_id.fetch_add(1, Ordering::Relaxed)
533 }
534
535 fn register_push(
543 &self,
544 pid: u64,
545 correlation_id: u64,
546 ) -> Result<Receiver<Vec<u8>>, ServerError> {
547 let (sender, receiver) = channel();
548 lock(&self.push_replies, "push correlation registry")?
549 .insert(correlation_id, PendingPush { pid, sender });
550 Ok(receiver)
551 }
552
553 pub(super) fn cancel_push(&self, correlation_id: u64) {
557 if let Ok(mut slots) = self.push_replies.lock() {
558 slots.remove(&correlation_id);
559 }
560 }
561
562 fn cancel_pushes_for_connection(&self, pid: u64) {
569 if let Ok(mut slots) = self.push_replies.lock() {
570 slots.retain(|_correlation_id, pending| pending.pid != pid);
571 }
572 }
573
574 pub(super) fn resolve_push(&self, correlation_id: u64, payload: Vec<u8>) {
579 let pending = self
580 .push_replies
581 .lock()
582 .ok()
583 .and_then(|mut slots| slots.remove(&correlation_id));
584 if let Some(pending) = pending {
585 pending.sender.send(payload).ok();
588 }
589 }
590
591 pub(super) const fn control_atom(&self) -> Atom {
592 self.control_atom
593 }
594
595 fn register(&self, pid: u64, peer_addr: Option<SocketAddr>) -> Result<(), ServerError> {
610 lock(&self.records, "connection registry")?.insert(
611 pid,
612 ConnectionRecord {
613 peer_addr,
614 registration: None,
615 },
616 );
617 Ok(())
618 }
619
620 pub(super) fn mark_crashed(&self, pid: u64, reason: ExitReason, peer_addr: Option<SocketAddr>) {
621 let removed = self.remove(pid).unwrap_or(ConnectionRecord {
622 peer_addr,
623 registration: None,
624 });
625 self.fire_unregistered(pid, &removed);
626 tracing::warn!(
627 connection_pid = pid,
628 peer_addr = ?removed.peer_addr,
629 reason = ?reason,
630 "connection process crashed"
631 );
632 }
633
634 pub(super) fn finish(&self, pid: u64) {
635 if let Some(removed) = self.remove(pid) {
636 self.fire_unregistered(pid, &removed);
637 }
638 }
639
640 fn fire_unregistered(&self, pid: u64, record: &ConnectionRecord) {
645 if record.registration.is_some() {
646 if let Some(notifier) = self.notifier.as_ref() {
647 notifier.on_worker_unregistered(pid);
648 }
649 }
650 }
651
652 fn reap_crashed(&self, scheduler: &Scheduler) -> usize {
653 let pids = match self.records.lock() {
654 Ok(records) => records.keys().copied().collect::<Vec<_>>(),
655 Err(error) => {
656 tracing::warn!(%error, "connection registry unavailable during crash reap");
657 return 0;
658 }
659 };
660 let mut reaped = 0;
661 for pid in pids {
662 if scheduler.process_table().get(pid).is_none() {
663 let removed = self.remove(pid);
664 if let Some(record) = removed.as_ref() {
665 self.fire_unregistered(pid, record);
666 }
667 let peer_addr = removed.and_then(|record| record.peer_addr);
668 tracing::warn!(
677 connection_pid = pid,
678 ?peer_addr,
679 reason = "terminated externally (no exit reason recorded by supervisor)",
680 "connection process crashed"
681 );
682 reaped += 1;
683 }
684 }
685 reaped
686 }
687
688 fn contains(&self, pid: u64) -> bool {
689 self.records
690 .lock()
691 .is_ok_and(|records| records.contains_key(&pid))
692 }
693
694 fn active_connections(&self) -> Vec<ActiveConnection> {
695 self.records.lock().map_or_else(
696 |_| Vec::new(),
697 |records| {
698 records
699 .iter()
700 .map(|(&pid, record)| ActiveConnection {
701 pid,
702 peer_addr: record.peer_addr,
703 })
704 .collect()
705 },
706 )
707 }
708
709 fn push_control(&self, pid: u64, control: ConnectionControl) -> Result<(), ServerError> {
710 lock(&self.controls, "connection control queue")?
711 .push(QueuedConnectionControl { pid, control });
712 Ok(())
713 }
714
715 pub(super) fn pop_control(&self, pid: u64) -> Option<ConnectionControl> {
716 let mut controls = self.controls.lock().ok()?;
717 let index = controls.iter().position(|queued| queued.pid == pid)?;
718 Some(controls.remove(index).control)
719 }
720
721 fn remove_control(&self, pid: u64, control: &ConnectionControl) {
722 let Ok(mut controls) = self.controls.lock() else {
723 return;
724 };
725 let Some(index) = controls
726 .iter()
727 .position(|queued| queued.pid == pid && &queued.control == control)
728 else {
729 return;
730 };
731 controls.remove(index);
732 }
733
734 fn active_count(&self) -> usize {
735 self.records.lock().map_or(0, |records| records.len())
736 }
737
738 fn remove(&self, pid: u64) -> Option<ConnectionRecord> {
745 self.cancel_pushes_for_connection(pid);
746 self.records
747 .lock()
748 .ok()
749 .and_then(|mut records| records.remove(&pid))
750 }
751}
752
753#[derive(Debug)]
759struct PendingPush {
760 pid: u64,
761 sender: Sender<Vec<u8>>,
762}
763
764#[derive(Debug, Clone)]
765struct ConnectionRecord {
766 peer_addr: Option<SocketAddr>,
767 registration: Option<WorkerRegistration>,
771}
772
773#[derive(Debug, Clone, PartialEq, Eq)]
774struct QueuedConnectionControl {
775 pid: u64,
776 control: ConnectionControl,
777}
778
779fn lock<'a, T>(mutex: &'a Mutex<T>, context: &str) -> Result<MutexGuard<'a, T>, ServerError> {
780 mutex.lock().map_err(|error| ServerError::ListenerAccept {
781 message: format!("{context} unavailable: {error}"),
782 })
783}