1use std::collections::BTreeMap;
2use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
3use std::sync::{Arc, Mutex, MutexGuard, OnceLock, Weak};
4
5use super::channels::{ChannelName, dispatch_channel};
6use super::codec::dispatch_request_schema;
7use super::dispatch::{DispatchWorker, DispatchWorkerPool};
8use super::error::AionSurfaceError;
9use super::types::{ActivityRequest, WorkerCapacity};
10use crate::channel::{ChannelConfig, ChannelHandle, ChannelMode, SubscriptionHandle};
11use crate::conversation::{ConversationSupervisor, ParticipantPid};
12use crate::routing::{ConsumerId, ConsumerStateView};
13
14mod link;
15use link::WorkerLinkMonitor;
16
17#[derive(Clone)]
18pub struct WorkerContext {
19 inner: Arc<WorkerContextInner>,
20}
21
22impl WorkerContext {
23 #[must_use]
32 pub fn new() -> Self {
33 Self {
34 inner: Arc::new(WorkerContextInner::default()),
35 }
36 }
37
38 pub fn register_worker(
41 &self,
42 namespace: &str,
43 task_queue: &str,
44 capacity: WorkerCapacity,
45 ) -> Result<WorkerRegistration, AionSurfaceError> {
46 let channel_name = dispatch_channel(namespace, task_queue)?;
47 let sequence = self.next_sequence();
48 let participant = link::spawn_worker_process(self, &channel_name)?;
49 self.register_worker_on_channel(
50 &channel_name,
51 format!("worker-{sequence}"),
52 participant,
53 capacity,
54 Some(participant),
55 )
56 }
57
58 pub fn register_worker_with_participant(
61 &self,
62 namespace: &str,
63 task_queue: &str,
64 worker_id: impl Into<String>,
65 participant: ParticipantPid,
66 capacity: WorkerCapacity,
67 ) -> Result<WorkerRegistration, AionSurfaceError> {
68 let channel_name = dispatch_channel(namespace, task_queue)?;
69 self.register_worker_on_channel(
70 &channel_name,
71 worker_id.into(),
72 participant,
73 capacity,
74 None,
75 )
76 }
77
78 fn register_worker_on_channel(
79 &self,
80 channel_name: &ChannelName,
81 worker_id: String,
82 participant: ParticipantPid,
83 capacity: WorkerCapacity,
84 owned_participant: Option<ParticipantPid>,
85 ) -> Result<WorkerRegistration, AionSurfaceError> {
86 let session = self.session_for(channel_name)?;
87 let subscription = session
88 .handle
89 .subscribe()
90 .map_err(|error| lifecycle_failed(channel_name, error))?;
91 let subscription = Mutex::new(Some(subscription));
92 let entry = Arc::new(WorkerEntry {
93 channel_name: channel_name.clone(),
94 worker_id,
95 participant,
96 capacity,
97 subscription,
98 current_in_flight: AtomicU32::new(0),
99 active: AtomicBool::new(true),
100 });
101 let monitor = link::monitor_worker_process(
107 self.clone(),
108 channel_name,
109 participant,
110 Arc::downgrade(&entry),
111 owned_participant,
112 )?;
113 self.insert_entry(channel_name, &entry)?;
114 Ok(WorkerRegistration::new(self.clone(), entry, monitor))
115 }
116
117 pub fn workers_for_channel(
120 &self,
121 channel_name: &ChannelName,
122 request: &ActivityRequest,
123 ) -> Result<Vec<DispatchWorker>, AionSurfaceError> {
124 <Self as DispatchWorkerPool>::workers_for(self, channel_name, request)
125 }
126
127 fn next_sequence(&self) -> u64 {
128 self.inner
129 .next_worker
130 .fetch_add(1, Ordering::Relaxed)
131 .saturating_add(1)
132 }
133
134 fn session_for(&self, channel_name: &ChannelName) -> Result<ChannelSession, AionSurfaceError> {
135 if let Some(session) = self.lookup_session(channel_name)? {
136 return Ok(session);
137 }
138
139 let schema =
140 dispatch_request_schema().map_err(|error| lifecycle_failed(channel_name, error))?;
141 let handle = ChannelHandle::new(ChannelConfig::new(
142 channel_name.as_str().to_owned(),
143 schema,
144 ChannelMode::Ephemeral,
145 ));
146 let session = ChannelSession { handle };
147 self.insert_or_reuse_session(channel_name, session)
148 }
149
150 fn lookup_session(
151 &self,
152 channel_name: &ChannelName,
153 ) -> Result<Option<ChannelSession>, AionSurfaceError> {
154 let session = {
155 let channels = self.lock_channels(channel_name)?;
156 channels
157 .get(channel_name.as_str())
158 .map(|state| state.session.clone())
159 };
160 Ok(session)
161 }
162
163 fn insert_or_reuse_session(
164 &self,
165 channel_name: &ChannelName,
166 session: ChannelSession,
167 ) -> Result<ChannelSession, AionSurfaceError> {
168 let mut channels = self.lock_channels(channel_name)?;
169 let state = channels
170 .entry(channel_name.as_str().to_owned())
171 .or_insert_with(|| ChannelState::new(session));
172 let stored = state.session.clone();
173 drop(channels);
174 Ok(stored)
175 }
176
177 fn insert_entry(
178 &self,
179 channel_name: &ChannelName,
180 entry: &Arc<WorkerEntry>,
181 ) -> Result<(), AionSurfaceError> {
182 let mut channels = self.lock_channels(channel_name)?;
183 let state = channels
184 .get_mut(channel_name.as_str())
185 .ok_or_else(|| lifecycle_failed(channel_name, "dispatch channel missing"))?;
186 state.entries.push(Arc::downgrade(entry));
187 drop(channels);
188 Ok(())
189 }
190
191 fn remove_inactive(&self, channel_name: &ChannelName) -> Result<(), AionSurfaceError> {
192 let mut channels = self.lock_channels(channel_name)?;
193 if let Some(state) = channels.get_mut(channel_name.as_str()) {
194 state.retain_active();
195 }
196 drop(channels);
197 Ok(())
198 }
199
200 fn snapshot(
201 &self,
202 channel_name: &ChannelName,
203 ) -> Result<Vec<DispatchWorker>, AionSurfaceError> {
204 let mut channels = self.lock_channels(channel_name)?;
205 let Some(state) = channels.get_mut(channel_name.as_str()) else {
206 return Ok(Vec::new());
207 };
208 state.retain_active();
209 let workers = state
210 .entries
211 .iter()
212 .filter_map(Weak::upgrade)
213 .map(|entry| entry.to_dispatch_worker())
214 .collect();
215 drop(channels);
216 Ok(workers)
217 }
218
219 fn lock_channels(
220 &self,
221 channel_name: &ChannelName,
222 ) -> Result<MutexGuard<'_, BTreeMap<String, ChannelState>>, AionSurfaceError> {
223 self.inner
224 .channels
225 .lock()
226 .map_err(|error| lifecycle_failed(channel_name, error))
227 }
228
229 fn supervisor_for(
230 &self,
231 channel_name: &ChannelName,
232 ) -> Result<ConversationSupervisor, AionSurfaceError> {
233 if let Some(supervisor) = self.inner.supervisor.get() {
234 return Ok(supervisor.clone());
235 }
236 let supervisor =
237 ConversationSupervisor::new().map_err(|error| lifecycle_failed(channel_name, error))?;
238 if self.inner.supervisor.set(supervisor.clone()).is_ok() {
239 Ok(supervisor)
240 } else {
241 self.inner.supervisor.get().cloned().ok_or_else(|| {
242 lifecycle_failed(channel_name, "worker process supervisor unavailable")
243 })
244 }
245 }
246}
247
248impl Default for WorkerContext {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl std::fmt::Debug for WorkerContext {
255 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 formatter
257 .debug_struct("WorkerContext")
258 .finish_non_exhaustive()
259 }
260}
261
262impl DispatchWorkerPool for WorkerContext {
263 fn workers_for(
264 &self,
265 channel_name: &ChannelName,
266 request: &ActivityRequest,
267 ) -> Result<Vec<DispatchWorker>, AionSurfaceError> {
268 let _ = request;
269 self.snapshot(channel_name)
270 }
271}
272
273pub type WorkerPool = WorkerContext;
274
275#[derive(Debug)]
276pub struct WorkerRegistration {
277 context: WorkerContext,
278 entry: Arc<WorkerEntry>,
279 link_monitor: Option<WorkerLinkMonitor>,
280}
281
282impl WorkerRegistration {
283 const fn new(
284 context: WorkerContext,
285 entry: Arc<WorkerEntry>,
286 link_monitor: WorkerLinkMonitor,
287 ) -> Self {
288 Self {
289 context,
290 entry,
291 link_monitor: Some(link_monitor),
292 }
293 }
294
295 #[must_use]
296 pub fn worker_id(&self) -> &str {
297 self.entry.worker_id.as_str()
298 }
299
300 #[must_use]
301 pub fn channel_name(&self) -> &ChannelName {
302 &self.entry.channel_name
303 }
304
305 #[must_use]
306 pub fn participant(&self) -> ParticipantPid {
307 self.entry.participant
308 }
309
310 #[must_use]
311 pub fn capacity(&self) -> &WorkerCapacity {
312 &self.entry.capacity
313 }
314
315 #[must_use]
316 pub fn current_in_flight(&self) -> u32 {
317 self.entry.current_in_flight.load(Ordering::Acquire)
318 }
319
320 pub fn set_in_flight(&self, count: u32) {
321 self.entry.current_in_flight.store(count, Ordering::Release);
322 }
323
324 pub fn try_next(&self) -> Result<Option<crate::envelope::Envelope>, AionSurfaceError> {
327 let subscription = self
328 .entry
329 .subscription
330 .lock()
331 .map_err(|error| lifecycle_failed(&self.entry.channel_name, error))?;
332 subscription.as_ref().map_or(Ok(None), |subscription| {
333 subscription
334 .try_next()
335 .map_err(|error| lifecycle_failed(&self.entry.channel_name, error))
336 })
337 }
338
339 pub fn unregister(mut self) -> Result<(), AionSurfaceError> {
342 self.deactivate()?;
343 self.entry.drop_subscription();
344 if let Some(mut monitor) = self.link_monitor.take() {
345 monitor.shutdown();
346 }
347 Ok(())
348 }
349
350 fn deactivate(&self) -> Result<(), AionSurfaceError> {
351 self.entry.active.store(false, Ordering::Release);
352 self.context.remove_inactive(&self.entry.channel_name)
353 }
354}
355
356impl Drop for WorkerRegistration {
357 fn drop(&mut self) {
358 self.entry.active.store(false, Ordering::Release);
359 let _ = self.context.remove_inactive(&self.entry.channel_name);
360 self.entry.drop_subscription();
361 if let Some(mut monitor) = self.link_monitor.take() {
362 monitor.shutdown();
363 }
364 }
365}
366
367#[derive(Debug, Default)]
368struct WorkerContextInner {
369 channels: Mutex<BTreeMap<String, ChannelState>>,
370 next_worker: AtomicU64,
371 supervisor: OnceLock<ConversationSupervisor>,
372}
373
374#[derive(Clone, Debug)]
375struct ChannelSession {
376 handle: ChannelHandle,
377}
378
379#[derive(Debug)]
380struct ChannelState {
381 session: ChannelSession,
382 entries: Vec<Weak<WorkerEntry>>,
383}
384
385impl ChannelState {
386 const fn new(session: ChannelSession) -> Self {
387 Self {
388 session,
389 entries: Vec::new(),
390 }
391 }
392
393 fn retain_active(&mut self) {
394 self.entries.retain(|entry| {
395 entry
396 .upgrade()
397 .is_some_and(|entry| entry.active.load(Ordering::Acquire))
398 });
399 }
400}
401
402#[derive(Debug)]
403struct WorkerEntry {
404 channel_name: ChannelName,
405 worker_id: String,
406 participant: ParticipantPid,
407 capacity: WorkerCapacity,
408 subscription: Mutex<Option<SubscriptionHandle>>,
409 current_in_flight: AtomicU32,
410 active: AtomicBool,
411}
412
413impl WorkerEntry {
414 pub(super) fn drop_subscription(&self) {
415 if let Ok(mut subscription) = self.subscription.lock() {
416 subscription.take();
417 }
418 }
419
420 fn to_dispatch_worker(&self) -> DispatchWorker {
421 let max_in_flight =
422 u32::try_from(self.capacity.max_concurrent).map_or(u32::MAX, |value| value);
423 let consumer_state = ConsumerStateView::new(
424 ConsumerId::new(self.worker_id.clone()),
425 self.current_in_flight.load(Ordering::Acquire),
426 max_in_flight,
427 0,
428 self.capacity.activity_types.clone(),
429 );
430 DispatchWorker::with_consumer_state(
431 self.worker_id.clone(),
432 self.participant,
433 consumer_state,
434 )
435 }
436}
437
438#[must_use]
439pub fn default_worker_context() -> &'static WorkerContext {
440 static DEFAULT_CONTEXT: OnceLock<WorkerContext> = OnceLock::new();
441 DEFAULT_CONTEXT.get_or_init(WorkerContext::new)
442}
443
444pub fn register_worker(
447 namespace: &str,
448 task_queue: &str,
449 capacity: WorkerCapacity,
450) -> Result<WorkerRegistration, AionSurfaceError> {
451 default_worker_context().register_worker(namespace, task_queue, capacity)
452}
453
454fn lifecycle_failed(
455 channel_name: &ChannelName,
456 message: impl std::fmt::Display,
457) -> AionSurfaceError {
458 AionSurfaceError::ChannelLifecycleError {
459 channel_name: String::from(channel_name.clone()),
460 message: message.to_string(),
461 }
462}
463
464#[cfg(test)]
465mod tests;