1use std::collections::HashMap;
2use std::net::{SocketAddr, TcpStream};
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
5use std::sync::{Arc, Mutex, MutexGuard};
6use std::time::Duration;
7
8use beamr::atom::{Atom, AtomTable};
9use beamr::module::ModuleRegistry;
10use beamr::native::native_process::NativeHandlerFactory;
11use beamr::process::ExitReason;
12use beamr::scheduler::{Scheduler, SchedulerConfig};
13
14use liminal::protocol::WorkerRegistration;
15
16use super::notifier::ConnectionNotifier;
17use super::process::ConnectionProcess;
18use super::services::{ConnectionServices, LiminalConnectionServices};
19use crate::ServerError;
20use crate::config::types::ServerConfig;
21
22const CONNECTION_SCHEDULER_THREADS: usize = 4;
23const CONNECTION_SHUTDOWN_CONTROL_ATOM: &str = "liminal_server_connection_shutdown_control";
24
25#[cfg(test)]
26#[path = "supervisor_tests.rs"]
27mod tests;
28
29#[derive(Clone, Debug)]
31pub struct ConnectionSupervisor {
32 inner: Arc<SupervisorInner>,
33}
34
35impl ConnectionSupervisor {
36 pub fn from_config(config: &ServerConfig) -> Result<Self, ServerError> {
41 Self::with_services(Arc::new(LiminalConnectionServices::from_config(config)?))
42 }
43
44 pub fn new() -> Result<Self, ServerError> {
49 Self::with_services(Arc::new(LiminalConnectionServices::empty()?))
50 }
51
52 pub fn with_services(services: Arc<dyn ConnectionServices>) -> Result<Self, ServerError> {
57 SupervisorInner::new(services, None).map(|inner| Self {
58 inner: Arc::new(inner),
59 })
60 }
61
62 pub fn with_services_and_notifier(
74 services: Arc<dyn ConnectionServices>,
75 notifier: Arc<dyn ConnectionNotifier>,
76 ) -> Result<Self, ServerError> {
77 SupervisorInner::new(services, Some(notifier)).map(|inner| Self {
78 inner: Arc::new(inner),
79 })
80 }
81
82 pub fn spawn_connection(&self, stream: TcpStream) -> Result<ConnectionHandle, ServerError> {
87 self.inner.spawn_connection(stream)
88 }
89
90 #[must_use]
92 pub fn scheduler(&self) -> Arc<Scheduler> {
93 Arc::clone(&self.inner.scheduler)
94 }
95
96 #[must_use]
98 pub fn reap_crashed_connections(&self) -> usize {
99 self.inner.runtime.reap_crashed(&self.inner.scheduler)
100 }
101
102 #[must_use]
104 pub fn is_tracked(&self, pid: u64) -> bool {
105 self.inner.runtime.contains(pid)
106 }
107
108 #[must_use]
110 pub fn active_connection_count(&self) -> usize {
111 self.inner.runtime.active_count()
112 }
113
114 #[must_use]
120 pub fn active_connection_pids(&self) -> Vec<u64> {
121 self.inner
122 .runtime
123 .active_connections()
124 .into_iter()
125 .map(|connection| connection.pid)
126 .collect()
127 }
128
129 pub fn notify_shutdown_subscribers(&self) {
134 self.inner
135 .broadcast_control(&ConnectionControl::NotifyShutdown);
136 }
137
138 pub fn force_close_active_connections(&self) {
143 for connection in self.inner.runtime.active_connections() {
144 tracing::warn!(
145 connection_pid = connection.pid,
146 peer_addr = ?connection.peer_addr,
147 "forcefully closing connection after drain timeout"
148 );
149 if !self
150 .inner
151 .enqueue_control(connection.pid, ConnectionControl::ForceClose)
152 {
153 tracing::warn!(
154 connection_pid = connection.pid,
155 peer_addr = ?connection.peer_addr,
156 "failed to request forceful connection close; process is not live"
157 );
158 }
159 }
160 }
161
162 pub fn push_to_connection(
178 &self,
179 pid: u64,
180 payload: Vec<u8>,
181 ) -> Result<PushReplyAwaiter, ServerError> {
182 let correlation_id = self.inner.runtime.next_push_correlation_id();
183 let receiver = self.inner.runtime.register_push(pid, correlation_id)?;
184 let control = ConnectionControl::Push {
185 correlation_id,
186 payload,
187 };
188 if self.inner.enqueue_control(pid, control) {
189 Ok(PushReplyAwaiter {
190 correlation_id,
191 receiver,
192 })
193 } else {
194 self.inner.runtime.cancel_push(correlation_id);
197 Err(ServerError::ListenerAccept {
198 message: format!("cannot push to connection process {pid}: process is not live"),
199 })
200 }
201 }
202
203 pub fn flush_durable_state(&self) -> Result<(), ServerError> {
208 self.inner.runtime.services().flush_durable_state()
209 }
210
211 pub fn shutdown(&self) {
213 self.inner.scheduler.shutdown();
214 }
215}
216
217#[derive(Clone, Debug)]
219pub struct ConnectionHandle {
220 pid: u64,
221 peer_addr: Option<SocketAddr>,
222 supervisor: Arc<SupervisorInner>,
223}
224
225impl ConnectionHandle {
226 #[must_use]
228 pub const fn pid(&self) -> u64 {
229 self.pid
230 }
231
232 #[must_use]
234 pub const fn peer_addr(&self) -> Option<SocketAddr> {
235 self.peer_addr
236 }
237
238 #[must_use]
240 pub fn is_live(&self) -> bool {
241 self.supervisor
242 .scheduler
243 .process_table()
244 .get(self.pid)
245 .is_some()
246 }
247
248 pub fn request_crash(&self) -> Result<(), ServerError> {
253 if self
254 .supervisor
255 .scheduler
256 .enqueue_atom_message(self.pid, Atom::ERROR)
257 {
258 Ok(())
259 } else {
260 Err(ServerError::ListenerAccept {
261 message: format!("connection process {} is not live", self.pid),
262 })
263 }
264 }
265}
266
267#[derive(Debug)]
274pub struct PushReplyAwaiter {
275 correlation_id: u64,
276 receiver: Receiver<Vec<u8>>,
277}
278
279impl PushReplyAwaiter {
280 #[must_use]
282 pub const fn correlation_id(&self) -> u64 {
283 self.correlation_id
284 }
285
286 pub fn receive(&self, timeout: Duration) -> Result<Vec<u8>, ServerError> {
295 self.receiver
296 .recv_timeout(timeout)
297 .map_err(|error| match error {
298 RecvTimeoutError::Timeout => ServerError::PushReplyTimeout {
299 correlation_id: self.correlation_id,
300 },
301 RecvTimeoutError::Disconnected => ServerError::PushReplyDisconnected {
302 correlation_id: self.correlation_id,
303 },
304 })
305 }
306}
307
308pub(super) struct SupervisorInner {
309 scheduler: Arc<Scheduler>,
310 runtime: Arc<ConnectionRuntime>,
311}
312
313impl std::fmt::Debug for SupervisorInner {
314 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 formatter
316 .debug_struct("SupervisorInner")
317 .field("runtime", &self.runtime)
318 .finish_non_exhaustive()
319 }
320}
321
322impl SupervisorInner {
323 fn new(
324 services: Arc<dyn ConnectionServices>,
325 notifier: Option<Arc<dyn ConnectionNotifier>>,
326 ) -> Result<Self, ServerError> {
327 let atoms = AtomTable::with_common_atoms();
328 let control_atom = atoms.intern(CONNECTION_SHUTDOWN_CONTROL_ATOM);
329 let registry = Arc::new(ModuleRegistry::new());
330
331 let scheduler = Scheduler::new(
332 SchedulerConfig {
333 thread_count: Some(CONNECTION_SCHEDULER_THREADS),
334 ..SchedulerConfig::default()
335 },
336 registry,
337 )
338 .map_err(|message| ServerError::ListenerAccept {
339 message: format!("failed to start connection scheduler: {message}"),
340 })?;
341 Ok(Self {
342 scheduler: Arc::new(scheduler),
343 runtime: Arc::new(ConnectionRuntime::new(services, control_atom, notifier)),
344 })
345 }
346
347 fn spawn_connection(
348 self: &Arc<Self>,
349 stream: TcpStream,
350 ) -> Result<ConnectionHandle, ServerError> {
351 stream
352 .set_nonblocking(true)
353 .map_err(|error| ServerError::ListenerAccept {
354 message: format!("failed to configure connection stream: {error}"),
355 })?;
356 let peer_addr = stream.peer_addr().ok();
357 let holder = Arc::new(Mutex::new(Some(stream)));
358 let runtime = Arc::clone(&self.runtime);
359 let process_holder = Arc::clone(&holder);
360 let factory: NativeHandlerFactory = Box::new(move || {
361 Box::new(ConnectionProcess::from_holder(
362 Arc::clone(&runtime),
363 peer_addr,
364 &process_holder,
365 ))
366 });
367 let pid =
368 self.scheduler
369 .spawn_native(factory)
370 .map_err(|error| ServerError::ListenerAccept {
371 message: format!("failed to spawn connection process: {error}"),
372 })?;
373 self.runtime.register(pid, peer_addr)?;
374 Ok(ConnectionHandle {
375 pid,
376 peer_addr,
377 supervisor: Arc::clone(self),
378 })
379 }
380
381 fn broadcast_control(&self, control: &ConnectionControl) {
382 for connection in self.runtime.active_connections() {
383 if !self.enqueue_control(connection.pid, control.clone()) {
384 tracing::debug!(
385 connection_pid = connection.pid,
386 peer_addr = ?connection.peer_addr,
387 ?control,
388 "connection control message skipped because process is not live"
389 );
390 }
391 }
392 }
393
394 fn enqueue_control(&self, pid: u64, control: ConnectionControl) -> bool {
395 let removal_key = control.clone();
399 if self.runtime.push_control(pid, control).is_err() {
400 return false;
401 }
402 if self
403 .scheduler
404 .enqueue_atom_message(pid, self.runtime.control_atom())
405 {
406 true
407 } else {
408 self.runtime.remove_control(pid, &removal_key);
409 false
410 }
411 }
412}
413
414#[derive(Debug, Clone, PartialEq, Eq)]
415pub(super) enum ConnectionControl {
416 NotifyShutdown,
417 ForceClose,
418 Push {
421 correlation_id: u64,
422 payload: Vec<u8>,
423 },
424}
425
426#[derive(Debug, Clone, Copy, PartialEq, Eq)]
427pub struct ActiveConnection {
428 pid: u64,
429 peer_addr: Option<SocketAddr>,
430}
431
432#[derive(Debug)]
433pub(super) struct ConnectionRuntime {
434 services: Arc<dyn ConnectionServices>,
435 records: Mutex<HashMap<u64, ConnectionRecord>>,
436 controls: Mutex<Vec<QueuedConnectionControl>>,
437 control_atom: Atom,
438 push_replies: Mutex<HashMap<u64, PendingPush>>,
444 next_push_id: AtomicU64,
447 notifier: Option<Arc<dyn ConnectionNotifier>>,
451}
452
453impl ConnectionRuntime {
454 fn new(
455 services: Arc<dyn ConnectionServices>,
456 control_atom: Atom,
457 notifier: Option<Arc<dyn ConnectionNotifier>>,
458 ) -> Self {
459 Self {
460 services,
461 records: Mutex::new(HashMap::new()),
462 controls: Mutex::new(Vec::new()),
463 control_atom,
464 push_replies: Mutex::new(HashMap::new()),
465 next_push_id: AtomicU64::new(1),
466 notifier,
467 }
468 }
469
470 #[cfg(test)]
474 pub(super) fn for_tests(services: Arc<dyn ConnectionServices>) -> Self {
475 let atoms = AtomTable::with_common_atoms();
476 let control_atom = atoms.intern(CONNECTION_SHUTDOWN_CONTROL_ATOM);
477 Self::new(services, control_atom, None)
478 }
479
480 #[cfg(test)]
483 pub(super) fn for_tests_with_notifier(
484 services: Arc<dyn ConnectionServices>,
485 notifier: Arc<dyn ConnectionNotifier>,
486 ) -> Self {
487 let atoms = AtomTable::with_common_atoms();
488 let control_atom = atoms.intern(CONNECTION_SHUTDOWN_CONTROL_ATOM);
489 Self::new(services, control_atom, Some(notifier))
490 }
491
492 pub(super) fn services(&self) -> &dyn ConnectionServices {
493 self.services.as_ref()
494 }
495
496 pub(super) fn notifier(&self) -> Option<&Arc<dyn ConnectionNotifier>> {
498 self.notifier.as_ref()
499 }
500
501 pub(super) fn set_registration(
509 &self,
510 pid: u64,
511 registration: WorkerRegistration,
512 ) -> Result<(), ServerError> {
513 if let Some(record) = lock(&self.records, "connection registry")?.get_mut(&pid) {
514 record.registration = Some(registration);
515 }
516 Ok(())
517 }
518
519 fn next_push_correlation_id(&self) -> u64 {
521 self.next_push_id.fetch_add(1, Ordering::Relaxed)
522 }
523
524 fn register_push(
532 &self,
533 pid: u64,
534 correlation_id: u64,
535 ) -> Result<Receiver<Vec<u8>>, ServerError> {
536 let (sender, receiver) = channel();
537 lock(&self.push_replies, "push correlation registry")?
538 .insert(correlation_id, PendingPush { pid, sender });
539 Ok(receiver)
540 }
541
542 pub(super) fn cancel_push(&self, correlation_id: u64) {
546 if let Ok(mut slots) = self.push_replies.lock() {
547 slots.remove(&correlation_id);
548 }
549 }
550
551 fn cancel_pushes_for_connection(&self, pid: u64) {
558 if let Ok(mut slots) = self.push_replies.lock() {
559 slots.retain(|_correlation_id, pending| pending.pid != pid);
560 }
561 }
562
563 pub(super) fn resolve_push(&self, correlation_id: u64, payload: Vec<u8>) {
568 let pending = self
569 .push_replies
570 .lock()
571 .ok()
572 .and_then(|mut slots| slots.remove(&correlation_id));
573 if let Some(pending) = pending {
574 pending.sender.send(payload).ok();
577 }
578 }
579
580 pub(super) const fn control_atom(&self) -> Atom {
581 self.control_atom
582 }
583
584 fn register(&self, pid: u64, peer_addr: Option<SocketAddr>) -> Result<(), ServerError> {
599 lock(&self.records, "connection registry")?.insert(
600 pid,
601 ConnectionRecord {
602 peer_addr,
603 registration: None,
604 },
605 );
606 Ok(())
607 }
608
609 pub(super) fn mark_crashed(&self, pid: u64, reason: ExitReason, peer_addr: Option<SocketAddr>) {
610 let removed = self.remove(pid).unwrap_or(ConnectionRecord {
611 peer_addr,
612 registration: None,
613 });
614 self.fire_unregistered(pid, &removed);
615 tracing::warn!(
616 connection_pid = pid,
617 peer_addr = ?removed.peer_addr,
618 reason = ?reason,
619 "connection process crashed"
620 );
621 }
622
623 pub(super) fn finish(&self, pid: u64) {
624 if let Some(removed) = self.remove(pid) {
625 self.fire_unregistered(pid, &removed);
626 }
627 }
628
629 fn fire_unregistered(&self, pid: u64, record: &ConnectionRecord) {
634 if record.registration.is_some() {
635 if let Some(notifier) = self.notifier.as_ref() {
636 notifier.on_worker_unregistered(pid);
637 }
638 }
639 }
640
641 fn reap_crashed(&self, scheduler: &Scheduler) -> usize {
642 let pids = match self.records.lock() {
643 Ok(records) => records.keys().copied().collect::<Vec<_>>(),
644 Err(error) => {
645 tracing::warn!(%error, "connection registry unavailable during crash reap");
646 return 0;
647 }
648 };
649 let mut reaped = 0;
650 for pid in pids {
651 if scheduler.process_table().get(pid).is_none() {
652 let removed = self.remove(pid);
653 if let Some(record) = removed.as_ref() {
654 self.fire_unregistered(pid, record);
655 }
656 let peer_addr = removed.and_then(|record| record.peer_addr);
657 tracing::warn!(
666 connection_pid = pid,
667 ?peer_addr,
668 reason = "terminated externally (no exit reason recorded by supervisor)",
669 "connection process crashed"
670 );
671 reaped += 1;
672 }
673 }
674 reaped
675 }
676
677 fn contains(&self, pid: u64) -> bool {
678 self.records
679 .lock()
680 .is_ok_and(|records| records.contains_key(&pid))
681 }
682
683 fn active_connections(&self) -> Vec<ActiveConnection> {
684 self.records.lock().map_or_else(
685 |_| Vec::new(),
686 |records| {
687 records
688 .iter()
689 .map(|(&pid, record)| ActiveConnection {
690 pid,
691 peer_addr: record.peer_addr,
692 })
693 .collect()
694 },
695 )
696 }
697
698 fn push_control(&self, pid: u64, control: ConnectionControl) -> Result<(), ServerError> {
699 lock(&self.controls, "connection control queue")?
700 .push(QueuedConnectionControl { pid, control });
701 Ok(())
702 }
703
704 pub(super) fn pop_control(&self, pid: u64) -> Option<ConnectionControl> {
705 let mut controls = self.controls.lock().ok()?;
706 let index = controls.iter().position(|queued| queued.pid == pid)?;
707 Some(controls.remove(index).control)
708 }
709
710 fn remove_control(&self, pid: u64, control: &ConnectionControl) {
711 let Ok(mut controls) = self.controls.lock() else {
712 return;
713 };
714 let Some(index) = controls
715 .iter()
716 .position(|queued| queued.pid == pid && &queued.control == control)
717 else {
718 return;
719 };
720 controls.remove(index);
721 }
722
723 fn active_count(&self) -> usize {
724 self.records.lock().map_or(0, |records| records.len())
725 }
726
727 fn remove(&self, pid: u64) -> Option<ConnectionRecord> {
734 self.cancel_pushes_for_connection(pid);
735 self.records
736 .lock()
737 .ok()
738 .and_then(|mut records| records.remove(&pid))
739 }
740}
741
742#[derive(Debug)]
748struct PendingPush {
749 pid: u64,
750 sender: Sender<Vec<u8>>,
751}
752
753#[derive(Debug, Clone)]
754struct ConnectionRecord {
755 peer_addr: Option<SocketAddr>,
756 registration: Option<WorkerRegistration>,
760}
761
762#[derive(Debug, Clone, PartialEq, Eq)]
763struct QueuedConnectionControl {
764 pid: u64,
765 control: ConnectionControl,
766}
767
768fn lock<'a, T>(mutex: &'a Mutex<T>, context: &str) -> Result<MutexGuard<'a, T>, ServerError> {
769 mutex.lock().map_err(|error| ServerError::ListenerAccept {
770 message: format!("{context} unavailable: {error}"),
771 })
772}