Skip to main content

noether_engine/executor/
pure_cache.rs

1//! In-memory cache for Pure stage outputs.
2//!
3//! A `Pure` stage is deterministic and side-effect-free: for the same input, it
4//! always returns the same output.  We can therefore skip execution entirely when
5//! we've already seen a `(stage_id, input_hash)` pair during the current run.
6//!
7//! The cache lives for the duration of a single `run_composition` call and is
8//! never persisted to disk.
9
10use noether_core::effects::Effect;
11use noether_core::stage::StageId;
12use noether_store::StageStore;
13use serde_json::Value;
14use sha2::{Digest, Sha256};
15use std::collections::{HashMap, HashSet};
16
17#[derive(Default)]
18pub struct PureStageCache {
19    /// Set of stage IDs whose `EffectSet` contains `Effect::Pure`.
20    pure_ids: HashSet<String>,
21    /// Cache from `(stage_id, input_hash)` to output value.
22    entries: HashMap<CacheKey, Value>,
23    pub hits: u32,
24    pub misses: u32,
25}
26
27#[derive(Hash, PartialEq, Eq)]
28struct CacheKey {
29    stage_id: String,
30    input_hash: String,
31}
32
33impl PureStageCache {
34    /// Build a cache pre-populated with the set of Pure stage IDs from the store.
35    pub fn from_store(store: &dyn StageStore) -> Self {
36        let pure_ids = store
37            .list(None)
38            .into_iter()
39            .filter(|s| s.signature.effects.contains(&Effect::Pure))
40            .map(|s| s.id.0.clone())
41            .collect();
42
43        Self {
44            pure_ids,
45            entries: HashMap::new(),
46            hits: 0,
47            misses: 0,
48        }
49    }
50
51    /// Returns `true` when the stage is declared Pure.
52    pub fn is_pure(&self, stage_id: &StageId) -> bool {
53        self.pure_ids.contains(&stage_id.0)
54    }
55
56    /// Look up a cached output. Returns `None` on a cache miss.
57    pub fn get(&mut self, stage_id: &StageId, input: &Value) -> Option<&Value> {
58        if !self.is_pure(stage_id) {
59            return None;
60        }
61        let key = CacheKey {
62            stage_id: stage_id.0.clone(),
63            input_hash: hash_value(input),
64        };
65        if self.entries.contains_key(&key) {
66            self.hits += 1;
67            self.entries.get(&key)
68        } else {
69            self.misses += 1;
70            None
71        }
72    }
73
74    /// Store an output in the cache. No-op for non-Pure stages.
75    pub fn put(&mut self, stage_id: &StageId, input: &Value, output: Value) {
76        if !self.is_pure(stage_id) {
77            return;
78        }
79        let key = CacheKey {
80            stage_id: stage_id.0.clone(),
81            input_hash: hash_value(input),
82        };
83        self.entries.insert(key, output);
84    }
85}
86
87fn hash_value(value: &Value) -> String {
88    let bytes = serde_json::to_vec(value).unwrap_or_default();
89    hex::encode(Sha256::digest(&bytes))
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use noether_core::stage::StageId;
96    use serde_json::json;
97
98    fn id(s: &str) -> StageId {
99        StageId(s.into())
100    }
101
102    #[test]
103    fn miss_on_non_pure_stage() {
104        let mut cache = PureStageCache::default();
105        // non_pure_ids not in pure set → always None
106        assert!(cache.get(&id("anything"), &json!("input")).is_none());
107    }
108
109    #[test]
110    fn hit_after_put() {
111        let mut cache = PureStageCache::default();
112        cache.pure_ids.insert("pure_stage".into());
113
114        let stage = id("pure_stage");
115        let input = json!("hello");
116        let output = json!(42);
117
118        assert!(cache.get(&stage, &input).is_none());
119        cache.put(&stage, &input, output.clone());
120        let cached = cache.get(&stage, &input).unwrap();
121        assert_eq!(*cached, output);
122        assert_eq!(cache.hits, 1);
123    }
124
125    #[test]
126    fn different_inputs_produce_different_keys() {
127        let mut cache = PureStageCache::default();
128        cache.pure_ids.insert("pure_stage".into());
129
130        let stage = id("pure_stage");
131        cache.put(&stage, &json!("foo"), json!(1));
132        cache.put(&stage, &json!("bar"), json!(2));
133
134        assert_eq!(*cache.get(&stage, &json!("foo")).unwrap(), json!(1));
135        assert_eq!(*cache.get(&stage, &json!("bar")).unwrap(), json!(2));
136    }
137}