Skip to main content

ai_lib_rust/cache/
key.rs

1//! Cache key generation.
2
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::collections::BTreeMap;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct CacheKey {
9    pub hash: String,
10    pub model: Option<String>,
11    pub provider: Option<String>,
12    pub fingerprint: Option<String>,
13}
14
15impl CacheKey {
16    pub fn new(hash: impl Into<String>) -> Self {
17        Self {
18            hash: hash.into(),
19            model: None,
20            provider: None,
21            fingerprint: None,
22        }
23    }
24    pub fn with_model(mut self, model: impl Into<String>) -> Self {
25        self.model = Some(model.into());
26        self
27    }
28    pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
29        self.provider = Some(provider.into());
30        self
31    }
32    pub fn with_fingerprint(mut self, fp: impl Into<String>) -> Self {
33        self.fingerprint = Some(fp.into());
34        self
35    }
36    pub fn as_str(&self) -> &str {
37        &self.hash
38    }
39}
40
41impl std::fmt::Display for CacheKey {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        write!(f, "{}", self.hash)
44    }
45}
46
47impl From<&str> for CacheKey {
48    fn from(s: &str) -> Self {
49        Self::new(s)
50    }
51}
52impl From<String> for CacheKey {
53    fn from(s: String) -> Self {
54        Self::new(s)
55    }
56}
57
58pub struct CacheKeyGenerator {
59    include_model: bool,
60    include_temperature: bool,
61    salt: Option<String>,
62}
63
64impl CacheKeyGenerator {
65    pub fn new() -> Self {
66        Self {
67            include_model: true,
68            include_temperature: true,
69            salt: None,
70        }
71    }
72    pub fn with_salt(mut self, salt: impl Into<String>) -> Self {
73        self.salt = Some(salt.into());
74        self
75    }
76
77    pub fn generate(
78        &self,
79        model: Option<&str>,
80        messages: &[serde_json::Value],
81        temperature: Option<f64>,
82        _max_tokens: Option<u32>,
83    ) -> CacheKey {
84        let mut parts: BTreeMap<String, String> = BTreeMap::new();
85        if self.include_model {
86            if let Some(m) = model {
87                parts.insert("model".into(), m.into());
88            }
89        }
90        if self.include_temperature {
91            if let Some(t) = temperature {
92                parts.insert("temperature".into(), format!("{:.2}", t));
93            }
94        }
95        parts.insert(
96            "messages".into(),
97            serde_json::to_string(messages).unwrap_or_default(),
98        );
99        if let Some(ref s) = self.salt {
100            parts.insert("salt".into(), s.clone());
101        }
102        let canonical = serde_json::to_string(&parts).unwrap_or_default();
103        let mut hasher = Sha256::new();
104        hasher.update(canonical.as_bytes());
105        let hash: String = hasher
106            .finalize()
107            .iter()
108            .map(|b| format!("{:02x}", b))
109            .collect();
110        let mut key = CacheKey::new(hash);
111        if let Some(m) = model {
112            key = key.with_model(m);
113        }
114        key
115    }
116
117    pub fn generate_from_json(&self, request: &serde_json::Value) -> CacheKey {
118        self.generate(
119            request["model"].as_str(),
120            request["messages"]
121                .as_array()
122                .cloned()
123                .unwrap_or_default()
124                .as_slice(),
125            request["temperature"].as_f64(),
126            request["max_tokens"].as_u64().map(|v| v as u32),
127        )
128    }
129}
130
131impl Default for CacheKeyGenerator {
132    fn default() -> Self {
133        Self::new()
134    }
135}