Skip to main content

murk_engine/
egress.rs

1//! Egress worker pool for RealtimeAsync observation extraction.
2//!
3//! Each worker receives `ObsTask` requests via a crossbeam channel,
4//! pins to the latest snapshot epoch, executes the observation plan,
5//! unpins, and sends the result back via a bounded(1) reply channel.
6//!
7//! Workers allocate their own output buffers and return them via the
8//! reply channel. The caller copies into its buffer. This maintains
9//! `#![forbid(unsafe_code)]` at the cost of a ~10μs memcpy (negligible
10//! vs the 16.7ms tick budget).
11
12use std::sync::Arc;
13
14use crossbeam_channel::Receiver;
15
16use crate::epoch::{EpochCounter, WorkerEpoch};
17use crate::ring::SnapshotRing;
18
19use murk_core::Coord;
20use murk_obs::{ObsMetadata, ObsPlan};
21use murk_space::Space;
22
23/// A task dispatched to an egress worker.
24pub(crate) enum ObsTask {
25    /// Simple observation (all Fixed regions).
26    Simple {
27        plan: Arc<ObsPlan>,
28        output_len: usize,
29        mask_len: usize,
30        reply: crossbeam_channel::Sender<ObsResult>,
31    },
32    /// Agent-relative observation (has AgentRelative regions).
33    Agents {
34        plan: Arc<ObsPlan>,
35        space: Arc<dyn Space>,
36        agent_centers: Vec<Coord>,
37        output_len: usize,
38        mask_len: usize,
39        reply: crossbeam_channel::Sender<ObsResult>,
40    },
41}
42
43/// Result of an egress worker's observation execution.
44#[derive(Debug)]
45pub(crate) enum ObsResult {
46    /// Simple plan result: one metadata + output + mask buffers.
47    Simple {
48        metadata: ObsMetadata,
49        output: Vec<f32>,
50        mask: Vec<u8>,
51    },
52    /// Agent plan result: one metadata per agent + output + mask buffers.
53    Agents {
54        metadata: Vec<ObsMetadata>,
55        output: Vec<f32>,
56        mask: Vec<u8>,
57    },
58    /// Plan execution failed.
59    Error(murk_core::error::ObsError),
60}
61
62/// Main loop for an egress worker thread, using an index into a shared
63/// `Arc<[WorkerEpoch]>` array. This ensures the tick thread's stall
64/// detector and the worker see the same `WorkerEpoch` instance.
65pub(crate) fn worker_loop_indexed(
66    task_rx: Receiver<ObsTask>,
67    ring: Arc<SnapshotRing>,
68    epoch_counter: Arc<EpochCounter>,
69    worker_epochs: Arc<[WorkerEpoch]>,
70    worker_index: usize,
71) {
72    let worker_epoch = &worker_epochs[worker_index];
73    worker_loop_inner(task_rx, ring, epoch_counter, worker_epoch);
74}
75
76/// Main loop for an egress worker thread (Arc variant, used in tests).
77///
78/// Runs until the task channel is closed (sender dropped). Each
79/// iteration: recv task → pin epoch → execute plan → unpin → reply.
80#[cfg(test)]
81pub(crate) fn worker_loop(
82    task_rx: Receiver<ObsTask>,
83    ring: Arc<SnapshotRing>,
84    epoch_counter: Arc<EpochCounter>,
85    worker_epoch: Arc<WorkerEpoch>,
86) {
87    worker_loop_inner(task_rx, ring, epoch_counter, &worker_epoch);
88}
89
90fn worker_loop_inner(
91    task_rx: Receiver<ObsTask>,
92    ring: Arc<SnapshotRing>,
93    epoch_counter: Arc<EpochCounter>,
94    worker_epoch: &WorkerEpoch,
95) {
96    while let Ok(task) = task_rx.recv() {
97        // Check for cooperative cancellation before starting.
98        if worker_epoch.is_cancelled() {
99            worker_epoch.clear_cancel();
100            send_error(
101                &task,
102                murk_core::error::ObsError::ExecutionFailed {
103                    reason: "worker cancelled before execution".into(),
104                },
105            );
106            continue;
107        }
108
109        // Get latest snapshot.
110        let snapshot = match ring.latest() {
111            Some(snap) => snap,
112            None => {
113                send_error(
114                    &task,
115                    murk_core::error::ObsError::ExecutionFailed {
116                        reason: "no snapshot available".into(),
117                    },
118                );
119                continue;
120            }
121        };
122
123        // Pin to the current epoch.
124        let epoch = epoch_counter.current();
125        worker_epoch.pin(epoch);
126
127        // Execute the plan. Always unpin, even on error.
128        let result = execute_task(&task, &*snapshot, epoch_counter.current());
129        worker_epoch.unpin();
130
131        // Send result back.
132        match task {
133            ObsTask::Simple { reply, .. } | ObsTask::Agents { reply, .. } => {
134                let _ = reply.send(result);
135            }
136        }
137    }
138    // Channel closed — worker exits cleanly.
139}
140
141/// Execute a task against a snapshot, returning the result.
142fn execute_task(
143    task: &ObsTask,
144    snapshot: &dyn murk_core::traits::SnapshotAccess,
145    current_tick_val: u64,
146) -> ObsResult {
147    let engine_tick = Some(murk_core::id::TickId(current_tick_val));
148
149    match task {
150        ObsTask::Simple {
151            plan,
152            output_len,
153            mask_len,
154            ..
155        } => {
156            let mut output = vec![0.0f32; *output_len];
157            let mut mask = vec![0u8; *mask_len];
158            match plan.execute(snapshot, engine_tick, &mut output, &mut mask) {
159                Ok(metadata) => ObsResult::Simple {
160                    metadata,
161                    output,
162                    mask,
163                },
164                Err(e) => ObsResult::Error(e),
165            }
166        }
167        ObsTask::Agents {
168            plan,
169            space,
170            agent_centers,
171            output_len,
172            mask_len,
173            ..
174        } => {
175            let n_agents = agent_centers.len();
176            let mut output = vec![0.0f32; output_len * n_agents];
177            let mut mask = vec![0u8; mask_len * n_agents];
178            match plan.execute_agents(
179                snapshot,
180                space.as_ref(),
181                agent_centers,
182                engine_tick,
183                &mut output,
184                &mut mask,
185            ) {
186                Ok(metadata) => ObsResult::Agents {
187                    metadata,
188                    output,
189                    mask,
190                },
191                Err(e) => ObsResult::Error(e),
192            }
193        }
194    }
195}
196
197/// Send an error result back through the task's reply channel.
198fn send_error(task: &ObsTask, err: murk_core::error::ObsError) {
199    match task {
200        ObsTask::Simple { reply, .. } => {
201            let _ = reply.send(ObsResult::Error(err));
202        }
203        ObsTask::Agents { reply, .. } => {
204            let _ = reply.send(ObsResult::Error(err));
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use murk_arena::config::ArenaConfig;
213    use murk_arena::pingpong::PingPongArena;
214    use murk_arena::static_arena::StaticArena;
215    use murk_core::id::{FieldId, ParameterVersion, TickId};
216    use murk_core::traits::FieldWriter as _;
217    use murk_core::{BoundaryBehavior, FieldDef, FieldMutability, FieldType};
218    use murk_obs::spec::ObsRegion;
219    use murk_obs::{ObsEntry, ObsSpec};
220    use murk_space::{EdgeBehavior, Line1D};
221    use std::thread;
222
223    fn make_test_snapshot(tick: u64, value: f32, cells: u32) -> murk_arena::OwnedSnapshot {
224        let config = ArenaConfig::new(cells);
225        let field_defs = vec![(
226            FieldId(0),
227            FieldDef {
228                name: "energy".into(),
229                field_type: FieldType::Scalar,
230                mutability: FieldMutability::PerTick,
231                units: None,
232                bounds: None,
233                boundary_behavior: BoundaryBehavior::Clamp,
234            },
235        )];
236        let static_arena = StaticArena::new(&[]).into_shared();
237        let mut arena = PingPongArena::new(config, field_defs, static_arena).unwrap();
238        {
239            let mut guard = arena.begin_tick().unwrap();
240            let data = guard.writer.write(FieldId(0)).unwrap();
241            data.fill(value);
242        }
243        arena.publish(TickId(tick), ParameterVersion(0));
244        arena.owned_snapshot()
245    }
246
247    #[test]
248    fn worker_executes_simple_plan() {
249        let cells = 10u32;
250        let space = Line1D::new(cells, EdgeBehavior::Absorb).unwrap();
251
252        // Build obs plan.
253        let spec = ObsSpec {
254            entries: vec![ObsEntry {
255                field_id: FieldId(0),
256                region: ObsRegion::Fixed(murk_space::RegionSpec::All),
257                pool: None,
258                transform: murk_obs::spec::ObsTransform::Identity,
259                dtype: murk_obs::spec::ObsDtype::F32,
260            }],
261        };
262        let plan_result = ObsPlan::compile(&spec, &space).unwrap();
263        let plan = Arc::new(plan_result.plan);
264        let output_len = plan_result.output_len;
265        let mask_len = plan_result.mask_len;
266
267        // Set up ring + epoch.
268        let ring = Arc::new(SnapshotRing::new(4));
269        ring.push(make_test_snapshot(1, 42.0, cells));
270
271        let epoch_counter = Arc::new(EpochCounter::new());
272        epoch_counter.advance();
273
274        let worker_epoch = Arc::new(WorkerEpoch::new(0));
275
276        // Create task channel.
277        let (task_tx, task_rx) = crossbeam_channel::bounded(4);
278        let (reply_tx, reply_rx) = crossbeam_channel::bounded(1);
279
280        // Spawn worker.
281        let ring_c = Arc::clone(&ring);
282        let epoch_c = Arc::clone(&epoch_counter);
283        let we_c = Arc::clone(&worker_epoch);
284        let handle = thread::spawn(move || {
285            worker_loop(task_rx, ring_c, epoch_c, we_c);
286        });
287
288        // Send task.
289        task_tx
290            .send(ObsTask::Simple {
291                plan,
292                output_len,
293                mask_len,
294                reply: reply_tx,
295            })
296            .unwrap();
297
298        // Get result.
299        let result = reply_rx.recv().unwrap();
300        match result {
301            ObsResult::Simple {
302                metadata,
303                output,
304                mask,
305            } => {
306                assert_eq!(metadata.tick_id, TickId(1));
307                assert_eq!(output.len(), output_len);
308                assert!(output.iter().all(|&v| v == 42.0));
309                assert_eq!(mask.len(), mask_len);
310            }
311            other => panic!("expected Simple result, got error: {other:?}"),
312        }
313
314        // Worker should be unpinned.
315        assert!(!worker_epoch.is_pinned());
316
317        // Drop sender to close channel and join worker.
318        drop(task_tx);
319        handle.join().unwrap();
320    }
321
322    #[test]
323    fn worker_unpins_on_error() {
324        // With an empty ring, the worker should return an error but still unpin.
325        let ring = Arc::new(SnapshotRing::new(4));
326        let epoch_counter = Arc::new(EpochCounter::new());
327        let worker_epoch = Arc::new(WorkerEpoch::new(0));
328
329        let (task_tx, task_rx) = crossbeam_channel::bounded(4);
330        let (reply_tx, reply_rx) = crossbeam_channel::bounded(1);
331
332        let ring_c = Arc::clone(&ring);
333        let epoch_c = Arc::clone(&epoch_counter);
334        let we_c = Arc::clone(&worker_epoch);
335        let handle = thread::spawn(move || {
336            worker_loop(task_rx, ring_c, epoch_c, we_c);
337        });
338
339        // Build a dummy plan — doesn't matter, we'll error before execute.
340        let space = Line1D::new(10, EdgeBehavior::Absorb).unwrap();
341        let spec = ObsSpec {
342            entries: vec![ObsEntry {
343                field_id: FieldId(0),
344                region: ObsRegion::Fixed(murk_space::RegionSpec::All),
345                pool: None,
346                transform: murk_obs::spec::ObsTransform::Identity,
347                dtype: murk_obs::spec::ObsDtype::F32,
348            }],
349        };
350        let plan_result = ObsPlan::compile(&spec, &space).unwrap();
351
352        task_tx
353            .send(ObsTask::Simple {
354                plan: Arc::new(plan_result.plan),
355                output_len: plan_result.output_len,
356                mask_len: plan_result.mask_len,
357                reply: reply_tx,
358            })
359            .unwrap();
360
361        let result = reply_rx.recv().unwrap();
362        assert!(matches!(result, ObsResult::Error(_)));
363        assert!(!worker_epoch.is_pinned());
364
365        drop(task_tx);
366        handle.join().unwrap();
367    }
368
369    #[test]
370    fn worker_exits_on_channel_close() {
371        let ring = Arc::new(SnapshotRing::new(4));
372        let epoch_counter = Arc::new(EpochCounter::new());
373        let worker_epoch = Arc::new(WorkerEpoch::new(0));
374
375        let (task_tx, task_rx) = crossbeam_channel::bounded::<ObsTask>(4);
376
377        let ring_c = Arc::clone(&ring);
378        let epoch_c = Arc::clone(&epoch_counter);
379        let we_c = Arc::clone(&worker_epoch);
380        let handle = thread::spawn(move || {
381            worker_loop(task_rx, ring_c, epoch_c, we_c);
382        });
383
384        // Drop sender — worker should exit.
385        drop(task_tx);
386        handle.join().unwrap();
387    }
388}