Skip to main content

liminal/aion/
worker.rs

1use std::collections::BTreeMap;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::sync::{Arc, Mutex, MutexGuard, OnceLock, Weak};
4
5use super::channels::{ChannelName, dispatch_channel};
6use super::codec::dispatch_request_schema;
7use super::dispatch::{DispatchWorker, DispatchWorkerPool};
8use super::error::AionSurfaceError;
9use super::types::{ActivityRequest, WorkerCapacity};
10use crate::channel::{ChannelConfig, ChannelHandle, ChannelMode, SubscriptionHandle};
11use crate::conversation::{ConversationSupervisor, ParticipantPid};
12use crate::routing::{ConsumerId, ConsumerStateView};
13
14mod link;
15use link::WorkerLinkMonitor;
16
17#[derive(Clone)]
18pub struct WorkerContext {
19    inner: Arc<WorkerContextInner>,
20}
21
22impl WorkerContext {
23    /// Create a worker context backed by an embedded, lazily-created
24    /// [`ConversationSupervisor`] per dispatch channel.
25    ///
26    /// Note: the embedded supervisor runs on a single-thread beamr scheduler, so
27    /// all worker link monitors for a channel share one scheduler thread. This is
28    /// appropriate for the single-process embedded case; deployments that register
29    /// very large worker pools should front them with an externally-constructed,
30    /// multi-threaded supervisor rather than relying on this default.
31    #[must_use]
32    pub fn new() -> Self {
33        Self {
34            inner: Arc::new(WorkerContextInner::default()),
35        }
36    }
37
38    /// # Errors
39    /// Returns when channel resolution, channel creation, or subscription opening fails.
40    pub fn register_worker(
41        &self,
42        namespace: &str,
43        task_queue: &str,
44        capacity: WorkerCapacity,
45    ) -> Result<WorkerRegistration, AionSurfaceError> {
46        let channel_name = dispatch_channel(namespace, task_queue)?;
47        let sequence = self.next_sequence();
48        let participant = link::spawn_worker_process(self, &channel_name)?;
49        self.register_worker_on_channel(
50            &channel_name,
51            format!("worker-{sequence}"),
52            participant,
53            capacity,
54            Some(participant),
55        )
56    }
57
58    /// # Errors
59    /// Returns when the dispatch channel cannot be resolved or subscribed.
60    pub fn register_worker_with_participant(
61        &self,
62        namespace: &str,
63        task_queue: &str,
64        worker_id: impl Into<String>,
65        participant: ParticipantPid,
66        capacity: WorkerCapacity,
67    ) -> Result<WorkerRegistration, AionSurfaceError> {
68        let channel_name = dispatch_channel(namespace, task_queue)?;
69        self.register_worker_on_channel(
70            &channel_name,
71            worker_id.into(),
72            participant,
73            capacity,
74            None,
75        )
76    }
77
78    fn register_worker_on_channel(
79        &self,
80        channel_name: &ChannelName,
81        worker_id: String,
82        participant: ParticipantPid,
83        capacity: WorkerCapacity,
84        owned_participant: Option<ParticipantPid>,
85    ) -> Result<WorkerRegistration, AionSurfaceError> {
86        let session = self.session_for(channel_name)?;
87        let subscription = session
88            .handle
89            .subscribe()
90            .map_err(|error| lifecycle_failed(channel_name, error))?;
91        let subscription = Mutex::new(Some(subscription));
92        let entry = Arc::new(WorkerEntry {
93            channel_name: channel_name.clone(),
94            worker_id,
95            participant,
96            capacity,
97            subscription,
98            current_in_flight: AtomicU32::new(0),
99            active: AtomicBool::new(true),
100        });
101        // Arm crash detection BEFORE the entry is inserted into the pool, so a
102        // worker is never dispatch-eligible without a live link monitor. If the
103        // participant dies in the window before insert, the listener deactivates
104        // the entry and the next `retain_active` prunes it — the pool never
105        // exposes a worker whose crash would go unnoticed.
106        let monitor = link::monitor_worker_process(
107            self.clone(),
108            channel_name,
109            participant,
110            Arc::downgrade(&entry),
111            owned_participant,
112        )?;
113        self.insert_entry(channel_name, &entry)?;
114        Ok(WorkerRegistration::new(self.clone(), entry, monitor))
115    }
116
117    /// # Errors
118    /// Returns when the worker pool cannot be read.
119    pub fn workers_for_channel(
120        &self,
121        channel_name: &ChannelName,
122        request: &ActivityRequest,
123    ) -> Result<Vec<DispatchWorker>, AionSurfaceError> {
124        <Self as DispatchWorkerPool>::workers_for(self, channel_name, request)
125    }
126
127    fn next_sequence(&self) -> u64 {
128        self.inner
129            .next_worker
130            .fetch_add(1, Ordering::Relaxed)
131            .saturating_add(1)
132    }
133
134    fn session_for(&self, channel_name: &ChannelName) -> Result<ChannelSession, AionSurfaceError> {
135        if let Some(session) = self.lookup_session(channel_name)? {
136            return Ok(session);
137        }
138
139        let schema =
140            dispatch_request_schema().map_err(|error| lifecycle_failed(channel_name, error))?;
141        let handle = ChannelHandle::new(ChannelConfig::new(
142            channel_name.as_str().to_owned(),
143            schema,
144            ChannelMode::Ephemeral,
145        ));
146        let session = ChannelSession { handle };
147        self.insert_or_reuse_session(channel_name, session)
148    }
149
150    fn lookup_session(
151        &self,
152        channel_name: &ChannelName,
153    ) -> Result<Option<ChannelSession>, AionSurfaceError> {
154        let session = {
155            let channels = self.lock_channels(channel_name)?;
156            channels
157                .get(channel_name.as_str())
158                .map(|state| state.session.clone())
159        };
160        Ok(session)
161    }
162
163    fn insert_or_reuse_session(
164        &self,
165        channel_name: &ChannelName,
166        session: ChannelSession,
167    ) -> Result<ChannelSession, AionSurfaceError> {
168        let mut channels = self.lock_channels(channel_name)?;
169        let state = channels
170            .entry(channel_name.as_str().to_owned())
171            .or_insert_with(|| ChannelState::new(session));
172        let stored = state.session.clone();
173        drop(channels);
174        Ok(stored)
175    }
176
177    fn insert_entry(
178        &self,
179        channel_name: &ChannelName,
180        entry: &Arc<WorkerEntry>,
181    ) -> Result<(), AionSurfaceError> {
182        let mut channels = self.lock_channels(channel_name)?;
183        let state = channels
184            .get_mut(channel_name.as_str())
185            .ok_or_else(|| lifecycle_failed(channel_name, "dispatch channel missing"))?;
186        state.entries.push(Arc::downgrade(entry));
187        drop(channels);
188        Ok(())
189    }
190
191    fn remove_inactive(&self, channel_name: &ChannelName) -> Result<(), AionSurfaceError> {
192        let mut channels = self.lock_channels(channel_name)?;
193        if let Some(state) = channels.get_mut(channel_name.as_str()) {
194            state.retain_active();
195        }
196        drop(channels);
197        Ok(())
198    }
199
200    fn snapshot(
201        &self,
202        channel_name: &ChannelName,
203    ) -> Result<Vec<DispatchWorker>, AionSurfaceError> {
204        let mut channels = self.lock_channels(channel_name)?;
205        let Some(state) = channels.get_mut(channel_name.as_str()) else {
206            return Ok(Vec::new());
207        };
208        state.retain_active();
209        let workers = state
210            .entries
211            .iter()
212            .filter_map(Weak::upgrade)
213            .map(|entry| entry.to_dispatch_worker())
214            .collect();
215        drop(channels);
216        Ok(workers)
217    }
218
219    fn lock_channels(
220        &self,
221        channel_name: &ChannelName,
222    ) -> Result<MutexGuard<'_, BTreeMap<String, ChannelState>>, AionSurfaceError> {
223        self.inner
224            .channels
225            .lock()
226            .map_err(|error| lifecycle_failed(channel_name, error))
227    }
228
229    fn supervisor_for(
230        &self,
231        channel_name: &ChannelName,
232    ) -> Result<ConversationSupervisor, AionSurfaceError> {
233        if let Some(supervisor) = self.inner.supervisor.get() {
234            return Ok(supervisor.clone());
235        }
236        let supervisor =
237            ConversationSupervisor::new().map_err(|error| lifecycle_failed(channel_name, error))?;
238        if self.inner.supervisor.set(supervisor.clone()).is_ok() {
239            Ok(supervisor)
240        } else {
241            self.inner.supervisor.get().cloned().ok_or_else(|| {
242                lifecycle_failed(channel_name, "worker process supervisor unavailable")
243            })
244        }
245    }
246}
247
248impl Default for WorkerContext {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254impl std::fmt::Debug for WorkerContext {
255    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        formatter
257            .debug_struct("WorkerContext")
258            .finish_non_exhaustive()
259    }
260}
261
262impl DispatchWorkerPool for WorkerContext {
263    fn workers_for(
264        &self,
265        channel_name: &ChannelName,
266        request: &ActivityRequest,
267    ) -> Result<Vec<DispatchWorker>, AionSurfaceError> {
268        let _ = request;
269        self.snapshot(channel_name)
270    }
271}
272
273pub type WorkerPool = WorkerContext;
274
275#[derive(Debug)]
276pub struct WorkerRegistration {
277    context: WorkerContext,
278    entry: Arc<WorkerEntry>,
279    link_monitor: Option<WorkerLinkMonitor>,
280}
281
282impl WorkerRegistration {
283    const fn new(
284        context: WorkerContext,
285        entry: Arc<WorkerEntry>,
286        link_monitor: WorkerLinkMonitor,
287    ) -> Self {
288        Self {
289            context,
290            entry,
291            link_monitor: Some(link_monitor),
292        }
293    }
294
295    #[must_use]
296    pub fn worker_id(&self) -> &str {
297        self.entry.worker_id.as_str()
298    }
299
300    #[must_use]
301    pub fn channel_name(&self) -> &ChannelName {
302        &self.entry.channel_name
303    }
304
305    #[must_use]
306    pub fn participant(&self) -> ParticipantPid {
307        self.entry.participant
308    }
309
310    #[must_use]
311    pub fn capacity(&self) -> &WorkerCapacity {
312        &self.entry.capacity
313    }
314
315    #[must_use]
316    pub fn current_in_flight(&self) -> u32 {
317        self.entry.current_in_flight.load(Ordering::Acquire)
318    }
319
320    pub fn set_in_flight(&self, count: u32) {
321        self.entry.current_in_flight.store(count, Ordering::Release);
322    }
323
324    /// # Errors
325    /// Returns when the subscription inbox cannot be read.
326    pub fn try_next(&self) -> Result<Option<crate::envelope::Envelope>, AionSurfaceError> {
327        let subscription = self
328            .entry
329            .subscription
330            .lock()
331            .map_err(|error| lifecycle_failed(&self.entry.channel_name, error))?;
332        subscription.as_ref().map_or(Ok(None), |subscription| {
333            subscription
334                .try_next()
335                .map_err(|error| lifecycle_failed(&self.entry.channel_name, error))
336        })
337    }
338
339    /// # Errors
340    /// Returns when the worker pool cannot be updated.
341    pub fn unregister(mut self) -> Result<(), AionSurfaceError> {
342        self.deactivate()?;
343        self.entry.drop_subscription();
344        if let Some(mut monitor) = self.link_monitor.take() {
345            monitor.shutdown();
346        }
347        Ok(())
348    }
349
350    fn deactivate(&self) -> Result<(), AionSurfaceError> {
351        self.entry.active.store(false, Ordering::Release);
352        self.context.remove_inactive(&self.entry.channel_name)
353    }
354}
355
356impl Drop for WorkerRegistration {
357    fn drop(&mut self) {
358        self.entry.active.store(false, Ordering::Release);
359        let _ = self.context.remove_inactive(&self.entry.channel_name);
360        self.entry.drop_subscription();
361        if let Some(mut monitor) = self.link_monitor.take() {
362            monitor.shutdown();
363        }
364    }
365}
366
367#[derive(Debug, Default)]
368struct WorkerContextInner {
369    channels: Mutex<BTreeMap<String, ChannelState>>,
370    next_worker: AtomicU64,
371    supervisor: OnceLock<ConversationSupervisor>,
372}
373
374#[derive(Clone, Debug)]
375struct ChannelSession {
376    handle: ChannelHandle,
377}
378
379#[derive(Debug)]
380struct ChannelState {
381    session: ChannelSession,
382    entries: Vec<Weak<WorkerEntry>>,
383}
384
385impl ChannelState {
386    const fn new(session: ChannelSession) -> Self {
387        Self {
388            session,
389            entries: Vec::new(),
390        }
391    }
392
393    fn retain_active(&mut self) {
394        self.entries.retain(|entry| {
395            entry
396                .upgrade()
397                .is_some_and(|entry| entry.active.load(Ordering::Acquire))
398        });
399    }
400}
401
402#[derive(Debug)]
403struct WorkerEntry {
404    channel_name: ChannelName,
405    worker_id: String,
406    participant: ParticipantPid,
407    capacity: WorkerCapacity,
408    subscription: Mutex<Option<SubscriptionHandle>>,
409    current_in_flight: AtomicU32,
410    active: AtomicBool,
411}
412
413impl WorkerEntry {
414    pub(super) fn drop_subscription(&self) {
415        if let Ok(mut subscription) = self.subscription.lock() {
416            subscription.take();
417        }
418    }
419
420    fn to_dispatch_worker(&self) -> DispatchWorker {
421        let max_in_flight =
422            u32::try_from(self.capacity.max_concurrent).map_or(u32::MAX, |value| value);
423        let consumer_state = ConsumerStateView::new(
424            ConsumerId::new(self.worker_id.clone()),
425            self.current_in_flight.load(Ordering::Acquire),
426            max_in_flight,
427            0,
428            self.capacity.activity_types.clone(),
429        );
430        DispatchWorker::with_consumer_state(
431            self.worker_id.clone(),
432            self.participant,
433            consumer_state,
434        )
435    }
436}
437
438#[must_use]
439pub fn default_worker_context() -> &'static WorkerContext {
440    static DEFAULT_CONTEXT: OnceLock<WorkerContext> = OnceLock::new();
441    DEFAULT_CONTEXT.get_or_init(WorkerContext::new)
442}
443
444/// # Errors
445/// Returns when the worker cannot subscribe to its dispatch channel.
446pub fn register_worker(
447    namespace: &str,
448    task_queue: &str,
449    capacity: WorkerCapacity,
450) -> Result<WorkerRegistration, AionSurfaceError> {
451    default_worker_context().register_worker(namespace, task_queue, capacity)
452}
453
454fn lifecycle_failed(
455    channel_name: &ChannelName,
456    message: impl std::fmt::Display,
457) -> AionSurfaceError {
458    AionSurfaceError::ChannelLifecycleError {
459        channel_name: String::from(channel_name.clone()),
460        message: message.to_string(),
461    }
462}
463
464#[cfg(test)]
465mod tests;