Skip to main content

murk_obs/
cache.rs

1//! Plan cache with space-topology-based invalidation.
2//!
3//! [`ObsPlanCache`] wraps an [`ObsSpec`] and lazily compiles an
4//! [`ObsPlan`] on first use. Subsequent calls to [`ObsPlanCache::get_or_compile`]
5//! return the cached plan as long as the same space instance (by
6//! [`SpaceInstanceId`] and cell count)
7//! is provided; otherwise the plan is recompiled automatically.
8//!
9//! The cache does **not** key on [`WorldGenerationId`](murk_core::WorldGenerationId)
10//! because that counter increments on every tick, which would defeat
11//! caching. Observation plans depend only on space topology (cell count,
12//! canonical ordering), not on per-tick state.
13
14use murk_core::error::ObsError;
15use murk_core::{Coord, SnapshotAccess, SpaceInstanceId, TickId};
16use murk_space::Space;
17
18use crate::metadata::ObsMetadata;
19use crate::spec::ObsSpec;
20use crate::ObsPlan;
21
22/// Cached observation plan with space-topology-based invalidation.
23///
24/// Holds an [`ObsSpec`] and an optional compiled [`ObsPlan`]. On each
25/// call to [`execute`](Self::execute), checks whether the cached plan
26/// was compiled for the same space (by [`SpaceInstanceId`] and cell count).
27/// On mismatch, the plan is recompiled transparently.
28///
29/// # Example
30///
31/// ```ignore
32/// let mut cache = ObsPlanCache::new(spec);
33/// // First call compiles the plan:
34/// let meta = cache.execute(&space, &snapshot, None, &mut output, &mut mask)?;
35/// // Subsequent calls reuse it (same space):
36/// let meta = cache.execute(&space, &snapshot, None, &mut output, &mut mask)?;
37/// ```
38///
39/// # Invalidation
40///
41/// The plan is recompiled when:
42/// - No plan has been compiled yet.
43/// - A different space instance is passed (different [`SpaceInstanceId`]).
44/// - The same space object's `cell_count()` has changed (topology mutation).
45/// - [`invalidate`](Self::invalidate) is called explicitly.
46///
47/// The plan is **not** recompiled when:
48/// - The snapshot's `WorldGenerationId` changes (that is per-tick churn,
49///   not a topology change).
50#[derive(Debug)]
51pub struct ObsPlanCache {
52    spec: ObsSpec,
53    cached: Option<CachedPlan>,
54}
55
56/// Fingerprint of a `&dyn Space` for cache invalidation.
57///
58/// Uses the space's [`SpaceInstanceId`] (monotonic counter, no ABA risk)
59/// plus `cell_count` as a mutation guard.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61struct SpaceFingerprint {
62    instance_id: SpaceInstanceId,
63    cell_count: usize,
64}
65
66impl SpaceFingerprint {
67    fn of(space: &dyn Space) -> Self {
68        Self {
69            instance_id: space.instance_id(),
70            cell_count: space.cell_count(),
71        }
72    }
73}
74
75/// Internal: a compiled plan with its space fingerprint and layout info.
76#[derive(Debug)]
77struct CachedPlan {
78    plan: ObsPlan,
79    fingerprint: SpaceFingerprint,
80    output_len: usize,
81    mask_len: usize,
82    entry_shapes: Vec<Vec<usize>>,
83}
84
85impl ObsPlanCache {
86    /// Create a new cache for the given observation spec.
87    ///
88    /// The plan is not compiled until the first call to
89    /// [`execute`](Self::execute) or [`get_or_compile`](Self::get_or_compile).
90    pub fn new(spec: ObsSpec) -> Self {
91        Self { spec, cached: None }
92    }
93
94    /// Get the cached plan, recompiling if needed.
95    ///
96    /// Returns the cached plan if one exists and was compiled for the
97    /// same space (by [`SpaceInstanceId`] and cell count). Otherwise
98    /// recompiles from the stored [`ObsSpec`].
99    pub fn get_or_compile(&mut self, space: &dyn Space) -> Result<&ObsPlan, ObsError> {
100        let fingerprint = SpaceFingerprint::of(space);
101
102        let needs_recompile = match &self.cached {
103            None => true,
104            Some(cached) => cached.fingerprint != fingerprint,
105        };
106
107        if needs_recompile {
108            let result = ObsPlan::compile(&self.spec, space)?;
109            self.cached = Some(CachedPlan {
110                plan: result.plan,
111                fingerprint,
112                output_len: result.output_len,
113                mask_len: result.mask_len,
114                entry_shapes: result.entry_shapes,
115            });
116        }
117
118        Ok(&self.cached.as_ref().unwrap().plan)
119    }
120
121    /// Execute the observation plan against a snapshot, recompiling if
122    /// the space has changed.
123    ///
124    /// This is the primary convenience method. It calls
125    /// [`get_or_compile`](Self::get_or_compile) then
126    /// [`ObsPlan::execute`].
127    ///
128    /// `engine_tick` is the current engine tick for computing
129    /// [`ObsMetadata::age_ticks`]. Pass `None` in Lockstep mode
130    /// (age is always 0).
131    pub fn execute(
132        &mut self,
133        space: &dyn Space,
134        snapshot: &dyn SnapshotAccess,
135        engine_tick: Option<TickId>,
136        output: &mut [f32],
137        mask: &mut [u8],
138    ) -> Result<ObsMetadata, ObsError> {
139        let plan = self.get_or_compile(space)?;
140        plan.execute(snapshot, engine_tick, output, mask)
141    }
142
143    /// Execute the Standard plan for `N` agents, recompiling if the
144    /// space has changed.
145    ///
146    /// Convenience wrapper over [`get_or_compile`](Self::get_or_compile)
147    /// + [`ObsPlan::execute_agents`].
148    pub fn execute_agents(
149        &mut self,
150        space: &dyn Space,
151        snapshot: &dyn SnapshotAccess,
152        agent_centers: &[Coord],
153        engine_tick: Option<TickId>,
154        output: &mut [f32],
155        mask: &mut [u8],
156    ) -> Result<Vec<ObsMetadata>, ObsError> {
157        let plan = self.get_or_compile(space)?;
158        plan.execute_agents(snapshot, space, agent_centers, engine_tick, output, mask)
159    }
160
161    /// Output length of the currently cached plan, or `None` if no
162    /// plan has been compiled yet.
163    pub fn output_len(&self) -> Option<usize> {
164        self.cached.as_ref().map(|c| c.output_len)
165    }
166
167    /// Mask length of the currently cached plan.
168    pub fn mask_len(&self) -> Option<usize> {
169        self.cached.as_ref().map(|c| c.mask_len)
170    }
171
172    /// Entry shapes of the currently cached plan.
173    pub fn entry_shapes(&self) -> Option<&[Vec<usize>]> {
174        self.cached.as_ref().map(|c| c.entry_shapes.as_slice())
175    }
176
177    /// Whether a compiled plan is currently cached.
178    pub fn is_compiled(&self) -> bool {
179        self.cached.is_some()
180    }
181
182    /// Invalidate the cached plan, forcing recompilation on next use.
183    pub fn invalidate(&mut self) {
184        self.cached = None;
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::spec::{ObsDtype, ObsEntry, ObsRegion, ObsTransform};
192    use murk_core::{FieldId, ParameterVersion, TickId, WorldGenerationId};
193    use murk_space::{EdgeBehavior, RegionSpec, Square4};
194    use murk_test_utils::MockSnapshot;
195
196    fn space() -> Square4 {
197        Square4::new(3, 3, EdgeBehavior::Absorb).unwrap()
198    }
199
200    fn spec() -> ObsSpec {
201        ObsSpec {
202            entries: vec![ObsEntry {
203                field_id: FieldId(0),
204                region: ObsRegion::Fixed(RegionSpec::All),
205                pool: None,
206                transform: ObsTransform::Identity,
207                dtype: ObsDtype::F32,
208            }],
209        }
210    }
211
212    fn snap(gen: u64, tick: u64) -> MockSnapshot {
213        let mut s = MockSnapshot::new(TickId(tick), WorldGenerationId(gen), ParameterVersion(0));
214        s.set_field(FieldId(0), vec![1.0; 9]);
215        s
216    }
217
218    // ── Cache lifecycle tests ────────────────────────────────
219
220    #[test]
221    fn not_compiled_initially() {
222        let cache = ObsPlanCache::new(spec());
223        assert!(!cache.is_compiled());
224        assert_eq!(cache.output_len(), None);
225    }
226
227    #[test]
228    fn first_execute_compiles_plan() {
229        let space = space();
230        let snapshot = snap(1, 10);
231        let mut cache = ObsPlanCache::new(spec());
232
233        let mut output = vec![0.0f32; 9];
234        let mut mask = vec![0u8; 9];
235        cache
236            .execute(&space, &snapshot, None, &mut output, &mut mask)
237            .unwrap();
238
239        assert!(cache.is_compiled());
240        assert_eq!(cache.output_len(), Some(9));
241    }
242
243    #[test]
244    fn same_space_reuses_plan_across_generations() {
245        let space = space();
246        // Different WorldGenerationId values — cache should NOT recompile.
247        let snap_gen1 = snap(1, 10);
248        let snap_gen2 = snap(2, 20);
249        let snap_gen3 = snap(3, 30);
250        let mut cache = ObsPlanCache::new(spec());
251
252        let mut output = vec![0.0f32; 9];
253        let mut mask = vec![0u8; 9];
254        cache
255            .execute(&space, &snap_gen1, None, &mut output, &mut mask)
256            .unwrap();
257        assert!(cache.is_compiled());
258
259        // Same space, different generation — no recompile.
260        cache
261            .execute(&space, &snap_gen2, None, &mut output, &mut mask)
262            .unwrap();
263        assert!(cache.is_compiled());
264
265        // Third generation — still no recompile.
266        cache
267            .execute(&space, &snap_gen3, None, &mut output, &mut mask)
268            .unwrap();
269        assert!(cache.is_compiled());
270    }
271
272    #[test]
273    fn different_space_triggers_recompile() {
274        let space_a = Square4::new(3, 3, EdgeBehavior::Absorb).unwrap();
275        let space_b = Square4::new(4, 4, EdgeBehavior::Absorb).unwrap();
276        let mut cache = ObsPlanCache::new(spec());
277
278        // Compile with 3x3 space (9 cells).
279        cache.get_or_compile(&space_a).unwrap();
280        assert!(cache.is_compiled());
281        assert_eq!(cache.output_len(), Some(9));
282
283        // Different space object with different topology → recompile.
284        cache.get_or_compile(&space_b).unwrap();
285        assert!(cache.is_compiled());
286        assert_eq!(cache.output_len(), Some(16));
287    }
288
289    #[test]
290    fn different_space_same_dimensions_triggers_recompile() {
291        // Two distinct space objects with the same dimensions.
292        // Different instance IDs → recompile, even though topology is identical.
293        let space_a = Square4::new(3, 3, EdgeBehavior::Absorb).unwrap();
294        let space_b = Square4::new(3, 3, EdgeBehavior::Absorb).unwrap();
295        let mut cache = ObsPlanCache::new(spec());
296
297        let fp_a = SpaceFingerprint::of(&space_a);
298        let fp_b = SpaceFingerprint::of(&space_b);
299        // Distinct objects have different instance IDs (monotonic counter).
300        assert_ne!(fp_a.instance_id, fp_b.instance_id);
301
302        cache.get_or_compile(&space_a).unwrap();
303        assert!(cache.is_compiled());
304
305        // Different instance ID → recompile (conservative but safe).
306        cache.get_or_compile(&space_b).unwrap();
307        assert!(cache.is_compiled());
308    }
309
310    #[test]
311    fn invalidate_forces_recompile() {
312        let space = space();
313        let snapshot = snap(1, 10);
314        let mut cache = ObsPlanCache::new(spec());
315
316        let mut output = vec![0.0f32; 9];
317        let mut mask = vec![0u8; 9];
318        cache
319            .execute(&space, &snapshot, None, &mut output, &mut mask)
320            .unwrap();
321        assert!(cache.is_compiled());
322
323        cache.invalidate();
324        assert!(!cache.is_compiled());
325
326        // Re-executes fine.
327        cache
328            .execute(&space, &snapshot, None, &mut output, &mut mask)
329            .unwrap();
330        assert!(cache.is_compiled());
331    }
332
333    // ── age_ticks tests ──────────────────────────────────────
334
335    #[test]
336    fn age_ticks_zero_when_engine_tick_none() {
337        let space = space();
338        let snapshot = snap(1, 42);
339        let mut cache = ObsPlanCache::new(spec());
340
341        let mut output = vec![0.0f32; 9];
342        let mut mask = vec![0u8; 9];
343        let meta = cache
344            .execute(&space, &snapshot, None, &mut output, &mut mask)
345            .unwrap();
346
347        assert_eq!(meta.age_ticks, 0);
348    }
349
350    #[test]
351    fn age_ticks_zero_for_lockstep_same_tick() {
352        let space = space();
353        let snapshot = snap(1, 10);
354        let mut cache = ObsPlanCache::new(spec());
355
356        let mut output = vec![0.0f32; 9];
357        let mut mask = vec![0u8; 9];
358        let meta = cache
359            .execute(&space, &snapshot, Some(TickId(10)), &mut output, &mut mask)
360            .unwrap();
361
362        assert_eq!(meta.age_ticks, 0);
363    }
364
365    #[test]
366    fn age_ticks_positive_for_stale_snapshot() {
367        let space = space();
368        // Snapshot at tick 10, engine at tick 15 → age = 5.
369        let snapshot = snap(1, 10);
370        let mut cache = ObsPlanCache::new(spec());
371
372        let mut output = vec![0.0f32; 9];
373        let mut mask = vec![0u8; 9];
374        let meta = cache
375            .execute(&space, &snapshot, Some(TickId(15)), &mut output, &mut mask)
376            .unwrap();
377
378        assert_eq!(meta.age_ticks, 5);
379    }
380
381    #[test]
382    fn age_ticks_saturates_on_underflow() {
383        let space = space();
384        // Engine tick < snapshot tick (shouldn't happen, but saturating_sub handles it).
385        let snapshot = snap(1, 100);
386        let mut cache = ObsPlanCache::new(spec());
387
388        let mut output = vec![0.0f32; 9];
389        let mut mask = vec![0u8; 9];
390        let meta = cache
391            .execute(&space, &snapshot, Some(TickId(5)), &mut output, &mut mask)
392            .unwrap();
393
394        assert_eq!(meta.age_ticks, 0);
395    }
396
397    // ── get_or_compile tests ─────────────────────────────────
398
399    #[test]
400    fn get_or_compile_returns_unbound_plan() {
401        let space = space();
402        let mut cache = ObsPlanCache::new(spec());
403
404        let plan = cache.get_or_compile(&space).unwrap();
405        // Cache uses compile() not compile_bound(), so no generation binding.
406        assert_eq!(plan.compiled_generation(), None);
407    }
408
409    #[test]
410    fn get_or_compile_reuses_for_same_space() {
411        let space = space();
412        let mut cache = ObsPlanCache::new(spec());
413
414        cache.get_or_compile(&space).unwrap();
415        assert!(cache.is_compiled());
416
417        // Same space reference → reuse.
418        cache.get_or_compile(&space).unwrap();
419        assert!(cache.is_compiled());
420    }
421
422    // ── SpaceFingerprint tests ───────────────────────────────
423
424    #[test]
425    fn fingerprint_same_object_is_equal() {
426        let space = space();
427        let fp1 = SpaceFingerprint::of(&space);
428        let fp2 = SpaceFingerprint::of(&space);
429        assert_eq!(fp1, fp2);
430    }
431
432    #[test]
433    fn fingerprint_different_objects_differ() {
434        // Monotonic counter guarantees distinct IDs even for identical topology.
435        let a = Square4::new(3, 3, EdgeBehavior::Absorb).unwrap();
436        let b = Square4::new(3, 3, EdgeBehavior::Absorb).unwrap();
437        let fp_a = SpaceFingerprint::of(&a);
438        let fp_b = SpaceFingerprint::of(&b);
439        assert_ne!(fp_a, fp_b);
440    }
441
442    #[test]
443    fn fingerprint_different_sizes_differ() {
444        let small = Square4::new(2, 2, EdgeBehavior::Absorb).unwrap();
445        let big = Square4::new(5, 5, EdgeBehavior::Absorb).unwrap();
446        let fp_s = SpaceFingerprint::of(&small);
447        let fp_b = SpaceFingerprint::of(&big);
448        assert_ne!(fp_s, fp_b);
449    }
450
451    #[test]
452    fn fingerprint_clone_preserves_id() {
453        // Cloning a space preserves instance_id (same topology, safe to reuse plan).
454        let a = Square4::new(3, 3, EdgeBehavior::Absorb).unwrap();
455        let b = a.clone();
456        let fp_a = SpaceFingerprint::of(&a);
457        let fp_b = SpaceFingerprint::of(&b);
458        assert_eq!(fp_a, fp_b);
459    }
460}