1use std::sync::Arc;
13
14use crossbeam_channel::Receiver;
15
16use crate::epoch::{EpochCounter, WorkerEpoch};
17use crate::ring::SnapshotRing;
18
19use murk_core::Coord;
20use murk_obs::{ObsMetadata, ObsPlan};
21use murk_space::Space;
22
23pub(crate) enum ObsTask {
25 Simple {
27 plan: Arc<ObsPlan>,
28 output_len: usize,
29 mask_len: usize,
30 reply: crossbeam_channel::Sender<ObsResult>,
31 },
32 Agents {
34 plan: Arc<ObsPlan>,
35 space: Arc<dyn Space>,
36 agent_centers: Vec<Coord>,
37 output_len: usize,
38 mask_len: usize,
39 reply: crossbeam_channel::Sender<ObsResult>,
40 },
41}
42
43#[derive(Debug)]
45pub(crate) enum ObsResult {
46 Simple {
48 metadata: ObsMetadata,
49 output: Vec<f32>,
50 mask: Vec<u8>,
51 },
52 Agents {
54 metadata: Vec<ObsMetadata>,
55 output: Vec<f32>,
56 mask: Vec<u8>,
57 },
58 Error(murk_core::error::ObsError),
60}
61
62pub(crate) fn worker_loop_indexed(
66 task_rx: Receiver<ObsTask>,
67 ring: Arc<SnapshotRing>,
68 epoch_counter: Arc<EpochCounter>,
69 worker_epochs: Arc<[WorkerEpoch]>,
70 worker_index: usize,
71) {
72 let worker_epoch = &worker_epochs[worker_index];
73 worker_loop_inner(task_rx, ring, epoch_counter, worker_epoch);
74}
75
76#[cfg(test)]
81pub(crate) fn worker_loop(
82 task_rx: Receiver<ObsTask>,
83 ring: Arc<SnapshotRing>,
84 epoch_counter: Arc<EpochCounter>,
85 worker_epoch: Arc<WorkerEpoch>,
86) {
87 worker_loop_inner(task_rx, ring, epoch_counter, &worker_epoch);
88}
89
90fn worker_loop_inner(
91 task_rx: Receiver<ObsTask>,
92 ring: Arc<SnapshotRing>,
93 epoch_counter: Arc<EpochCounter>,
94 worker_epoch: &WorkerEpoch,
95) {
96 while let Ok(task) = task_rx.recv() {
97 if worker_epoch.is_cancelled() {
99 worker_epoch.clear_cancel();
100 send_error(
101 &task,
102 murk_core::error::ObsError::ExecutionFailed {
103 reason: "worker cancelled before execution".into(),
104 },
105 );
106 continue;
107 }
108
109 let snapshot = match ring.latest() {
111 Some(snap) => snap,
112 None => {
113 send_error(
114 &task,
115 murk_core::error::ObsError::ExecutionFailed {
116 reason: "no snapshot available".into(),
117 },
118 );
119 continue;
120 }
121 };
122
123 let epoch = epoch_counter.current();
125 worker_epoch.pin(epoch);
126
127 let result = execute_task(&task, &*snapshot, epoch_counter.current());
129 worker_epoch.unpin();
130
131 match task {
133 ObsTask::Simple { reply, .. } | ObsTask::Agents { reply, .. } => {
134 let _ = reply.send(result);
135 }
136 }
137 }
138 }
140
141fn execute_task(
143 task: &ObsTask,
144 snapshot: &dyn murk_core::traits::SnapshotAccess,
145 current_tick_val: u64,
146) -> ObsResult {
147 let engine_tick = Some(murk_core::id::TickId(current_tick_val));
148
149 match task {
150 ObsTask::Simple {
151 plan,
152 output_len,
153 mask_len,
154 ..
155 } => {
156 let mut output = vec![0.0f32; *output_len];
157 let mut mask = vec![0u8; *mask_len];
158 match plan.execute(snapshot, engine_tick, &mut output, &mut mask) {
159 Ok(metadata) => ObsResult::Simple {
160 metadata,
161 output,
162 mask,
163 },
164 Err(e) => ObsResult::Error(e),
165 }
166 }
167 ObsTask::Agents {
168 plan,
169 space,
170 agent_centers,
171 output_len,
172 mask_len,
173 ..
174 } => {
175 let n_agents = agent_centers.len();
176 let mut output = vec![0.0f32; output_len * n_agents];
177 let mut mask = vec![0u8; mask_len * n_agents];
178 match plan.execute_agents(
179 snapshot,
180 space.as_ref(),
181 agent_centers,
182 engine_tick,
183 &mut output,
184 &mut mask,
185 ) {
186 Ok(metadata) => ObsResult::Agents {
187 metadata,
188 output,
189 mask,
190 },
191 Err(e) => ObsResult::Error(e),
192 }
193 }
194 }
195}
196
197fn send_error(task: &ObsTask, err: murk_core::error::ObsError) {
199 match task {
200 ObsTask::Simple { reply, .. } => {
201 let _ = reply.send(ObsResult::Error(err));
202 }
203 ObsTask::Agents { reply, .. } => {
204 let _ = reply.send(ObsResult::Error(err));
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use murk_arena::config::ArenaConfig;
213 use murk_arena::pingpong::PingPongArena;
214 use murk_arena::static_arena::StaticArena;
215 use murk_core::id::{FieldId, ParameterVersion, TickId};
216 use murk_core::traits::FieldWriter as _;
217 use murk_core::{BoundaryBehavior, FieldDef, FieldMutability, FieldType};
218 use murk_obs::spec::ObsRegion;
219 use murk_obs::{ObsEntry, ObsSpec};
220 use murk_space::{EdgeBehavior, Line1D};
221 use std::thread;
222
223 fn make_test_snapshot(tick: u64, value: f32, cells: u32) -> murk_arena::OwnedSnapshot {
224 let config = ArenaConfig::new(cells);
225 let field_defs = vec![(
226 FieldId(0),
227 FieldDef {
228 name: "energy".into(),
229 field_type: FieldType::Scalar,
230 mutability: FieldMutability::PerTick,
231 units: None,
232 bounds: None,
233 boundary_behavior: BoundaryBehavior::Clamp,
234 },
235 )];
236 let static_arena = StaticArena::new(&[]).into_shared();
237 let mut arena = PingPongArena::new(config, field_defs, static_arena).unwrap();
238 {
239 let mut guard = arena.begin_tick().unwrap();
240 let data = guard.writer.write(FieldId(0)).unwrap();
241 data.fill(value);
242 }
243 arena.publish(TickId(tick), ParameterVersion(0));
244 arena.owned_snapshot()
245 }
246
247 #[test]
248 fn worker_executes_simple_plan() {
249 let cells = 10u32;
250 let space = Line1D::new(cells, EdgeBehavior::Absorb).unwrap();
251
252 let spec = ObsSpec {
254 entries: vec![ObsEntry {
255 field_id: FieldId(0),
256 region: ObsRegion::Fixed(murk_space::RegionSpec::All),
257 pool: None,
258 transform: murk_obs::spec::ObsTransform::Identity,
259 dtype: murk_obs::spec::ObsDtype::F32,
260 }],
261 };
262 let plan_result = ObsPlan::compile(&spec, &space).unwrap();
263 let plan = Arc::new(plan_result.plan);
264 let output_len = plan_result.output_len;
265 let mask_len = plan_result.mask_len;
266
267 let ring = Arc::new(SnapshotRing::new(4));
269 ring.push(make_test_snapshot(1, 42.0, cells));
270
271 let epoch_counter = Arc::new(EpochCounter::new());
272 epoch_counter.advance();
273
274 let worker_epoch = Arc::new(WorkerEpoch::new(0));
275
276 let (task_tx, task_rx) = crossbeam_channel::bounded(4);
278 let (reply_tx, reply_rx) = crossbeam_channel::bounded(1);
279
280 let ring_c = Arc::clone(&ring);
282 let epoch_c = Arc::clone(&epoch_counter);
283 let we_c = Arc::clone(&worker_epoch);
284 let handle = thread::spawn(move || {
285 worker_loop(task_rx, ring_c, epoch_c, we_c);
286 });
287
288 task_tx
290 .send(ObsTask::Simple {
291 plan,
292 output_len,
293 mask_len,
294 reply: reply_tx,
295 })
296 .unwrap();
297
298 let result = reply_rx.recv().unwrap();
300 match result {
301 ObsResult::Simple {
302 metadata,
303 output,
304 mask,
305 } => {
306 assert_eq!(metadata.tick_id, TickId(1));
307 assert_eq!(output.len(), output_len);
308 assert!(output.iter().all(|&v| v == 42.0));
309 assert_eq!(mask.len(), mask_len);
310 }
311 other => panic!("expected Simple result, got error: {other:?}"),
312 }
313
314 assert!(!worker_epoch.is_pinned());
316
317 drop(task_tx);
319 handle.join().unwrap();
320 }
321
322 #[test]
323 fn worker_unpins_on_error() {
324 let ring = Arc::new(SnapshotRing::new(4));
326 let epoch_counter = Arc::new(EpochCounter::new());
327 let worker_epoch = Arc::new(WorkerEpoch::new(0));
328
329 let (task_tx, task_rx) = crossbeam_channel::bounded(4);
330 let (reply_tx, reply_rx) = crossbeam_channel::bounded(1);
331
332 let ring_c = Arc::clone(&ring);
333 let epoch_c = Arc::clone(&epoch_counter);
334 let we_c = Arc::clone(&worker_epoch);
335 let handle = thread::spawn(move || {
336 worker_loop(task_rx, ring_c, epoch_c, we_c);
337 });
338
339 let space = Line1D::new(10, EdgeBehavior::Absorb).unwrap();
341 let spec = ObsSpec {
342 entries: vec![ObsEntry {
343 field_id: FieldId(0),
344 region: ObsRegion::Fixed(murk_space::RegionSpec::All),
345 pool: None,
346 transform: murk_obs::spec::ObsTransform::Identity,
347 dtype: murk_obs::spec::ObsDtype::F32,
348 }],
349 };
350 let plan_result = ObsPlan::compile(&spec, &space).unwrap();
351
352 task_tx
353 .send(ObsTask::Simple {
354 plan: Arc::new(plan_result.plan),
355 output_len: plan_result.output_len,
356 mask_len: plan_result.mask_len,
357 reply: reply_tx,
358 })
359 .unwrap();
360
361 let result = reply_rx.recv().unwrap();
362 assert!(matches!(result, ObsResult::Error(_)));
363 assert!(!worker_epoch.is_pinned());
364
365 drop(task_tx);
366 handle.join().unwrap();
367 }
368
369 #[test]
370 fn worker_exits_on_channel_close() {
371 let ring = Arc::new(SnapshotRing::new(4));
372 let epoch_counter = Arc::new(EpochCounter::new());
373 let worker_epoch = Arc::new(WorkerEpoch::new(0));
374
375 let (task_tx, task_rx) = crossbeam_channel::bounded::<ObsTask>(4);
376
377 let ring_c = Arc::clone(&ring);
378 let epoch_c = Arc::clone(&epoch_counter);
379 let we_c = Arc::clone(&worker_epoch);
380 let handle = thread::spawn(move || {
381 worker_loop(task_rx, ring_c, epoch_c, we_c);
382 });
383
384 drop(task_tx);
386 handle.join().unwrap();
387 }
388}