noether_engine/executor/
pure_cache.rs1use noether_core::effects::Effect;
11use noether_core::stage::StageId;
12use noether_store::StageStore;
13use serde_json::Value;
14use sha2::{Digest, Sha256};
15use std::collections::{HashMap, HashSet};
16
17#[derive(Default)]
18pub struct PureStageCache {
19 pure_ids: HashSet<String>,
21 entries: HashMap<CacheKey, Value>,
23 pub hits: u32,
24 pub misses: u32,
25}
26
27#[derive(Hash, PartialEq, Eq)]
28struct CacheKey {
29 stage_id: String,
30 input_hash: String,
31}
32
33impl PureStageCache {
34 pub fn from_store(store: &dyn StageStore) -> Self {
36 let pure_ids = store
37 .list(None)
38 .into_iter()
39 .filter(|s| s.signature.effects.contains(&Effect::Pure))
40 .map(|s| s.id.0.clone())
41 .collect();
42
43 Self {
44 pure_ids,
45 entries: HashMap::new(),
46 hits: 0,
47 misses: 0,
48 }
49 }
50
51 pub fn is_pure(&self, stage_id: &StageId) -> bool {
53 self.pure_ids.contains(&stage_id.0)
54 }
55
56 pub fn get(&mut self, stage_id: &StageId, input: &Value) -> Option<&Value> {
58 if !self.is_pure(stage_id) {
59 return None;
60 }
61 let key = CacheKey {
62 stage_id: stage_id.0.clone(),
63 input_hash: hash_value(input),
64 };
65 if self.entries.contains_key(&key) {
66 self.hits += 1;
67 self.entries.get(&key)
68 } else {
69 self.misses += 1;
70 None
71 }
72 }
73
74 pub fn put(&mut self, stage_id: &StageId, input: &Value, output: Value) {
76 if !self.is_pure(stage_id) {
77 return;
78 }
79 let key = CacheKey {
80 stage_id: stage_id.0.clone(),
81 input_hash: hash_value(input),
82 };
83 self.entries.insert(key, output);
84 }
85}
86
87fn hash_value(value: &Value) -> String {
88 let bytes = serde_json::to_vec(value).unwrap_or_default();
89 hex::encode(Sha256::digest(&bytes))
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use noether_core::stage::StageId;
96 use serde_json::json;
97
98 fn id(s: &str) -> StageId {
99 StageId(s.into())
100 }
101
102 #[test]
103 fn miss_on_non_pure_stage() {
104 let mut cache = PureStageCache::default();
105 assert!(cache.get(&id("anything"), &json!("input")).is_none());
107 }
108
109 #[test]
110 fn hit_after_put() {
111 let mut cache = PureStageCache::default();
112 cache.pure_ids.insert("pure_stage".into());
113
114 let stage = id("pure_stage");
115 let input = json!("hello");
116 let output = json!(42);
117
118 assert!(cache.get(&stage, &input).is_none());
119 cache.put(&stage, &input, output.clone());
120 let cached = cache.get(&stage, &input).unwrap();
121 assert_eq!(*cached, output);
122 assert_eq!(cache.hits, 1);
123 }
124
125 #[test]
126 fn different_inputs_produce_different_keys() {
127 let mut cache = PureStageCache::default();
128 cache.pure_ids.insert("pure_stage".into());
129
130 let stage = id("pure_stage");
131 cache.put(&stage, &json!("foo"), json!(1));
132 cache.put(&stage, &json!("bar"), json!(2));
133
134 assert_eq!(*cache.get(&stage, &json!("foo")).unwrap(), json!(1));
135 assert_eq!(*cache.get(&stage, &json!("bar")).unwrap(), json!(2));
136 }
137}