Skip to main content

ai_lib_rust/cache/
key.rs

1//! Cache key generation.
2
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::collections::BTreeMap;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct CacheKey {
9    pub hash: String,
10    pub model: Option<String>,
11    pub provider: Option<String>,
12    pub fingerprint: Option<String>,
13}
14
15impl CacheKey {
16    pub fn new(hash: impl Into<String>) -> Self {
17        Self { hash: hash.into(), model: None, provider: None, fingerprint: None }
18    }
19    pub fn with_model(mut self, model: impl Into<String>) -> Self { self.model = Some(model.into()); self }
20    pub fn with_provider(mut self, provider: impl Into<String>) -> Self { self.provider = Some(provider.into()); self }
21    pub fn with_fingerprint(mut self, fp: impl Into<String>) -> Self { self.fingerprint = Some(fp.into()); self }
22    pub fn as_str(&self) -> &str { &self.hash }
23}
24
25impl std::fmt::Display for CacheKey {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.hash) }
27}
28
29impl From<&str> for CacheKey { fn from(s: &str) -> Self { Self::new(s) } }
30impl From<String> for CacheKey { fn from(s: String) -> Self { Self::new(s) } }
31
32pub struct CacheKeyGenerator {
33    include_model: bool,
34    include_temperature: bool,
35    salt: Option<String>,
36}
37
38impl CacheKeyGenerator {
39    pub fn new() -> Self { Self { include_model: true, include_temperature: true, salt: None } }
40    pub fn with_salt(mut self, salt: impl Into<String>) -> Self { self.salt = Some(salt.into()); self }
41
42    pub fn generate(&self, model: Option<&str>, messages: &[serde_json::Value], temperature: Option<f64>, _max_tokens: Option<u32>) -> CacheKey {
43        let mut parts: BTreeMap<String, String> = BTreeMap::new();
44        if self.include_model { if let Some(m) = model { parts.insert("model".into(), m.into()); } }
45        if self.include_temperature { if let Some(t) = temperature { parts.insert("temperature".into(), format!("{:.2}", t)); } }
46        parts.insert("messages".into(), serde_json::to_string(messages).unwrap_or_default());
47        if let Some(ref s) = self.salt { parts.insert("salt".into(), s.clone()); }
48        let canonical = serde_json::to_string(&parts).unwrap_or_default();
49        let mut hasher = Sha256::new();
50        hasher.update(canonical.as_bytes());
51        let hash: String = hasher.finalize().iter().map(|b| format!("{:02x}", b)).collect();
52        let mut key = CacheKey::new(hash);
53        if let Some(m) = model { key = key.with_model(m); }
54        key
55    }
56
57    pub fn generate_from_json(&self, request: &serde_json::Value) -> CacheKey {
58        self.generate(request["model"].as_str(), request["messages"].as_array().cloned().unwrap_or_default().as_slice(), request["temperature"].as_f64(), request["max_tokens"].as_u64().map(|v| v as u32))
59    }
60}
61
62impl Default for CacheKeyGenerator { fn default() -> Self { Self::new() } }