Skip to main content

forge_runtime/function/
cache.rs

1use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
2use std::hash::{Hash, Hasher};
3use std::sync::{Arc, RwLock};
4use std::time::{Duration, Instant};
5
6use forge_core::realtime::Change;
7use forge_core::{AuthContext, FunctionInfo};
8use serde_json::Value;
9use sha2::{Digest, Sha256};
10use tokio::sync::broadcast;
11
12use super::registry::FunctionRegistry;
13
14/// `Hasher` adapter that funnels writes into a SHA-256 digest. Used to keep
15/// cache keys deterministic across rolling deploys; the standard library
16/// `DefaultHasher` is explicitly not stable across Rust versions.
17struct Sha256Hasher(Sha256);
18
19impl Sha256Hasher {
20    fn new() -> Self {
21        Self(Sha256::new())
22    }
23
24    /// Consume the hasher and return a 64-bit cache key.
25    ///
26    /// Truncation to u64 is acceptable for cache key dedup — collision
27    /// probability is ~1/2^64 per birthday bound which is negligible for
28    /// cache sizes under billions of entries.
29    fn finish_u64(self) -> u64 {
30        let digest = self.0.finalize();
31        let mut buf = [0u8; 8];
32        if let Some(prefix) = digest.get(..8) {
33            buf.copy_from_slice(prefix);
34        }
35        u64::from_be_bytes(buf)
36    }
37}
38
39impl Hasher for Sha256Hasher {
40    fn write(&mut self, bytes: &[u8]) {
41        self.0.update(bytes);
42    }
43
44    fn finish(&self) -> u64 {
45        // Required by the trait for `Hash::hash` to work. Clones internal
46        // state because `Hasher::finish` is non-consuming. All real callers
47        // use `finish_u64(self)` instead.
48        let digest = self.0.clone().finalize();
49        let mut buf = [0u8; 8];
50        if let Some(prefix) = digest.get(..8) {
51            buf.copy_from_slice(prefix);
52        }
53        u64::from_be_bytes(buf)
54    }
55}
56
57/// A simple in-memory cache for query results.
58pub struct QueryCache {
59    entries: RwLock<CacheState>,
60    max_entries: usize,
61}
62
63struct CacheState {
64    map: HashMap<CacheKey, CacheEntry>,
65    /// Insertion-order queue for O(1) eviction of oldest entries.
66    insertion_order: VecDeque<CacheKey>,
67    /// Reverse index: function name -> set of cache keys for that function.
68    /// Avoids O(N) scan in `invalidate_by_tables` / `invalidate_function`.
69    name_to_keys: HashMap<Arc<str>, HashSet<CacheKey>>,
70}
71
72/// Cache key using hashed function name to avoid per-lookup String allocation.
73/// The `function_name_hash` is a stable SHA-256 truncation of the name.
74#[derive(Clone, Copy, Eq, PartialEq, Hash)]
75struct CacheKey {
76    function_name_hash: u64,
77    args_hash: u64,
78    auth_scope_hash: u64,
79}
80
81struct CacheEntry {
82    value: Arc<Value>,
83    expires_at: Instant,
84    /// The original function name, retained for invalidation-by-name lookups.
85    function_name: Arc<str>,
86}
87
88impl QueryCache {
89    /// Create a new query cache with default settings.
90    pub fn new() -> Self {
91        Self::with_max_entries(10_000)
92    }
93
94    /// Create a new query cache with a maximum number of entries.
95    pub fn with_max_entries(max_entries: usize) -> Self {
96        Self {
97            entries: RwLock::new(CacheState {
98                map: HashMap::new(),
99                insertion_order: VecDeque::new(),
100                name_to_keys: HashMap::new(),
101            }),
102            max_entries,
103        }
104    }
105
106    /// Get a cached value if it exists and hasn't expired.
107    pub fn get(
108        &self,
109        function_name: &str,
110        args: &Value,
111        auth_scope: Option<&str>,
112    ) -> Option<Arc<Value>> {
113        let key = Self::make_key(function_name, args, auth_scope);
114
115        let state = self.entries.read().ok()?;
116        let entry = state.map.get(&key)?;
117
118        if Instant::now() < entry.expires_at {
119            Some(Arc::clone(&entry.value))
120        } else {
121            None
122        }
123    }
124
125    /// Set a cached value with a TTL.
126    pub fn set(
127        &self,
128        function_name: &str,
129        args: &Value,
130        auth_scope: Option<&str>,
131        value: Value,
132        ttl: Duration,
133    ) {
134        self.set_arc(function_name, args, auth_scope, Arc::new(value), ttl);
135    }
136
137    /// Store a pre-wrapped `Arc<Value>` in the cache.
138    pub fn set_arc(
139        &self,
140        function_name: &str,
141        args: &Value,
142        auth_scope: Option<&str>,
143        value: Arc<Value>,
144        ttl: Duration,
145    ) {
146        let key = Self::make_key(function_name, args, auth_scope);
147        let now = Instant::now();
148
149        let entry = CacheEntry {
150            value,
151            expires_at: now + ttl,
152            function_name: Arc::from(function_name),
153        };
154
155        if let Ok(mut state) = self.entries.write() {
156            if state.map.len() >= self.max_entries {
157                Self::evict_expired(&mut state);
158            }
159
160            if state.map.len() >= self.max_entries {
161                Self::evict_oldest(&mut state, (self.max_entries / 10).max(1));
162            }
163
164            if !state.map.contains_key(&key) {
165                state.insertion_order.push_back(key);
166            }
167            state
168                .name_to_keys
169                .entry(Arc::clone(&entry.function_name))
170                .or_default()
171                .insert(key);
172            state.map.insert(key, entry);
173        }
174    }
175
176    /// Invalidate a specific cache entry.
177    pub fn invalidate(&self, function_name: &str, args: &Value) {
178        let name_hash = hash_str(function_name);
179        let args_hash = hash_value(args);
180        if let Ok(mut state) = self.entries.write() {
181            let matching: Vec<CacheKey> = state
182                .map
183                .keys()
184                .filter(|k| k.function_name_hash == name_hash && k.args_hash == args_hash)
185                .copied()
186                .collect();
187            for key in &matching {
188                if let Some(entry) = state.map.remove(key)
189                    && let Some(keys) = state.name_to_keys.get_mut(&entry.function_name)
190                {
191                    keys.remove(key);
192                    if keys.is_empty() {
193                        state.name_to_keys.remove(&entry.function_name);
194                    }
195                }
196            }
197            if !matching.is_empty() {
198                let removed: HashSet<&CacheKey> = matching.iter().collect();
199                state.insertion_order.retain(|k| !removed.contains(k));
200            }
201        }
202    }
203
204    /// Invalidate all entries for a function.
205    pub fn invalidate_function(&self, function_name: &str) {
206        if let Ok(mut state) = self.entries.write() {
207            let name_arc: Arc<str> = Arc::from(function_name);
208            if let Some(keys) = state.name_to_keys.remove(&name_arc) {
209                for key in &keys {
210                    state.map.remove(key);
211                }
212                state.insertion_order.retain(|k| !keys.contains(k));
213            }
214        }
215    }
216
217    /// Invalidate all cached queries that depend on any of the given tables.
218    pub fn invalidate_by_tables(&self, query_names: &[&str]) {
219        if query_names.is_empty() {
220            return;
221        }
222        if let Ok(mut state) = self.entries.write() {
223            let mut all_removed_keys: HashSet<CacheKey> = HashSet::new();
224            for name in query_names {
225                let name_arc: Arc<str> = Arc::from(*name);
226                if let Some(keys) = state.name_to_keys.remove(&name_arc) {
227                    for key in &keys {
228                        state.map.remove(key);
229                    }
230                    all_removed_keys.extend(keys);
231                }
232            }
233            if !all_removed_keys.is_empty() {
234                state
235                    .insertion_order
236                    .retain(|k| !all_removed_keys.contains(k));
237            }
238        }
239    }
240
241    /// Clear the entire cache.
242    pub fn clear(&self) {
243        if let Ok(mut state) = self.entries.write() {
244            state.map.clear();
245            state.insertion_order.clear();
246            state.name_to_keys.clear();
247        }
248    }
249
250    /// Get the number of cached entries.
251    pub fn len(&self) -> usize {
252        self.entries.read().map(|e| e.map.len()).unwrap_or(0)
253    }
254
255    /// Check if the cache is empty.
256    pub fn is_empty(&self) -> bool {
257        self.len() == 0
258    }
259
260    fn make_key(function_name: &str, args: &Value, auth_scope: Option<&str>) -> CacheKey {
261        CacheKey {
262            function_name_hash: hash_str(function_name),
263            args_hash: hash_value(args),
264            auth_scope_hash: hash_str(auth_scope.unwrap_or("")),
265        }
266    }
267
268    fn evict_expired(state: &mut CacheState) {
269        let now = Instant::now();
270        let expired_keys: Vec<CacheKey> = state
271            .map
272            .iter()
273            .filter(|(_, v)| v.expires_at <= now)
274            .map(|(k, _)| *k)
275            .collect();
276        for key in &expired_keys {
277            if let Some(entry) = state.map.remove(key)
278                && let Some(keys) = state.name_to_keys.get_mut(&entry.function_name)
279            {
280                keys.remove(key);
281                if keys.is_empty() {
282                    state.name_to_keys.remove(&entry.function_name);
283                }
284            }
285        }
286        if !expired_keys.is_empty() {
287            let removed: HashSet<&CacheKey> = expired_keys.iter().collect();
288            state.insertion_order.retain(|k| !removed.contains(k));
289        }
290    }
291
292    /// Evict the oldest entries by popping from the front of the insertion queue.
293    fn evict_oldest(state: &mut CacheState, count: usize) {
294        let mut evicted = 0;
295        while evicted < count {
296            let Some(key) = state.insertion_order.pop_front() else {
297                break;
298            };
299            if let Some(entry) = state.map.remove(&key) {
300                if let Some(keys) = state.name_to_keys.get_mut(&entry.function_name) {
301                    keys.remove(&key);
302                    if keys.is_empty() {
303                        state.name_to_keys.remove(&entry.function_name);
304                    }
305                }
306                evicted += 1;
307            }
308        }
309    }
310}
311
312impl Default for QueryCache {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318/// Coordinates query-result caching: owns the [`QueryCache`], the reverse
319/// table -> query index used for mutation invalidation, and the auth-scope
320/// derivation that keys cache entries per principal.
321///
322/// Splitting these out of [`super::router::FunctionRouter`] keeps the router
323/// focused on dispatch and lets the cache concerns be tested in isolation.
324pub struct QueryCacheCoordinator {
325    cache: QueryCache,
326    /// Reverse index: table name -> dependent queries with the columns each one
327    /// reads. Carrying the column set here lets `invalidate_for_mutation` skip
328    /// queries whose read set doesn't overlap the mutation's write set without
329    /// re-traversing the registry.
330    table_to_queries: HashMap<String, Vec<QueryDep>>,
331}
332
333/// A query that depends on a table, with the columns it reads from that table.
334/// Empty `selected_columns` means "could read any column" — treated as a wildcard
335/// so the query is always invalidated when the table is touched.
336#[derive(Clone)]
337struct QueryDep {
338    name: String,
339    selected_columns: HashSet<String>,
340}
341
342impl QueryCacheCoordinator {
343    /// Create a coordinator from the function registry. The reverse index is
344    /// built once at construction so mutation invalidation is a hash lookup.
345    pub fn new(registry: &FunctionRegistry) -> Self {
346        Self {
347            cache: QueryCache::new(),
348            table_to_queries: build_table_index(registry),
349        }
350    }
351
352    /// Lookup a cached value with an already-derived scope. Use this when the
353    /// caller has computed [`Self::auth_scope`] up front (e.g. before `auth` is
354    /// moved into a context).
355    pub fn get_by_scope(
356        &self,
357        function_name: &str,
358        args: &Value,
359        scope: Option<&str>,
360    ) -> Option<Arc<Value>> {
361        self.cache.get(function_name, args, scope)
362    }
363
364    /// Store a value with an already-derived scope.
365    pub fn set_by_scope(
366        &self,
367        function_name: &str,
368        args: &Value,
369        scope: Option<&str>,
370        value: Value,
371        ttl: Duration,
372    ) {
373        self.cache.set(function_name, args, scope, value, ttl);
374    }
375
376    /// Store a pre-wrapped `Arc<Value>` with an already-derived scope.
377    pub fn set_arc_by_scope(
378        &self,
379        function_name: &str,
380        args: &Value,
381        scope: Option<&str>,
382        value: Arc<Value>,
383        ttl: Duration,
384    ) {
385        self.cache.set_arc(function_name, args, scope, value, ttl);
386    }
387
388    /// Invalidate cached queries whose table dependencies overlap with the
389    /// mutation's write set, narrowed by column intersection when both sides
390    /// know which columns they touch.
391    ///
392    /// Conservative on uncertainty: when either the mutation's `changed_columns`
393    /// or the candidate query's `selected_columns` is empty (extractor couldn't
394    /// determine columns from the SQL), the query is invalidated. False positives
395    /// are cheap (re-execute a query); false negatives would serve stale data.
396    pub fn invalidate_for_mutation(&self, info: &FunctionInfo) {
397        if info.table_dependencies.is_empty() {
398            return;
399        }
400        let mutation_cols: HashSet<&str> = info.changed_columns.iter().copied().collect();
401        let mutation_cols_unknown = mutation_cols.is_empty();
402
403        let mut affected: HashSet<&str> = HashSet::new();
404        for table in info.table_dependencies {
405            let Some(queries) = self.table_to_queries.get(*table) else {
406                continue;
407            };
408            for dep in queries {
409                if mutation_cols_unknown
410                    || dep.selected_columns.is_empty()
411                    || dep
412                        .selected_columns
413                        .iter()
414                        .any(|c| mutation_cols.contains(c.as_str()))
415                {
416                    affected.insert(dep.name.as_str());
417                }
418            }
419        }
420        if !affected.is_empty() {
421            let names: Vec<&str> = affected.into_iter().collect();
422            self.cache.invalidate_by_tables(&names);
423            tracing::trace!(
424                mutation = info.name,
425                invalidated_queries = ?names,
426                "Cache invalidated after mutation"
427            );
428        }
429    }
430
431    /// Invalidate cached queries affected by a `forge_changes` NOTIFY event.
432    ///
433    /// Used to propagate mutations across cluster nodes: every node listening
434    /// to `forge_changes` runs this for each peer-emitted change so its local
435    /// cache stays consistent. Without this, mutations on node A leave stale
436    /// cache entries on node B until TTL expires.
437    ///
438    /// Mirrors `invalidate_for_mutation`'s column-intersection logic but uses
439    /// the change's runtime `changed_columns` rather than the mutation's
440    /// compile-time set. An empty change column list means "could be any
441    /// column" (older trigger payloads, INSERT/DELETE) — fall back to full
442    /// invalidation for that table.
443    pub fn invalidate_by_change(&self, change: &Change) {
444        let Some(queries) = self.table_to_queries.get(&change.table) else {
445            return;
446        };
447        let change_cols_unknown = change.changed_columns.is_empty();
448        let change_cols: HashSet<&str> =
449            change.changed_columns.iter().map(String::as_str).collect();
450
451        let mut affected: Vec<&str> = Vec::new();
452        for dep in queries {
453            if change_cols_unknown
454                || dep.selected_columns.is_empty()
455                || dep
456                    .selected_columns
457                    .iter()
458                    .any(|c| change_cols.contains(c.as_str()))
459            {
460                affected.push(dep.name.as_str());
461            }
462        }
463        if !affected.is_empty() {
464            self.cache.invalidate_by_tables(&affected);
465            tracing::trace!(
466                table = %change.table,
467                invalidated_queries = ?affected,
468                "Cache invalidated by cluster change"
469            );
470        }
471    }
472
473    /// Spawn a background task that drains a `Change` broadcast and evicts
474    /// matching cache entries. Returns a handle the caller can abort on
475    /// shutdown. The receiver typically comes from
476    /// `Reactor::change_subscriber()`.
477    pub fn spawn_cluster_invalidator(
478        self: Arc<Self>,
479        mut rx: broadcast::Receiver<Change>,
480    ) -> tokio::task::JoinHandle<()> {
481        tokio::spawn(async move {
482            loop {
483                match rx.recv().await {
484                    Ok(change) => self.invalidate_by_change(&change),
485                    Err(broadcast::error::RecvError::Lagged(n)) => {
486                        // We dropped n change events; the local cache could
487                        // be holding values that should have been evicted.
488                        // Clear everything to recover correctness.
489                        tracing::warn!(
490                            dropped = n,
491                            "Cache invalidator lagged; clearing local cache"
492                        );
493                        self.cache.clear();
494                    }
495                    Err(broadcast::error::RecvError::Closed) => {
496                        tracing::debug!("Change channel closed; cache invalidator stopping");
497                        break;
498                    }
499                }
500            }
501        })
502    }
503
504    /// Derive a stable cache scope from auth context. Anonymous callers share
505    /// the `"anon"` scope; authenticated callers get a hash that mixes role +
506    /// claims so cross-tenant cache bleed is impossible.
507    ///
508    /// Volatile JWT claims (`iat`, `nbf`, `exp`, `jti`, `auth_time`, `sid`,
509    /// `nonce`) are excluded — they change on every token refresh for the
510    /// same logical principal and would otherwise fragment the cache, killing
511    /// hit rate for any system using refresh tokens.
512    pub fn auth_scope(auth: &AuthContext) -> Option<String> {
513        if !auth.is_authenticated() {
514            return Some("anon".to_string());
515        }
516
517        let mut roles = auth.roles().to_vec();
518        roles.sort();
519        roles.dedup();
520
521        let mut claims = BTreeMap::new();
522        for (k, v) in auth.claims() {
523            if is_volatile_claim(k) {
524                continue;
525            }
526            claims.insert(k.clone(), v.clone());
527        }
528
529        let claims_json = match serde_json::to_string(&claims) {
530            Ok(json) => json,
531            Err(_) => {
532                tracing::error!(
533                    "BTreeMap<String, Value> serialization failed — cache scope degraded"
534                );
535                String::new()
536            }
537        };
538        let mut buf = String::with_capacity(64 + claims_json.len());
539        for role in &roles {
540            buf.push_str(role);
541            buf.push('\x1f');
542        }
543        buf.push('\x1e');
544        buf.push_str(&claims_json);
545        let scope = crate::stable_hash::stable_u64(buf.as_bytes());
546
547        let principal = auth
548            .principal_id()
549            .unwrap_or_else(|| "authenticated".to_string());
550
551        Some(format!("subject:{principal}:scope:{scope:016x}"))
552    }
553}
554
555/// Standard JWT claims that vary across token refreshes for the same
556/// principal and must not influence the per-principal cache scope.
557fn is_volatile_claim(name: &str) -> bool {
558    matches!(
559        name,
560        "iat" | "nbf" | "exp" | "jti" | "auth_time" | "sid" | "nonce"
561    )
562}
563
564fn build_table_index(registry: &FunctionRegistry) -> HashMap<String, Vec<QueryDep>> {
565    let mut index: HashMap<String, Vec<QueryDep>> = HashMap::new();
566    for (name, info) in registry.queries() {
567        let selected_columns: HashSet<String> = info
568            .selected_columns
569            .iter()
570            .map(|c| (*c).to_string())
571            .collect();
572        for table in info.table_dependencies {
573            index
574                .entry((*table).to_string())
575                .or_default()
576                .push(QueryDep {
577                    name: name.to_string(),
578                    selected_columns: selected_columns.clone(),
579                });
580        }
581    }
582    index
583}
584
585fn hash_value(value: &Value) -> u64 {
586    let mut hasher = Sha256Hasher::new();
587    hash_value_recursive(value, &mut hasher);
588    hasher.finish_u64()
589}
590
591fn hash_str(value: &str) -> u64 {
592    let mut hasher = Sha256Hasher::new();
593    value.hash(&mut hasher);
594    hasher.finish_u64()
595}
596
597fn hash_value_recursive<H: Hasher>(value: &Value, hasher: &mut H) {
598    match value {
599        Value::Null => 0u8.hash(hasher),
600        Value::Bool(b) => {
601            1u8.hash(hasher);
602            b.hash(hasher);
603        }
604        Value::Number(n) => {
605            2u8.hash(hasher);
606            n.to_string().hash(hasher);
607        }
608        Value::String(s) => {
609            3u8.hash(hasher);
610            s.hash(hasher);
611        }
612        Value::Array(arr) => {
613            4u8.hash(hasher);
614            arr.len().hash(hasher);
615            for v in arr {
616                hash_value_recursive(v, hasher);
617            }
618        }
619        Value::Object(obj) => {
620            5u8.hash(hasher);
621            obj.len().hash(hasher);
622            let mut keys: Vec<_> = obj.keys().collect();
623            keys.sort();
624            for key in keys {
625                key.hash(hasher);
626                if let Some(v) = obj.get(key.as_str()) {
627                    hash_value_recursive(v, hasher);
628                }
629            }
630        }
631    }
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use serde_json::json;
638
639    #[test]
640    fn test_cache_set_get() {
641        let cache = QueryCache::new();
642        let args = json!({"id": 123});
643        let value = json!({"name": "test"});
644
645        cache.set(
646            "get_user",
647            &args,
648            Some("user:1"),
649            value.clone(),
650            Duration::from_secs(60),
651        );
652
653        let result = cache.get("get_user", &args, Some("user:1"));
654        assert_eq!(result.as_deref(), Some(&value));
655    }
656
657    #[test]
658    fn test_cache_miss() {
659        let cache = QueryCache::new();
660        let args = json!({"id": 123});
661
662        let result = cache.get("get_user", &args, Some("user:1"));
663        assert_eq!(result, None);
664    }
665
666    #[test]
667    fn test_cache_invalidate() {
668        let cache = QueryCache::new();
669        let args = json!({"id": 123});
670        let value = json!({"name": "test"});
671
672        cache.set(
673            "get_user",
674            &args,
675            Some("user:1"),
676            value,
677            Duration::from_secs(60),
678        );
679        cache.invalidate("get_user", &args);
680
681        let result = cache.get("get_user", &args, Some("user:1"));
682        assert_eq!(result, None);
683    }
684
685    #[test]
686    fn test_cache_invalidate_function() {
687        let cache = QueryCache::new();
688        let args1 = json!({"id": 1});
689        let args2 = json!({"id": 2});
690
691        cache.set(
692            "get_user",
693            &args1,
694            Some("user:1"),
695            json!({"name": "a"}),
696            Duration::from_secs(60),
697        );
698        cache.set(
699            "get_user",
700            &args2,
701            Some("user:1"),
702            json!({"name": "b"}),
703            Duration::from_secs(60),
704        );
705        cache.set(
706            "list_users",
707            &json!({}),
708            Some("user:1"),
709            json!([]),
710            Duration::from_secs(60),
711        );
712
713        cache.invalidate_function("get_user");
714
715        assert_eq!(cache.get("get_user", &args1, Some("user:1")), None);
716        assert_eq!(cache.get("get_user", &args2, Some("user:1")), None);
717        assert!(
718            cache
719                .get("list_users", &json!({}), Some("user:1"))
720                .is_some()
721        );
722    }
723
724    #[test]
725    fn test_hash_consistency() {
726        let v1 = json!({"a": 1, "b": 2});
727        let v2 = json!({"b": 2, "a": 1});
728
729        // Object keys should be sorted for consistent hashing
730        assert_eq!(hash_value(&v1), hash_value(&v2));
731    }
732
733    #[test]
734    fn test_auth_scope_stable_across_token_refresh() {
735        let user_id = uuid::Uuid::new_v4();
736        let mut claims_t1 = std::collections::HashMap::new();
737        claims_t1.insert(
738            "sub".to_string(),
739            serde_json::Value::String(user_id.to_string()),
740        );
741        claims_t1.insert(
742            "tenant_id".to_string(),
743            serde_json::Value::String("acme".to_string()),
744        );
745        claims_t1.insert("iat".to_string(), serde_json::Value::from(1_700_000_000));
746        claims_t1.insert("exp".to_string(), serde_json::Value::from(1_700_003_600));
747        claims_t1.insert("nbf".to_string(), serde_json::Value::from(1_700_000_000));
748        claims_t1.insert(
749            "jti".to_string(),
750            serde_json::Value::String("token-uuid-1".to_string()),
751        );
752
753        let mut claims_t2 = claims_t1.clone();
754        claims_t2.insert("iat".to_string(), serde_json::Value::from(1_700_010_000));
755        claims_t2.insert("exp".to_string(), serde_json::Value::from(1_700_013_600));
756        claims_t2.insert("nbf".to_string(), serde_json::Value::from(1_700_010_000));
757        claims_t2.insert(
758            "jti".to_string(),
759            serde_json::Value::String("token-uuid-2".to_string()),
760        );
761
762        let auth_t1 = AuthContext::authenticated(user_id, vec!["member".to_string()], claims_t1);
763        let auth_t2 = AuthContext::authenticated(user_id, vec!["member".to_string()], claims_t2);
764
765        assert_eq!(
766            QueryCacheCoordinator::auth_scope(&auth_t1),
767            QueryCacheCoordinator::auth_scope(&auth_t2),
768            "Token refresh must not change cache scope for the same principal"
769        );
770    }
771
772    #[test]
773    fn test_auth_scope_differs_by_tenant() {
774        let user_id = uuid::Uuid::new_v4();
775        let mut claims_a = std::collections::HashMap::new();
776        claims_a.insert(
777            "sub".to_string(),
778            serde_json::Value::String(user_id.to_string()),
779        );
780        claims_a.insert(
781            "tenant_id".to_string(),
782            serde_json::Value::String("tenant-a".to_string()),
783        );
784
785        let mut claims_b = std::collections::HashMap::new();
786        claims_b.insert(
787            "sub".to_string(),
788            serde_json::Value::String(user_id.to_string()),
789        );
790        claims_b.insert(
791            "tenant_id".to_string(),
792            serde_json::Value::String("tenant-b".to_string()),
793        );
794
795        let auth_a = AuthContext::authenticated(user_id, vec!["member".to_string()], claims_a);
796        let auth_b = AuthContext::authenticated(user_id, vec!["member".to_string()], claims_b);
797
798        assert_ne!(
799            QueryCacheCoordinator::auth_scope(&auth_a),
800            QueryCacheCoordinator::auth_scope(&auth_b),
801            "Different tenant claims must produce distinct scopes"
802        );
803    }
804
805    /// Build a coordinator with a pre-populated table -> dependent-queries
806    /// index. Lets the invalidation tests assert behaviour without standing
807    /// up a full FunctionRegistry of fake handlers.
808    fn coordinator_with_deps(deps: Vec<(&str, &str, &[&str])>) -> QueryCacheCoordinator {
809        let mut index: HashMap<String, Vec<QueryDep>> = HashMap::new();
810        for (table, name, cols) in deps {
811            index.entry(table.to_string()).or_default().push(QueryDep {
812                name: name.to_string(),
813                selected_columns: cols.iter().map(|c| (*c).to_string()).collect(),
814            });
815        }
816        QueryCacheCoordinator {
817            cache: QueryCache::new(),
818            table_to_queries: index,
819        }
820    }
821
822    fn mutation_info(
823        name: &'static str,
824        tables: &'static [&'static str],
825        changed: &'static [&'static str],
826    ) -> FunctionInfo {
827        FunctionInfo {
828            name,
829            description: None,
830            kind: forge_core::FunctionKind::Mutation,
831            required_role: None,
832            is_public: false,
833            cache_ttl: None,
834            timeout: None,
835            http_timeout: None,
836            rate_limit_requests: None,
837            rate_limit_per_secs: None,
838            rate_limit_key: None,
839            log_level: None,
840            table_dependencies: tables,
841            selected_columns: &[],
842            changed_columns: changed,
843            transactional: true,
844            consistent: false,
845            max_upload_size_bytes: None,
846            requires_tenant_scope: false,
847        }
848    }
849
850    #[test]
851    fn invalidate_skips_query_when_columns_disjoint() {
852        let coord = coordinator_with_deps(vec![
853            ("users", "list_user_emails", &["id", "email"]),
854            ("users", "list_user_names", &["id", "name"]),
855        ]);
856        coord.set_by_scope(
857            "list_user_emails",
858            &json!({}),
859            Some("anon"),
860            json!([]),
861            Duration::from_secs(60),
862        );
863        coord.set_by_scope(
864            "list_user_names",
865            &json!({}),
866            Some("anon"),
867            json!([]),
868            Duration::from_secs(60),
869        );
870
871        // Mutation changes only `name`. Email-only query should survive.
872        coord.invalidate_for_mutation(&mutation_info("rename_user", &["users"], &["name"]));
873
874        assert!(
875            coord
876                .get_by_scope("list_user_emails", &json!({}), Some("anon"))
877                .is_some(),
878            "email query must survive a name-only mutation"
879        );
880        assert!(
881            coord
882                .get_by_scope("list_user_names", &json!({}), Some("anon"))
883                .is_none(),
884            "name query must be invalidated"
885        );
886    }
887
888    #[test]
889    fn invalidate_falls_back_when_mutation_columns_unknown() {
890        let coord = coordinator_with_deps(vec![("users", "list_user_emails", &["id", "email"])]);
891        coord.set_by_scope(
892            "list_user_emails",
893            &json!({}),
894            Some("anon"),
895            json!([]),
896            Duration::from_secs(60),
897        );
898
899        // Empty changed_columns => we don't know what the mutation touched,
900        // so every dependent query has to go.
901        coord.invalidate_for_mutation(&mutation_info("opaque_mutation", &["users"], &[]));
902
903        assert!(
904            coord
905                .get_by_scope("list_user_emails", &json!({}), Some("anon"))
906                .is_none(),
907            "unknown column set must fall back to full invalidation"
908        );
909    }
910
911    #[test]
912    fn invalidate_falls_back_when_query_columns_unknown() {
913        // Query with no extracted columns means dynamic SQL; we can't reason
914        // about what it reads, so any mutation on the table invalidates it.
915        let coord = coordinator_with_deps(vec![("users", "dynamic_query", &[])]);
916        coord.set_by_scope(
917            "dynamic_query",
918            &json!({}),
919            Some("anon"),
920            json!([]),
921            Duration::from_secs(60),
922        );
923
924        coord.invalidate_for_mutation(&mutation_info("rename_user", &["users"], &["name"]));
925
926        assert!(
927            coord
928                .get_by_scope("dynamic_query", &json!({}), Some("anon"))
929                .is_none(),
930            "queries with unknown selected columns must always be invalidated"
931        );
932    }
933
934    #[test]
935    fn invalidate_by_change_evicts_matching_query() {
936        use forge_core::realtime::{Change, ChangeOperation};
937
938        let coord = coordinator_with_deps(vec![("users", "list_user_names", &["id", "name"])]);
939        coord.set_by_scope(
940            "list_user_names",
941            &json!({}),
942            Some("anon"),
943            json!([]),
944            Duration::from_secs(60),
945        );
946
947        let change =
948            Change::new("users", ChangeOperation::Update).with_columns(vec!["name".to_string()]);
949        coord.invalidate_by_change(&change);
950
951        assert!(
952            coord
953                .get_by_scope("list_user_names", &json!({}), Some("anon"))
954                .is_none(),
955            "name change must invalidate name-reading query"
956        );
957    }
958
959    #[test]
960    fn invalidate_by_change_skips_disjoint_columns() {
961        use forge_core::realtime::{Change, ChangeOperation};
962
963        let coord = coordinator_with_deps(vec![("users", "list_user_emails", &["id", "email"])]);
964        coord.set_by_scope(
965            "list_user_emails",
966            &json!({}),
967            Some("anon"),
968            json!([]),
969            Duration::from_secs(60),
970        );
971
972        // Change touched only `name`; the email-only query should survive.
973        let change =
974            Change::new("users", ChangeOperation::Update).with_columns(vec!["name".to_string()]);
975        coord.invalidate_by_change(&change);
976
977        assert!(
978            coord
979                .get_by_scope("list_user_emails", &json!({}), Some("anon"))
980                .is_some(),
981            "disjoint column change must not invalidate"
982        );
983    }
984
985    #[test]
986    fn invalidate_by_change_falls_back_when_change_columns_unknown() {
987        use forge_core::realtime::{Change, ChangeOperation};
988
989        let coord = coordinator_with_deps(vec![("users", "list_user_emails", &["id", "email"])]);
990        coord.set_by_scope(
991            "list_user_emails",
992            &json!({}),
993            Some("anon"),
994            json!([]),
995            Duration::from_secs(60),
996        );
997
998        // Empty changed_columns (e.g. INSERT/DELETE) means "could be anything".
999        let change = Change::new("users", ChangeOperation::Insert);
1000        coord.invalidate_by_change(&change);
1001
1002        assert!(
1003            coord
1004                .get_by_scope("list_user_emails", &json!({}), Some("anon"))
1005                .is_none(),
1006            "unknown change columns must fall back to full invalidation"
1007        );
1008    }
1009
1010    #[test]
1011    fn invalidate_by_change_ignores_unrelated_table() {
1012        use forge_core::realtime::{Change, ChangeOperation};
1013
1014        let coord = coordinator_with_deps(vec![("users", "list_users", &["id"])]);
1015        coord.set_by_scope(
1016            "list_users",
1017            &json!({}),
1018            Some("anon"),
1019            json!([]),
1020            Duration::from_secs(60),
1021        );
1022
1023        let change = Change::new("orders", ChangeOperation::Update);
1024        coord.invalidate_by_change(&change);
1025
1026        assert!(
1027            coord
1028                .get_by_scope("list_users", &json!({}), Some("anon"))
1029                .is_some(),
1030            "change to unrelated table must not invalidate"
1031        );
1032    }
1033
1034    #[tokio::test]
1035    async fn cluster_invalidator_evicts_on_broadcast() {
1036        use forge_core::realtime::{Change, ChangeOperation};
1037
1038        let coord = Arc::new(coordinator_with_deps(vec![(
1039            "users",
1040            "list_user_names",
1041            &["id", "name"],
1042        )]));
1043        coord.set_by_scope(
1044            "list_user_names",
1045            &json!({}),
1046            Some("anon"),
1047            json!([]),
1048            Duration::from_secs(60),
1049        );
1050
1051        let (tx, rx) = broadcast::channel::<Change>(8);
1052        let handle = Arc::clone(&coord).spawn_cluster_invalidator(rx);
1053
1054        tx.send(
1055            Change::new("users", ChangeOperation::Update).with_columns(vec!["name".to_string()]),
1056        )
1057        .expect("send must succeed with active receiver");
1058
1059        // Allow the spawned task to drain the broadcast.
1060        for _ in 0..50 {
1061            if coord
1062                .get_by_scope("list_user_names", &json!({}), Some("anon"))
1063                .is_none()
1064            {
1065                break;
1066            }
1067            tokio::time::sleep(Duration::from_millis(10)).await;
1068        }
1069
1070        assert!(
1071            coord
1072                .get_by_scope("list_user_names", &json!({}), Some("anon"))
1073                .is_none(),
1074            "broadcast change must reach the invalidator and evict the entry"
1075        );
1076
1077        drop(tx);
1078        handle.await.expect("invalidator task must exit cleanly");
1079    }
1080
1081    #[test]
1082    fn test_cache_isolation_by_auth_scope() {
1083        let cache = QueryCache::new();
1084        let args = json!({"id": 1});
1085
1086        cache.set(
1087            "get_profile",
1088            &args,
1089            Some("subject:user-a"),
1090            json!({"name": "Alice"}),
1091            Duration::from_secs(60),
1092        );
1093
1094        assert!(
1095            cache
1096                .get("get_profile", &args, Some("subject:user-b"))
1097                .is_none()
1098        );
1099        assert!(
1100            cache
1101                .get("get_profile", &args, Some("subject:user-a"))
1102                .is_some()
1103        );
1104    }
1105}