Skip to main content

reddb_server/runtime/ai/
answer_cache_key.rs

1//! `AnswerCacheKey` — pure key derivation and TTL policy for the ASK
2//! answer cache.
3//!
4//! Issue #403 (PRD #391): an opt-in answer cache lets ASK skip the LLM
5//! when the same question lands against the same data under the same
6//! determinism knobs. The cache is keyed by
7//! `hash(tenant, user_scope, question, provider, model, temperature,
8//! seed, sources_fingerprint)` and gated by per-query `CACHE TTL '5m'`
9//! / `NOCACHE` clauses on top of deployment defaults.
10//!
11//! Deep module: no I/O, no clock, no storage. The caller hands in the
12//! identity scope, the determinism-resolved request shape (`Applied`
13//! from #400 in real wiring, plain fields here so the module stays
14//! decoupled), and the source fingerprint that retrieval (#398) already
15//! computes. We return a stable lowercase-hex SHA-256 key and, given
16//! `Mode` + `Settings`, an effective TTL.
17//!
18//! ## Why the module owns these decisions
19//!
20//! The cache key is a security boundary: cross-tenant key collisions
21//! leak answers. Pinning the canonical form here — with tests around
22//! the per-tenant scope, around `Some(0)` vs `None` seed, around
23//! `temperature` float canonicalisation — keeps the key derivation in
24//! one place a reviewer can audit. The wiring slice that follows can
25//! treat the key as an opaque string.
26//!
27//! ## Key canonical form
28//!
29//! Fields are concatenated in fixed order with the ASCII Unit Separator
30//! (0x1f) as delimiter:
31//!
32//! ```text
33//! tenant | 0x1f | user | 0x1f | question | 0x1f | provider | 0x1f
34//!     | model | 0x1f | temperature | 0x1f | seed | 0x1f | fingerprint
35//! ```
36//!
37//! - `temperature` serializes as `"none"` when absent, otherwise as the
38//!   shortest round-tripping IEEE-754 representation produced by Rust's
39//!   `{}` formatter (`0`, `0.5`, etc.). `0` and `none` are distinct.
40//! - `seed` serializes as `"none"` when absent, otherwise as the decimal
41//!   `u64`. `0` and `none` are distinct (guards against the same kind
42//!   of `unwrap_or(0)` regression `DeterminismDecider` already pins).
43//! - `0x1f` cannot appear in any of the inputs (SQL parser rejects it
44//!   in strings; the fingerprint, provider, model, decimals, and hex
45//!   are all ASCII printable), so the concatenation is injective without
46//!   escaping. Same trick as [`super::determinism_decider::derive_seed`].
47
48use std::time::Duration;
49
50use sha2::{Digest, Sha256};
51
52/// Identity scope. `tenant` is mandatory; `user` is empty when the
53/// cache should be tenant-wide. Anonymous / embedded callers with no
54/// auth context pass empty strings for both.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct Scope<'a> {
57    pub tenant: &'a str,
58    pub user: &'a str,
59}
60
61/// All inputs that determine which answer a given call would receive.
62/// Re-evaluating against a changed `temperature`, `seed`, `model`, or
63/// `sources_fingerprint` must miss the cache, so each appears verbatim
64/// in the key.
65#[derive(Debug, Clone, Copy)]
66pub struct Inputs<'a> {
67    pub question: &'a str,
68    pub provider: &'a str,
69    pub model: &'a str,
70    /// The temperature actually sent to the provider — i.e. what
71    /// `DeterminismDecider::decide` returned, not what the user asked
72    /// for.
73    pub temperature: Option<f32>,
74    /// The seed actually sent — same caveat as `temperature`.
75    pub seed: Option<u64>,
76    /// Opaque stable fingerprint over the retrieved sources (URNs +
77    /// content versions). The retrieval layer (#398) owns the format.
78    pub sources_fingerprint: &'a str,
79}
80
81/// Per-query `CACHE TTL '...'` / `NOCACHE` clause, parsed from the SQL
82/// surface. Default-constructed `Mode` is [`Mode::Default`], which
83/// means "fall back to settings".
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
85pub enum Mode {
86    /// No per-query opinion. The effective behaviour comes from
87    /// [`Settings::enabled`] / [`Settings::default_ttl`].
88    #[default]
89    Default,
90    /// `ASK '...' CACHE TTL '5m'` — populate and consult the cache
91    /// with this TTL regardless of the global default.
92    Cache(Duration),
93    /// `ASK '...' NOCACHE` — bypass the cache entirely on this call.
94    NoCache,
95}
96
97/// Deployment-level cache settings, surfaced via `ask.cache.*`.
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99pub struct Settings {
100    /// `ask.cache.enabled` (default `false`).
101    pub enabled: bool,
102    /// `ask.cache.default_ttl`. `None` means "no default TTL"; queries
103    /// must opt in with `CACHE TTL '...'` to populate the cache.
104    pub default_ttl: Option<Duration>,
105    /// `ask.cache.max_entries`. Not consulted here — the eviction
106    /// policy lives in the cache store. Exposed for completeness.
107    pub max_entries: usize,
108}
109
110/// What the cache wrapper should do for a single ASK call.
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum Decision {
113    /// Skip the cache entirely (do not read, do not write).
114    Bypass,
115    /// Consult the cache; on miss, populate with `ttl`.
116    Use { ttl: Duration },
117}
118
119/// Combine the per-query [`Mode`] with deployment [`Settings`] to get
120/// the effective behaviour for this call.
121///
122/// Rules:
123/// - `NOCACHE` always wins (explicit user opt-out).
124/// - `CACHE TTL t` always wins when present (explicit user opt-in;
125///   the deployment toggle does NOT gate per-query opt-in, only the
126///   silent default).
127/// - `Default` + `enabled=true` + `default_ttl=Some(t)` → use, ttl=t.
128/// - `Default` + anything else → bypass.
129pub fn decide(mode: Mode, settings: Settings) -> Decision {
130    match mode {
131        Mode::NoCache => Decision::Bypass,
132        Mode::Cache(ttl) => Decision::Use { ttl },
133        Mode::Default => match (settings.enabled, settings.default_ttl) {
134            (true, Some(ttl)) => Decision::Use { ttl },
135            _ => Decision::Bypass,
136        },
137    }
138}
139
140/// Derive the lowercase-hex SHA-256 cache key for one ASK call.
141///
142/// The key is a function of identity scope + request-shape inputs. It
143/// does NOT include the TTL — two calls with the same identity and
144/// shape collide on the same entry regardless of how long that entry
145/// will live, which is the correct hit/miss semantic.
146pub fn derive_key(scope: Scope<'_>, inputs: Inputs<'_>) -> String {
147    const SEP: u8 = 0x1f;
148    let mut hasher = Sha256::new();
149    hasher.update(scope.tenant.as_bytes());
150    hasher.update([SEP]);
151    hasher.update(scope.user.as_bytes());
152    hasher.update([SEP]);
153    hasher.update(inputs.question.as_bytes());
154    hasher.update([SEP]);
155    hasher.update(inputs.provider.as_bytes());
156    hasher.update([SEP]);
157    hasher.update(inputs.model.as_bytes());
158    hasher.update([SEP]);
159    hasher.update(format_temperature(inputs.temperature).as_bytes());
160    hasher.update([SEP]);
161    hasher.update(format_seed(inputs.seed).as_bytes());
162    hasher.update([SEP]);
163    hasher.update(inputs.sources_fingerprint.as_bytes());
164    let digest = hasher.finalize();
165    let mut out = String::with_capacity(digest.len() * 2);
166    for b in digest {
167        out.push_str(&format!("{b:02x}"));
168    }
169    out
170}
171
172fn format_temperature(t: Option<f32>) -> String {
173    match t {
174        None => "none".to_string(),
175        Some(v) => format!("{v}"),
176    }
177}
178
179fn format_seed(s: Option<u64>) -> String {
180    match s {
181        None => "none".to_string(),
182        Some(v) => v.to_string(),
183    }
184}
185
186/// Parse a TTL literal from `CACHE TTL '<lit>'`.
187///
188/// Accepts `<integer><unit>` with units `s` (seconds), `m` (minutes),
189/// `h` (hours), `d` (days). Whitespace is not allowed. The integer
190/// must be > 0; a zero TTL would mean "expire immediately" which is a
191/// foot-gun the parser refuses on the user's behalf.
192pub fn parse_ttl(literal: &str) -> Result<Duration, TtlParseError> {
193    if literal.is_empty() {
194        return Err(TtlParseError::Empty);
195    }
196    let bytes = literal.as_bytes();
197    let unit_idx = bytes
198        .iter()
199        .position(|b| !b.is_ascii_digit())
200        .ok_or(TtlParseError::MissingUnit)?;
201    if unit_idx == 0 {
202        return Err(TtlParseError::MissingNumber);
203    }
204    let (num_part, unit_part) = literal.split_at(unit_idx);
205    let n: u64 = num_part.parse().map_err(|_| TtlParseError::InvalidNumber)?;
206    if n == 0 {
207        return Err(TtlParseError::ZeroTtl);
208    }
209    let secs = match unit_part {
210        "s" => n,
211        "m" => n.checked_mul(60).ok_or(TtlParseError::Overflow)?,
212        "h" => n.checked_mul(3600).ok_or(TtlParseError::Overflow)?,
213        "d" => n.checked_mul(86_400).ok_or(TtlParseError::Overflow)?,
214        _ => return Err(TtlParseError::UnknownUnit),
215    };
216    Ok(Duration::from_secs(secs))
217}
218
219/// Why [`parse_ttl`] rejected a literal. Named variants so the runtime
220/// can map each to a deterministic error message without a stringly
221/// typed switch.
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum TtlParseError {
224    Empty,
225    MissingNumber,
226    MissingUnit,
227    InvalidNumber,
228    UnknownUnit,
229    ZeroTtl,
230    Overflow,
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    fn scope() -> Scope<'static> {
238        Scope {
239            tenant: "acme",
240            user: "alice",
241        }
242    }
243
244    fn inputs() -> Inputs<'static> {
245        Inputs {
246            question: "what is the capital of france?",
247            provider: "openai",
248            model: "gpt-4o-mini",
249            temperature: Some(0.0),
250            seed: Some(42),
251            sources_fingerprint: "abc123",
252        }
253    }
254
255    // ---- key: determinism & scope separation -------------------------
256
257    #[test]
258    fn key_is_deterministic_across_calls() {
259        let k1 = derive_key(scope(), inputs());
260        let k2 = derive_key(scope(), inputs());
261        assert_eq!(k1, k2);
262        // sha256 hex is 64 chars.
263        assert_eq!(k1.len(), 64);
264        assert!(k1
265            .chars()
266            .all(|c| c.is_ascii_hexdigit() && !c.is_uppercase()));
267    }
268
269    #[test]
270    fn key_changes_with_tenant() {
271        let a = derive_key(
272            Scope {
273                tenant: "acme",
274                user: "alice",
275            },
276            inputs(),
277        );
278        let b = derive_key(
279            Scope {
280                tenant: "globex",
281                user: "alice",
282            },
283            inputs(),
284        );
285        assert_ne!(a, b, "per-tenant scope must isolate cache keys");
286    }
287
288    #[test]
289    fn key_changes_with_user() {
290        let a = derive_key(
291            Scope {
292                tenant: "acme",
293                user: "alice",
294            },
295            inputs(),
296        );
297        let b = derive_key(
298            Scope {
299                tenant: "acme",
300                user: "bob",
301            },
302            inputs(),
303        );
304        assert_ne!(a, b);
305    }
306
307    #[test]
308    fn empty_user_is_distinct_from_named_user() {
309        let anon = derive_key(
310            Scope {
311                tenant: "acme",
312                user: "",
313            },
314            inputs(),
315        );
316        let named = derive_key(scope(), inputs());
317        assert_ne!(anon, named);
318    }
319
320    // ---- key: every input field actually feeds the digest ------------
321
322    #[test]
323    fn key_changes_with_question() {
324        let mut i = inputs();
325        let base = derive_key(scope(), i);
326        i.question = "different question";
327        let other = derive_key(scope(), i);
328        assert_ne!(base, other);
329    }
330
331    #[test]
332    fn key_changes_with_provider() {
333        let mut i = inputs();
334        let base = derive_key(scope(), i);
335        i.provider = "anthropic";
336        let other = derive_key(scope(), i);
337        assert_ne!(base, other);
338    }
339
340    #[test]
341    fn key_changes_with_model() {
342        let mut i = inputs();
343        let base = derive_key(scope(), i);
344        i.model = "gpt-4o";
345        let other = derive_key(scope(), i);
346        assert_ne!(base, other);
347    }
348
349    #[test]
350    fn key_changes_with_temperature() {
351        let mut i = inputs();
352        let base = derive_key(scope(), i);
353        i.temperature = Some(0.7);
354        let other = derive_key(scope(), i);
355        assert_ne!(base, other);
356    }
357
358    #[test]
359    fn key_changes_with_seed() {
360        let mut i = inputs();
361        let base = derive_key(scope(), i);
362        i.seed = Some(43);
363        let other = derive_key(scope(), i);
364        assert_ne!(base, other);
365    }
366
367    #[test]
368    fn key_changes_with_fingerprint() {
369        let mut i = inputs();
370        let base = derive_key(scope(), i);
371        i.sources_fingerprint = "def456";
372        let other = derive_key(scope(), i);
373        assert_ne!(
374            base, other,
375            "different sources must miss cache even for identical question"
376        );
377    }
378
379    // ---- key: None vs Some(0) for optional knobs ---------------------
380
381    #[test]
382    fn temperature_none_distinct_from_zero() {
383        let mut i = inputs();
384        i.temperature = None;
385        let none = derive_key(scope(), i);
386        i.temperature = Some(0.0);
387        let zero = derive_key(scope(), i);
388        assert_ne!(
389            none, zero,
390            "None and Some(0.0) must not collide — a provider that ignores temperature is not the same as one that received zero"
391        );
392    }
393
394    #[test]
395    fn seed_none_distinct_from_zero() {
396        let mut i = inputs();
397        i.seed = None;
398        let none = derive_key(scope(), i);
399        i.seed = Some(0);
400        let zero = derive_key(scope(), i);
401        assert_ne!(none, zero);
402    }
403
404    // ---- key: pin the canonical form against accidental change ------
405
406    #[test]
407    fn key_pinned_against_known_value() {
408        // If the canonical form ever changes (delimiter, field order,
409        // float/seed serialization), this test will fail loudly. Update
410        // the literal only on a deliberate schema bump and bump
411        // ask.cache.max_entries-style call sites accordingly.
412        let scope = Scope {
413            tenant: "t",
414            user: "u",
415        };
416        let i = Inputs {
417            question: "q",
418            provider: "p",
419            model: "m",
420            temperature: Some(0.0),
421            seed: Some(1),
422            sources_fingerprint: "f",
423        };
424        let key = derive_key(scope, i);
425        // Computed by `printf 't\x1fu\x1fq\x1fp\x1fm\x1f0\x1f1\x1ff' | sha256sum`.
426        assert_eq!(
427            key,
428            "ca47974209a1e07b9890aa73b5bdbcc2fda1bae0ba1d77f186c9dc168b54f903"
429        );
430    }
431
432    // ---- decide(): TTL policy ---------------------------------------
433
434    #[test]
435    fn decide_nocache_always_bypasses() {
436        let s = Settings {
437            enabled: true,
438            default_ttl: Some(Duration::from_secs(60)),
439            max_entries: 100,
440        };
441        assert_eq!(decide(Mode::NoCache, s), Decision::Bypass);
442    }
443
444    #[test]
445    fn decide_per_query_cache_wins_over_disabled_setting() {
446        let s = Settings::default();
447        assert_eq!(
448            decide(Mode::Cache(Duration::from_secs(300)), s),
449            Decision::Use {
450                ttl: Duration::from_secs(300)
451            }
452        );
453    }
454
455    #[test]
456    fn decide_default_bypass_when_disabled() {
457        let s = Settings {
458            enabled: false,
459            default_ttl: Some(Duration::from_secs(60)),
460            max_entries: 100,
461        };
462        assert_eq!(decide(Mode::Default, s), Decision::Bypass);
463    }
464
465    #[test]
466    fn decide_default_bypass_when_no_default_ttl() {
467        let s = Settings {
468            enabled: true,
469            default_ttl: None,
470            max_entries: 100,
471        };
472        assert_eq!(decide(Mode::Default, s), Decision::Bypass);
473    }
474
475    #[test]
476    fn decide_default_uses_setting_ttl_when_enabled_and_ttl_set() {
477        let s = Settings {
478            enabled: true,
479            default_ttl: Some(Duration::from_secs(120)),
480            max_entries: 100,
481        };
482        assert_eq!(
483            decide(Mode::Default, s),
484            Decision::Use {
485                ttl: Duration::from_secs(120)
486            }
487        );
488    }
489
490    #[test]
491    fn decide_per_query_cache_overrides_setting_default() {
492        let s = Settings {
493            enabled: true,
494            default_ttl: Some(Duration::from_secs(60)),
495            max_entries: 100,
496        };
497        assert_eq!(
498            decide(Mode::Cache(Duration::from_secs(900)), s),
499            Decision::Use {
500                ttl: Duration::from_secs(900)
501            }
502        );
503    }
504
505    // ---- parse_ttl() ------------------------------------------------
506
507    #[test]
508    fn parse_ttl_seconds() {
509        assert_eq!(parse_ttl("30s").unwrap(), Duration::from_secs(30));
510    }
511
512    #[test]
513    fn parse_ttl_minutes() {
514        assert_eq!(parse_ttl("5m").unwrap(), Duration::from_secs(300));
515    }
516
517    #[test]
518    fn parse_ttl_hours() {
519        assert_eq!(parse_ttl("2h").unwrap(), Duration::from_secs(7200));
520    }
521
522    #[test]
523    fn parse_ttl_days() {
524        assert_eq!(parse_ttl("1d").unwrap(), Duration::from_secs(86_400));
525    }
526
527    #[test]
528    fn parse_ttl_empty_rejected() {
529        assert_eq!(parse_ttl(""), Err(TtlParseError::Empty));
530    }
531
532    #[test]
533    fn parse_ttl_zero_rejected() {
534        // 0s is a foot-gun: an entry that expires the instant it's
535        // written. Refuse it so misconfiguration shows up at parse time.
536        assert_eq!(parse_ttl("0s"), Err(TtlParseError::ZeroTtl));
537    }
538
539    #[test]
540    fn parse_ttl_missing_unit_rejected() {
541        assert_eq!(parse_ttl("30"), Err(TtlParseError::MissingUnit));
542    }
543
544    #[test]
545    fn parse_ttl_missing_number_rejected() {
546        assert_eq!(parse_ttl("m"), Err(TtlParseError::MissingNumber));
547    }
548
549    #[test]
550    fn parse_ttl_unknown_unit_rejected() {
551        assert_eq!(parse_ttl("5x"), Err(TtlParseError::UnknownUnit));
552        assert_eq!(parse_ttl("5ms"), Err(TtlParseError::UnknownUnit));
553    }
554
555    #[test]
556    fn parse_ttl_whitespace_rejected() {
557        // The SQL surface strips quotes already; we should not be
558        // lenient about embedded whitespace inside the literal.
559        assert_eq!(parse_ttl("5 m"), Err(TtlParseError::UnknownUnit));
560        assert_eq!(parse_ttl(" 5m"), Err(TtlParseError::MissingNumber));
561    }
562
563    #[test]
564    fn parse_ttl_negative_rejected() {
565        // Leading '-' is not a digit, so position(!is_ascii_digit) =
566        // 0 → MissingNumber. Pinned for clarity.
567        assert_eq!(parse_ttl("-5m"), Err(TtlParseError::MissingNumber));
568    }
569
570    #[test]
571    fn parse_ttl_invalid_number_rejected() {
572        // u64 overflow at the integer parse step.
573        assert_eq!(
574            parse_ttl("99999999999999999999s"),
575            Err(TtlParseError::InvalidNumber)
576        );
577    }
578
579    #[test]
580    fn parse_ttl_overflow_on_unit_multiplication() {
581        // Large number that fits in u64 but overflows once multiplied
582        // by 86_400.
583        let max_d = u64::MAX / 86_400 + 1;
584        let lit = format!("{}d", max_d);
585        assert_eq!(parse_ttl(&lit), Err(TtlParseError::Overflow));
586    }
587
588    // ---- mode default ----------------------------------------------
589
590    #[test]
591    fn mode_default_is_inherit() {
592        assert_eq!(Mode::default(), Mode::Default);
593    }
594
595    // ---- determinism across modes ----------------------------------
596
597    #[test]
598    fn decide_is_deterministic_across_calls() {
599        let s = Settings {
600            enabled: true,
601            default_ttl: Some(Duration::from_secs(60)),
602            max_entries: 10,
603        };
604        for mode in [
605            Mode::Default,
606            Mode::NoCache,
607            Mode::Cache(Duration::from_secs(120)),
608        ] {
609            let d1 = decide(mode, s);
610            let d2 = decide(mode, s);
611            assert_eq!(d1, d2);
612        }
613    }
614}