1use std::time::Duration;
49
50use sha2::{Digest, Sha256};
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct Scope<'a> {
57 pub tenant: &'a str,
58 pub user: &'a str,
59}
60
61#[derive(Debug, Clone, Copy)]
66pub struct Inputs<'a> {
67 pub question: &'a str,
68 pub provider: &'a str,
69 pub model: &'a str,
70 pub temperature: Option<f32>,
74 pub seed: Option<u64>,
76 pub sources_fingerprint: &'a str,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum Mode {
86 Default,
89 Cache(Duration),
92 NoCache,
94}
95
96impl Default for Mode {
97 fn default() -> Self {
98 Mode::Default
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub struct Settings {
105 pub enabled: bool,
107 pub default_ttl: Option<Duration>,
110 pub max_entries: usize,
113}
114
115impl Default for Settings {
116 fn default() -> Self {
117 Self {
118 enabled: false,
119 default_ttl: None,
120 max_entries: 0,
121 }
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum Decision {
128 Bypass,
130 Use { ttl: Duration },
132}
133
134pub fn decide(mode: Mode, settings: Settings) -> Decision {
145 match mode {
146 Mode::NoCache => Decision::Bypass,
147 Mode::Cache(ttl) => Decision::Use { ttl },
148 Mode::Default => match (settings.enabled, settings.default_ttl) {
149 (true, Some(ttl)) => Decision::Use { ttl },
150 _ => Decision::Bypass,
151 },
152 }
153}
154
155pub fn derive_key(scope: Scope<'_>, inputs: Inputs<'_>) -> String {
162 const SEP: u8 = 0x1f;
163 let mut hasher = Sha256::new();
164 hasher.update(scope.tenant.as_bytes());
165 hasher.update([SEP]);
166 hasher.update(scope.user.as_bytes());
167 hasher.update([SEP]);
168 hasher.update(inputs.question.as_bytes());
169 hasher.update([SEP]);
170 hasher.update(inputs.provider.as_bytes());
171 hasher.update([SEP]);
172 hasher.update(inputs.model.as_bytes());
173 hasher.update([SEP]);
174 hasher.update(format_temperature(inputs.temperature).as_bytes());
175 hasher.update([SEP]);
176 hasher.update(format_seed(inputs.seed).as_bytes());
177 hasher.update([SEP]);
178 hasher.update(inputs.sources_fingerprint.as_bytes());
179 let digest = hasher.finalize();
180 let mut out = String::with_capacity(digest.len() * 2);
181 for b in digest {
182 out.push_str(&format!("{b:02x}"));
183 }
184 out
185}
186
187fn format_temperature(t: Option<f32>) -> String {
188 match t {
189 None => "none".to_string(),
190 Some(v) => format!("{v}"),
191 }
192}
193
194fn format_seed(s: Option<u64>) -> String {
195 match s {
196 None => "none".to_string(),
197 Some(v) => v.to_string(),
198 }
199}
200
201pub fn parse_ttl(literal: &str) -> Result<Duration, TtlParseError> {
208 if literal.is_empty() {
209 return Err(TtlParseError::Empty);
210 }
211 let bytes = literal.as_bytes();
212 let unit_idx = bytes
213 .iter()
214 .position(|b| !b.is_ascii_digit())
215 .ok_or(TtlParseError::MissingUnit)?;
216 if unit_idx == 0 {
217 return Err(TtlParseError::MissingNumber);
218 }
219 let (num_part, unit_part) = literal.split_at(unit_idx);
220 let n: u64 = num_part
221 .parse()
222 .map_err(|_| TtlParseError::InvalidNumber)?;
223 if n == 0 {
224 return Err(TtlParseError::ZeroTtl);
225 }
226 let secs = match unit_part {
227 "s" => n,
228 "m" => n.checked_mul(60).ok_or(TtlParseError::Overflow)?,
229 "h" => n.checked_mul(3600).ok_or(TtlParseError::Overflow)?,
230 "d" => n.checked_mul(86_400).ok_or(TtlParseError::Overflow)?,
231 _ => return Err(TtlParseError::UnknownUnit),
232 };
233 Ok(Duration::from_secs(secs))
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
240pub enum TtlParseError {
241 Empty,
242 MissingNumber,
243 MissingUnit,
244 InvalidNumber,
245 UnknownUnit,
246 ZeroTtl,
247 Overflow,
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 fn scope() -> Scope<'static> {
255 Scope {
256 tenant: "acme",
257 user: "alice",
258 }
259 }
260
261 fn inputs() -> Inputs<'static> {
262 Inputs {
263 question: "what is the capital of france?",
264 provider: "openai",
265 model: "gpt-4o-mini",
266 temperature: Some(0.0),
267 seed: Some(42),
268 sources_fingerprint: "abc123",
269 }
270 }
271
272 #[test]
275 fn key_is_deterministic_across_calls() {
276 let k1 = derive_key(scope(), inputs());
277 let k2 = derive_key(scope(), inputs());
278 assert_eq!(k1, k2);
279 assert_eq!(k1.len(), 64);
281 assert!(k1.chars().all(|c| c.is_ascii_hexdigit() && !c.is_uppercase()));
282 }
283
284 #[test]
285 fn key_changes_with_tenant() {
286 let a = derive_key(
287 Scope {
288 tenant: "acme",
289 user: "alice",
290 },
291 inputs(),
292 );
293 let b = derive_key(
294 Scope {
295 tenant: "globex",
296 user: "alice",
297 },
298 inputs(),
299 );
300 assert_ne!(a, b, "per-tenant scope must isolate cache keys");
301 }
302
303 #[test]
304 fn key_changes_with_user() {
305 let a = derive_key(
306 Scope {
307 tenant: "acme",
308 user: "alice",
309 },
310 inputs(),
311 );
312 let b = derive_key(
313 Scope {
314 tenant: "acme",
315 user: "bob",
316 },
317 inputs(),
318 );
319 assert_ne!(a, b);
320 }
321
322 #[test]
323 fn empty_user_is_distinct_from_named_user() {
324 let anon = derive_key(
325 Scope {
326 tenant: "acme",
327 user: "",
328 },
329 inputs(),
330 );
331 let named = derive_key(scope(), inputs());
332 assert_ne!(anon, named);
333 }
334
335 #[test]
338 fn key_changes_with_question() {
339 let mut i = inputs();
340 let base = derive_key(scope(), i);
341 i.question = "different question";
342 let other = derive_key(scope(), i);
343 assert_ne!(base, other);
344 }
345
346 #[test]
347 fn key_changes_with_provider() {
348 let mut i = inputs();
349 let base = derive_key(scope(), i);
350 i.provider = "anthropic";
351 let other = derive_key(scope(), i);
352 assert_ne!(base, other);
353 }
354
355 #[test]
356 fn key_changes_with_model() {
357 let mut i = inputs();
358 let base = derive_key(scope(), i);
359 i.model = "gpt-4o";
360 let other = derive_key(scope(), i);
361 assert_ne!(base, other);
362 }
363
364 #[test]
365 fn key_changes_with_temperature() {
366 let mut i = inputs();
367 let base = derive_key(scope(), i);
368 i.temperature = Some(0.7);
369 let other = derive_key(scope(), i);
370 assert_ne!(base, other);
371 }
372
373 #[test]
374 fn key_changes_with_seed() {
375 let mut i = inputs();
376 let base = derive_key(scope(), i);
377 i.seed = Some(43);
378 let other = derive_key(scope(), i);
379 assert_ne!(base, other);
380 }
381
382 #[test]
383 fn key_changes_with_fingerprint() {
384 let mut i = inputs();
385 let base = derive_key(scope(), i);
386 i.sources_fingerprint = "def456";
387 let other = derive_key(scope(), i);
388 assert_ne!(
389 base, other,
390 "different sources must miss cache even for identical question"
391 );
392 }
393
394 #[test]
397 fn temperature_none_distinct_from_zero() {
398 let mut i = inputs();
399 i.temperature = None;
400 let none = derive_key(scope(), i);
401 i.temperature = Some(0.0);
402 let zero = derive_key(scope(), i);
403 assert_ne!(
404 none, zero,
405 "None and Some(0.0) must not collide — a provider that ignores temperature is not the same as one that received zero"
406 );
407 }
408
409 #[test]
410 fn seed_none_distinct_from_zero() {
411 let mut i = inputs();
412 i.seed = None;
413 let none = derive_key(scope(), i);
414 i.seed = Some(0);
415 let zero = derive_key(scope(), i);
416 assert_ne!(none, zero);
417 }
418
419 #[test]
422 fn key_pinned_against_known_value() {
423 let scope = Scope {
428 tenant: "t",
429 user: "u",
430 };
431 let i = Inputs {
432 question: "q",
433 provider: "p",
434 model: "m",
435 temperature: Some(0.0),
436 seed: Some(1),
437 sources_fingerprint: "f",
438 };
439 let key = derive_key(scope, i);
440 assert_eq!(
442 key,
443 "ca47974209a1e07b9890aa73b5bdbcc2fda1bae0ba1d77f186c9dc168b54f903"
444 );
445 }
446
447 #[test]
450 fn decide_nocache_always_bypasses() {
451 let s = Settings {
452 enabled: true,
453 default_ttl: Some(Duration::from_secs(60)),
454 max_entries: 100,
455 };
456 assert_eq!(decide(Mode::NoCache, s), Decision::Bypass);
457 }
458
459 #[test]
460 fn decide_per_query_cache_wins_over_disabled_setting() {
461 let s = Settings::default();
462 assert_eq!(
463 decide(Mode::Cache(Duration::from_secs(300)), s),
464 Decision::Use {
465 ttl: Duration::from_secs(300)
466 }
467 );
468 }
469
470 #[test]
471 fn decide_default_bypass_when_disabled() {
472 let s = Settings {
473 enabled: false,
474 default_ttl: Some(Duration::from_secs(60)),
475 max_entries: 100,
476 };
477 assert_eq!(decide(Mode::Default, s), Decision::Bypass);
478 }
479
480 #[test]
481 fn decide_default_bypass_when_no_default_ttl() {
482 let s = Settings {
483 enabled: true,
484 default_ttl: None,
485 max_entries: 100,
486 };
487 assert_eq!(decide(Mode::Default, s), Decision::Bypass);
488 }
489
490 #[test]
491 fn decide_default_uses_setting_ttl_when_enabled_and_ttl_set() {
492 let s = Settings {
493 enabled: true,
494 default_ttl: Some(Duration::from_secs(120)),
495 max_entries: 100,
496 };
497 assert_eq!(
498 decide(Mode::Default, s),
499 Decision::Use {
500 ttl: Duration::from_secs(120)
501 }
502 );
503 }
504
505 #[test]
506 fn decide_per_query_cache_overrides_setting_default() {
507 let s = Settings {
508 enabled: true,
509 default_ttl: Some(Duration::from_secs(60)),
510 max_entries: 100,
511 };
512 assert_eq!(
513 decide(Mode::Cache(Duration::from_secs(900)), s),
514 Decision::Use {
515 ttl: Duration::from_secs(900)
516 }
517 );
518 }
519
520 #[test]
523 fn parse_ttl_seconds() {
524 assert_eq!(parse_ttl("30s").unwrap(), Duration::from_secs(30));
525 }
526
527 #[test]
528 fn parse_ttl_minutes() {
529 assert_eq!(parse_ttl("5m").unwrap(), Duration::from_secs(300));
530 }
531
532 #[test]
533 fn parse_ttl_hours() {
534 assert_eq!(parse_ttl("2h").unwrap(), Duration::from_secs(7200));
535 }
536
537 #[test]
538 fn parse_ttl_days() {
539 assert_eq!(parse_ttl("1d").unwrap(), Duration::from_secs(86_400));
540 }
541
542 #[test]
543 fn parse_ttl_empty_rejected() {
544 assert_eq!(parse_ttl(""), Err(TtlParseError::Empty));
545 }
546
547 #[test]
548 fn parse_ttl_zero_rejected() {
549 assert_eq!(parse_ttl("0s"), Err(TtlParseError::ZeroTtl));
552 }
553
554 #[test]
555 fn parse_ttl_missing_unit_rejected() {
556 assert_eq!(parse_ttl("30"), Err(TtlParseError::MissingUnit));
557 }
558
559 #[test]
560 fn parse_ttl_missing_number_rejected() {
561 assert_eq!(parse_ttl("m"), Err(TtlParseError::MissingNumber));
562 }
563
564 #[test]
565 fn parse_ttl_unknown_unit_rejected() {
566 assert_eq!(parse_ttl("5x"), Err(TtlParseError::UnknownUnit));
567 assert_eq!(parse_ttl("5ms"), Err(TtlParseError::UnknownUnit));
568 }
569
570 #[test]
571 fn parse_ttl_whitespace_rejected() {
572 assert_eq!(parse_ttl("5 m"), Err(TtlParseError::UnknownUnit));
575 assert_eq!(parse_ttl(" 5m"), Err(TtlParseError::MissingNumber));
576 }
577
578 #[test]
579 fn parse_ttl_negative_rejected() {
580 assert_eq!(parse_ttl("-5m"), Err(TtlParseError::MissingNumber));
583 }
584
585 #[test]
586 fn parse_ttl_invalid_number_rejected() {
587 assert_eq!(
589 parse_ttl("99999999999999999999s"),
590 Err(TtlParseError::InvalidNumber)
591 );
592 }
593
594 #[test]
595 fn parse_ttl_overflow_on_unit_multiplication() {
596 let max_d = u64::MAX / 86_400 + 1;
599 let lit = format!("{}d", max_d);
600 assert_eq!(parse_ttl(&lit), Err(TtlParseError::Overflow));
601 }
602
603 #[test]
606 fn mode_default_is_inherit() {
607 assert_eq!(Mode::default(), Mode::Default);
608 }
609
610 #[test]
613 fn decide_is_deterministic_across_calls() {
614 let s = Settings {
615 enabled: true,
616 default_ttl: Some(Duration::from_secs(60)),
617 max_entries: 10,
618 };
619 for mode in [
620 Mode::Default,
621 Mode::NoCache,
622 Mode::Cache(Duration::from_secs(120)),
623 ] {
624 let d1 = decide(mode, s);
625 let d2 = decide(mode, s);
626 assert_eq!(d1, d2);
627 }
628 }
629}