use std::time::Duration;
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Scope<'a> {
pub tenant: &'a str,
pub user: &'a str,
}
#[derive(Debug, Clone, Copy)]
pub struct Inputs<'a> {
pub question: &'a str,
pub provider: &'a str,
pub model: &'a str,
pub temperature: Option<f32>,
pub seed: Option<u64>,
pub sources_fingerprint: &'a str,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mode {
Default,
Cache(Duration),
NoCache,
}
impl Default for Mode {
fn default() -> Self {
Mode::Default
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Settings {
pub enabled: bool,
pub default_ttl: Option<Duration>,
pub max_entries: usize,
}
impl Default for Settings {
fn default() -> Self {
Self {
enabled: false,
default_ttl: None,
max_entries: 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Decision {
Bypass,
Use { ttl: Duration },
}
pub fn decide(mode: Mode, settings: Settings) -> Decision {
match mode {
Mode::NoCache => Decision::Bypass,
Mode::Cache(ttl) => Decision::Use { ttl },
Mode::Default => match (settings.enabled, settings.default_ttl) {
(true, Some(ttl)) => Decision::Use { ttl },
_ => Decision::Bypass,
},
}
}
pub fn derive_key(scope: Scope<'_>, inputs: Inputs<'_>) -> String {
const SEP: u8 = 0x1f;
let mut hasher = Sha256::new();
hasher.update(scope.tenant.as_bytes());
hasher.update([SEP]);
hasher.update(scope.user.as_bytes());
hasher.update([SEP]);
hasher.update(inputs.question.as_bytes());
hasher.update([SEP]);
hasher.update(inputs.provider.as_bytes());
hasher.update([SEP]);
hasher.update(inputs.model.as_bytes());
hasher.update([SEP]);
hasher.update(format_temperature(inputs.temperature).as_bytes());
hasher.update([SEP]);
hasher.update(format_seed(inputs.seed).as_bytes());
hasher.update([SEP]);
hasher.update(inputs.sources_fingerprint.as_bytes());
let digest = hasher.finalize();
let mut out = String::with_capacity(digest.len() * 2);
for b in digest {
out.push_str(&format!("{b:02x}"));
}
out
}
fn format_temperature(t: Option<f32>) -> String {
match t {
None => "none".to_string(),
Some(v) => format!("{v}"),
}
}
fn format_seed(s: Option<u64>) -> String {
match s {
None => "none".to_string(),
Some(v) => v.to_string(),
}
}
pub fn parse_ttl(literal: &str) -> Result<Duration, TtlParseError> {
if literal.is_empty() {
return Err(TtlParseError::Empty);
}
let bytes = literal.as_bytes();
let unit_idx = bytes
.iter()
.position(|b| !b.is_ascii_digit())
.ok_or(TtlParseError::MissingUnit)?;
if unit_idx == 0 {
return Err(TtlParseError::MissingNumber);
}
let (num_part, unit_part) = literal.split_at(unit_idx);
let n: u64 = num_part.parse().map_err(|_| TtlParseError::InvalidNumber)?;
if n == 0 {
return Err(TtlParseError::ZeroTtl);
}
let secs = match unit_part {
"s" => n,
"m" => n.checked_mul(60).ok_or(TtlParseError::Overflow)?,
"h" => n.checked_mul(3600).ok_or(TtlParseError::Overflow)?,
"d" => n.checked_mul(86_400).ok_or(TtlParseError::Overflow)?,
_ => return Err(TtlParseError::UnknownUnit),
};
Ok(Duration::from_secs(secs))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TtlParseError {
Empty,
MissingNumber,
MissingUnit,
InvalidNumber,
UnknownUnit,
ZeroTtl,
Overflow,
}
#[cfg(test)]
mod tests {
use super::*;
fn scope() -> Scope<'static> {
Scope {
tenant: "acme",
user: "alice",
}
}
fn inputs() -> Inputs<'static> {
Inputs {
question: "what is the capital of france?",
provider: "openai",
model: "gpt-4o-mini",
temperature: Some(0.0),
seed: Some(42),
sources_fingerprint: "abc123",
}
}
#[test]
fn key_is_deterministic_across_calls() {
let k1 = derive_key(scope(), inputs());
let k2 = derive_key(scope(), inputs());
assert_eq!(k1, k2);
assert_eq!(k1.len(), 64);
assert!(k1
.chars()
.all(|c| c.is_ascii_hexdigit() && !c.is_uppercase()));
}
#[test]
fn key_changes_with_tenant() {
let a = derive_key(
Scope {
tenant: "acme",
user: "alice",
},
inputs(),
);
let b = derive_key(
Scope {
tenant: "globex",
user: "alice",
},
inputs(),
);
assert_ne!(a, b, "per-tenant scope must isolate cache keys");
}
#[test]
fn key_changes_with_user() {
let a = derive_key(
Scope {
tenant: "acme",
user: "alice",
},
inputs(),
);
let b = derive_key(
Scope {
tenant: "acme",
user: "bob",
},
inputs(),
);
assert_ne!(a, b);
}
#[test]
fn empty_user_is_distinct_from_named_user() {
let anon = derive_key(
Scope {
tenant: "acme",
user: "",
},
inputs(),
);
let named = derive_key(scope(), inputs());
assert_ne!(anon, named);
}
#[test]
fn key_changes_with_question() {
let mut i = inputs();
let base = derive_key(scope(), i);
i.question = "different question";
let other = derive_key(scope(), i);
assert_ne!(base, other);
}
#[test]
fn key_changes_with_provider() {
let mut i = inputs();
let base = derive_key(scope(), i);
i.provider = "anthropic";
let other = derive_key(scope(), i);
assert_ne!(base, other);
}
#[test]
fn key_changes_with_model() {
let mut i = inputs();
let base = derive_key(scope(), i);
i.model = "gpt-4o";
let other = derive_key(scope(), i);
assert_ne!(base, other);
}
#[test]
fn key_changes_with_temperature() {
let mut i = inputs();
let base = derive_key(scope(), i);
i.temperature = Some(0.7);
let other = derive_key(scope(), i);
assert_ne!(base, other);
}
#[test]
fn key_changes_with_seed() {
let mut i = inputs();
let base = derive_key(scope(), i);
i.seed = Some(43);
let other = derive_key(scope(), i);
assert_ne!(base, other);
}
#[test]
fn key_changes_with_fingerprint() {
let mut i = inputs();
let base = derive_key(scope(), i);
i.sources_fingerprint = "def456";
let other = derive_key(scope(), i);
assert_ne!(
base, other,
"different sources must miss cache even for identical question"
);
}
#[test]
fn temperature_none_distinct_from_zero() {
let mut i = inputs();
i.temperature = None;
let none = derive_key(scope(), i);
i.temperature = Some(0.0);
let zero = derive_key(scope(), i);
assert_ne!(
none, zero,
"None and Some(0.0) must not collide — a provider that ignores temperature is not the same as one that received zero"
);
}
#[test]
fn seed_none_distinct_from_zero() {
let mut i = inputs();
i.seed = None;
let none = derive_key(scope(), i);
i.seed = Some(0);
let zero = derive_key(scope(), i);
assert_ne!(none, zero);
}
#[test]
fn key_pinned_against_known_value() {
let scope = Scope {
tenant: "t",
user: "u",
};
let i = Inputs {
question: "q",
provider: "p",
model: "m",
temperature: Some(0.0),
seed: Some(1),
sources_fingerprint: "f",
};
let key = derive_key(scope, i);
assert_eq!(
key,
"ca47974209a1e07b9890aa73b5bdbcc2fda1bae0ba1d77f186c9dc168b54f903"
);
}
#[test]
fn decide_nocache_always_bypasses() {
let s = Settings {
enabled: true,
default_ttl: Some(Duration::from_secs(60)),
max_entries: 100,
};
assert_eq!(decide(Mode::NoCache, s), Decision::Bypass);
}
#[test]
fn decide_per_query_cache_wins_over_disabled_setting() {
let s = Settings::default();
assert_eq!(
decide(Mode::Cache(Duration::from_secs(300)), s),
Decision::Use {
ttl: Duration::from_secs(300)
}
);
}
#[test]
fn decide_default_bypass_when_disabled() {
let s = Settings {
enabled: false,
default_ttl: Some(Duration::from_secs(60)),
max_entries: 100,
};
assert_eq!(decide(Mode::Default, s), Decision::Bypass);
}
#[test]
fn decide_default_bypass_when_no_default_ttl() {
let s = Settings {
enabled: true,
default_ttl: None,
max_entries: 100,
};
assert_eq!(decide(Mode::Default, s), Decision::Bypass);
}
#[test]
fn decide_default_uses_setting_ttl_when_enabled_and_ttl_set() {
let s = Settings {
enabled: true,
default_ttl: Some(Duration::from_secs(120)),
max_entries: 100,
};
assert_eq!(
decide(Mode::Default, s),
Decision::Use {
ttl: Duration::from_secs(120)
}
);
}
#[test]
fn decide_per_query_cache_overrides_setting_default() {
let s = Settings {
enabled: true,
default_ttl: Some(Duration::from_secs(60)),
max_entries: 100,
};
assert_eq!(
decide(Mode::Cache(Duration::from_secs(900)), s),
Decision::Use {
ttl: Duration::from_secs(900)
}
);
}
#[test]
fn parse_ttl_seconds() {
assert_eq!(parse_ttl("30s").unwrap(), Duration::from_secs(30));
}
#[test]
fn parse_ttl_minutes() {
assert_eq!(parse_ttl("5m").unwrap(), Duration::from_secs(300));
}
#[test]
fn parse_ttl_hours() {
assert_eq!(parse_ttl("2h").unwrap(), Duration::from_secs(7200));
}
#[test]
fn parse_ttl_days() {
assert_eq!(parse_ttl("1d").unwrap(), Duration::from_secs(86_400));
}
#[test]
fn parse_ttl_empty_rejected() {
assert_eq!(parse_ttl(""), Err(TtlParseError::Empty));
}
#[test]
fn parse_ttl_zero_rejected() {
assert_eq!(parse_ttl("0s"), Err(TtlParseError::ZeroTtl));
}
#[test]
fn parse_ttl_missing_unit_rejected() {
assert_eq!(parse_ttl("30"), Err(TtlParseError::MissingUnit));
}
#[test]
fn parse_ttl_missing_number_rejected() {
assert_eq!(parse_ttl("m"), Err(TtlParseError::MissingNumber));
}
#[test]
fn parse_ttl_unknown_unit_rejected() {
assert_eq!(parse_ttl("5x"), Err(TtlParseError::UnknownUnit));
assert_eq!(parse_ttl("5ms"), Err(TtlParseError::UnknownUnit));
}
#[test]
fn parse_ttl_whitespace_rejected() {
assert_eq!(parse_ttl("5 m"), Err(TtlParseError::UnknownUnit));
assert_eq!(parse_ttl(" 5m"), Err(TtlParseError::MissingNumber));
}
#[test]
fn parse_ttl_negative_rejected() {
assert_eq!(parse_ttl("-5m"), Err(TtlParseError::MissingNumber));
}
#[test]
fn parse_ttl_invalid_number_rejected() {
assert_eq!(
parse_ttl("99999999999999999999s"),
Err(TtlParseError::InvalidNumber)
);
}
#[test]
fn parse_ttl_overflow_on_unit_multiplication() {
let max_d = u64::MAX / 86_400 + 1;
let lit = format!("{}d", max_d);
assert_eq!(parse_ttl(&lit), Err(TtlParseError::Overflow));
}
#[test]
fn mode_default_is_inherit() {
assert_eq!(Mode::default(), Mode::Default);
}
#[test]
fn decide_is_deterministic_across_calls() {
let s = Settings {
enabled: true,
default_ttl: Some(Duration::from_secs(60)),
max_entries: 10,
};
for mode in [
Mode::Default,
Mode::NoCache,
Mode::Cache(Duration::from_secs(120)),
] {
let d1 = decide(mode, s);
let d2 = decide(mode, s);
assert_eq!(d1, d2);
}
}
}