1use std::time::Duration;
49
50use sha2::{Digest, Sha256};
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct Scope<'a> {
57 pub tenant: &'a str,
58 pub user: &'a str,
59}
60
61#[derive(Debug, Clone, Copy)]
66pub struct Inputs<'a> {
67 pub question: &'a str,
68 pub provider: &'a str,
69 pub model: &'a str,
70 pub temperature: Option<f32>,
74 pub seed: Option<u64>,
76 pub sources_fingerprint: &'a str,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
85pub enum Mode {
86 #[default]
89 Default,
90 Cache(Duration),
93 NoCache,
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99pub struct Settings {
100 pub enabled: bool,
102 pub default_ttl: Option<Duration>,
105 pub max_entries: usize,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum Decision {
113 Bypass,
115 Use { ttl: Duration },
117}
118
119pub fn decide(mode: Mode, settings: Settings) -> Decision {
130 match mode {
131 Mode::NoCache => Decision::Bypass,
132 Mode::Cache(ttl) => Decision::Use { ttl },
133 Mode::Default => match (settings.enabled, settings.default_ttl) {
134 (true, Some(ttl)) => Decision::Use { ttl },
135 _ => Decision::Bypass,
136 },
137 }
138}
139
140pub fn derive_key(scope: Scope<'_>, inputs: Inputs<'_>) -> String {
147 const SEP: u8 = 0x1f;
148 let mut hasher = Sha256::new();
149 hasher.update(scope.tenant.as_bytes());
150 hasher.update([SEP]);
151 hasher.update(scope.user.as_bytes());
152 hasher.update([SEP]);
153 hasher.update(inputs.question.as_bytes());
154 hasher.update([SEP]);
155 hasher.update(inputs.provider.as_bytes());
156 hasher.update([SEP]);
157 hasher.update(inputs.model.as_bytes());
158 hasher.update([SEP]);
159 hasher.update(format_temperature(inputs.temperature).as_bytes());
160 hasher.update([SEP]);
161 hasher.update(format_seed(inputs.seed).as_bytes());
162 hasher.update([SEP]);
163 hasher.update(inputs.sources_fingerprint.as_bytes());
164 let digest = hasher.finalize();
165 let mut out = String::with_capacity(digest.len() * 2);
166 for b in digest {
167 out.push_str(&format!("{b:02x}"));
168 }
169 out
170}
171
172fn format_temperature(t: Option<f32>) -> String {
173 match t {
174 None => "none".to_string(),
175 Some(v) => format!("{v}"),
176 }
177}
178
179fn format_seed(s: Option<u64>) -> String {
180 match s {
181 None => "none".to_string(),
182 Some(v) => v.to_string(),
183 }
184}
185
186pub fn parse_ttl(literal: &str) -> Result<Duration, TtlParseError> {
193 if literal.is_empty() {
194 return Err(TtlParseError::Empty);
195 }
196 let bytes = literal.as_bytes();
197 let unit_idx = bytes
198 .iter()
199 .position(|b| !b.is_ascii_digit())
200 .ok_or(TtlParseError::MissingUnit)?;
201 if unit_idx == 0 {
202 return Err(TtlParseError::MissingNumber);
203 }
204 let (num_part, unit_part) = literal.split_at(unit_idx);
205 let n: u64 = num_part.parse().map_err(|_| TtlParseError::InvalidNumber)?;
206 if n == 0 {
207 return Err(TtlParseError::ZeroTtl);
208 }
209 let secs = match unit_part {
210 "s" => n,
211 "m" => n.checked_mul(60).ok_or(TtlParseError::Overflow)?,
212 "h" => n.checked_mul(3600).ok_or(TtlParseError::Overflow)?,
213 "d" => n.checked_mul(86_400).ok_or(TtlParseError::Overflow)?,
214 _ => return Err(TtlParseError::UnknownUnit),
215 };
216 Ok(Duration::from_secs(secs))
217}
218
219#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum TtlParseError {
224 Empty,
225 MissingNumber,
226 MissingUnit,
227 InvalidNumber,
228 UnknownUnit,
229 ZeroTtl,
230 Overflow,
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 fn scope() -> Scope<'static> {
238 Scope {
239 tenant: "acme",
240 user: "alice",
241 }
242 }
243
244 fn inputs() -> Inputs<'static> {
245 Inputs {
246 question: "what is the capital of france?",
247 provider: "openai",
248 model: "gpt-4o-mini",
249 temperature: Some(0.0),
250 seed: Some(42),
251 sources_fingerprint: "abc123",
252 }
253 }
254
255 #[test]
258 fn key_is_deterministic_across_calls() {
259 let k1 = derive_key(scope(), inputs());
260 let k2 = derive_key(scope(), inputs());
261 assert_eq!(k1, k2);
262 assert_eq!(k1.len(), 64);
264 assert!(k1
265 .chars()
266 .all(|c| c.is_ascii_hexdigit() && !c.is_uppercase()));
267 }
268
269 #[test]
270 fn key_changes_with_tenant() {
271 let a = derive_key(
272 Scope {
273 tenant: "acme",
274 user: "alice",
275 },
276 inputs(),
277 );
278 let b = derive_key(
279 Scope {
280 tenant: "globex",
281 user: "alice",
282 },
283 inputs(),
284 );
285 assert_ne!(a, b, "per-tenant scope must isolate cache keys");
286 }
287
288 #[test]
289 fn key_changes_with_user() {
290 let a = derive_key(
291 Scope {
292 tenant: "acme",
293 user: "alice",
294 },
295 inputs(),
296 );
297 let b = derive_key(
298 Scope {
299 tenant: "acme",
300 user: "bob",
301 },
302 inputs(),
303 );
304 assert_ne!(a, b);
305 }
306
307 #[test]
308 fn empty_user_is_distinct_from_named_user() {
309 let anon = derive_key(
310 Scope {
311 tenant: "acme",
312 user: "",
313 },
314 inputs(),
315 );
316 let named = derive_key(scope(), inputs());
317 assert_ne!(anon, named);
318 }
319
320 #[test]
323 fn key_changes_with_question() {
324 let mut i = inputs();
325 let base = derive_key(scope(), i);
326 i.question = "different question";
327 let other = derive_key(scope(), i);
328 assert_ne!(base, other);
329 }
330
331 #[test]
332 fn key_changes_with_provider() {
333 let mut i = inputs();
334 let base = derive_key(scope(), i);
335 i.provider = "anthropic";
336 let other = derive_key(scope(), i);
337 assert_ne!(base, other);
338 }
339
340 #[test]
341 fn key_changes_with_model() {
342 let mut i = inputs();
343 let base = derive_key(scope(), i);
344 i.model = "gpt-4o";
345 let other = derive_key(scope(), i);
346 assert_ne!(base, other);
347 }
348
349 #[test]
350 fn key_changes_with_temperature() {
351 let mut i = inputs();
352 let base = derive_key(scope(), i);
353 i.temperature = Some(0.7);
354 let other = derive_key(scope(), i);
355 assert_ne!(base, other);
356 }
357
358 #[test]
359 fn key_changes_with_seed() {
360 let mut i = inputs();
361 let base = derive_key(scope(), i);
362 i.seed = Some(43);
363 let other = derive_key(scope(), i);
364 assert_ne!(base, other);
365 }
366
367 #[test]
368 fn key_changes_with_fingerprint() {
369 let mut i = inputs();
370 let base = derive_key(scope(), i);
371 i.sources_fingerprint = "def456";
372 let other = derive_key(scope(), i);
373 assert_ne!(
374 base, other,
375 "different sources must miss cache even for identical question"
376 );
377 }
378
379 #[test]
382 fn temperature_none_distinct_from_zero() {
383 let mut i = inputs();
384 i.temperature = None;
385 let none = derive_key(scope(), i);
386 i.temperature = Some(0.0);
387 let zero = derive_key(scope(), i);
388 assert_ne!(
389 none, zero,
390 "None and Some(0.0) must not collide — a provider that ignores temperature is not the same as one that received zero"
391 );
392 }
393
394 #[test]
395 fn seed_none_distinct_from_zero() {
396 let mut i = inputs();
397 i.seed = None;
398 let none = derive_key(scope(), i);
399 i.seed = Some(0);
400 let zero = derive_key(scope(), i);
401 assert_ne!(none, zero);
402 }
403
404 #[test]
407 fn key_pinned_against_known_value() {
408 let scope = Scope {
413 tenant: "t",
414 user: "u",
415 };
416 let i = Inputs {
417 question: "q",
418 provider: "p",
419 model: "m",
420 temperature: Some(0.0),
421 seed: Some(1),
422 sources_fingerprint: "f",
423 };
424 let key = derive_key(scope, i);
425 assert_eq!(
427 key,
428 "ca47974209a1e07b9890aa73b5bdbcc2fda1bae0ba1d77f186c9dc168b54f903"
429 );
430 }
431
432 #[test]
435 fn decide_nocache_always_bypasses() {
436 let s = Settings {
437 enabled: true,
438 default_ttl: Some(Duration::from_secs(60)),
439 max_entries: 100,
440 };
441 assert_eq!(decide(Mode::NoCache, s), Decision::Bypass);
442 }
443
444 #[test]
445 fn decide_per_query_cache_wins_over_disabled_setting() {
446 let s = Settings::default();
447 assert_eq!(
448 decide(Mode::Cache(Duration::from_secs(300)), s),
449 Decision::Use {
450 ttl: Duration::from_secs(300)
451 }
452 );
453 }
454
455 #[test]
456 fn decide_default_bypass_when_disabled() {
457 let s = Settings {
458 enabled: false,
459 default_ttl: Some(Duration::from_secs(60)),
460 max_entries: 100,
461 };
462 assert_eq!(decide(Mode::Default, s), Decision::Bypass);
463 }
464
465 #[test]
466 fn decide_default_bypass_when_no_default_ttl() {
467 let s = Settings {
468 enabled: true,
469 default_ttl: None,
470 max_entries: 100,
471 };
472 assert_eq!(decide(Mode::Default, s), Decision::Bypass);
473 }
474
475 #[test]
476 fn decide_default_uses_setting_ttl_when_enabled_and_ttl_set() {
477 let s = Settings {
478 enabled: true,
479 default_ttl: Some(Duration::from_secs(120)),
480 max_entries: 100,
481 };
482 assert_eq!(
483 decide(Mode::Default, s),
484 Decision::Use {
485 ttl: Duration::from_secs(120)
486 }
487 );
488 }
489
490 #[test]
491 fn decide_per_query_cache_overrides_setting_default() {
492 let s = Settings {
493 enabled: true,
494 default_ttl: Some(Duration::from_secs(60)),
495 max_entries: 100,
496 };
497 assert_eq!(
498 decide(Mode::Cache(Duration::from_secs(900)), s),
499 Decision::Use {
500 ttl: Duration::from_secs(900)
501 }
502 );
503 }
504
505 #[test]
508 fn parse_ttl_seconds() {
509 assert_eq!(parse_ttl("30s").unwrap(), Duration::from_secs(30));
510 }
511
512 #[test]
513 fn parse_ttl_minutes() {
514 assert_eq!(parse_ttl("5m").unwrap(), Duration::from_secs(300));
515 }
516
517 #[test]
518 fn parse_ttl_hours() {
519 assert_eq!(parse_ttl("2h").unwrap(), Duration::from_secs(7200));
520 }
521
522 #[test]
523 fn parse_ttl_days() {
524 assert_eq!(parse_ttl("1d").unwrap(), Duration::from_secs(86_400));
525 }
526
527 #[test]
528 fn parse_ttl_empty_rejected() {
529 assert_eq!(parse_ttl(""), Err(TtlParseError::Empty));
530 }
531
532 #[test]
533 fn parse_ttl_zero_rejected() {
534 assert_eq!(parse_ttl("0s"), Err(TtlParseError::ZeroTtl));
537 }
538
539 #[test]
540 fn parse_ttl_missing_unit_rejected() {
541 assert_eq!(parse_ttl("30"), Err(TtlParseError::MissingUnit));
542 }
543
544 #[test]
545 fn parse_ttl_missing_number_rejected() {
546 assert_eq!(parse_ttl("m"), Err(TtlParseError::MissingNumber));
547 }
548
549 #[test]
550 fn parse_ttl_unknown_unit_rejected() {
551 assert_eq!(parse_ttl("5x"), Err(TtlParseError::UnknownUnit));
552 assert_eq!(parse_ttl("5ms"), Err(TtlParseError::UnknownUnit));
553 }
554
555 #[test]
556 fn parse_ttl_whitespace_rejected() {
557 assert_eq!(parse_ttl("5 m"), Err(TtlParseError::UnknownUnit));
560 assert_eq!(parse_ttl(" 5m"), Err(TtlParseError::MissingNumber));
561 }
562
563 #[test]
564 fn parse_ttl_negative_rejected() {
565 assert_eq!(parse_ttl("-5m"), Err(TtlParseError::MissingNumber));
568 }
569
570 #[test]
571 fn parse_ttl_invalid_number_rejected() {
572 assert_eq!(
574 parse_ttl("99999999999999999999s"),
575 Err(TtlParseError::InvalidNumber)
576 );
577 }
578
579 #[test]
580 fn parse_ttl_overflow_on_unit_multiplication() {
581 let max_d = u64::MAX / 86_400 + 1;
584 let lit = format!("{}d", max_d);
585 assert_eq!(parse_ttl(&lit), Err(TtlParseError::Overflow));
586 }
587
588 #[test]
591 fn mode_default_is_inherit() {
592 assert_eq!(Mode::default(), Mode::Default);
593 }
594
595 #[test]
598 fn decide_is_deterministic_across_calls() {
599 let s = Settings {
600 enabled: true,
601 default_ttl: Some(Duration::from_secs(60)),
602 max_entries: 10,
603 };
604 for mode in [
605 Mode::Default,
606 Mode::NoCache,
607 Mode::Cache(Duration::from_secs(120)),
608 ] {
609 let d1 = decide(mode, s);
610 let d2 = decide(mode, s);
611 assert_eq!(d1, d2);
612 }
613 }
614}