1use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::collections::BTreeMap;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct CacheKey {
9 pub hash: String,
10 pub model: Option<String>,
11 pub provider: Option<String>,
12 pub fingerprint: Option<String>,
13}
14
15impl CacheKey {
16 pub fn new(hash: impl Into<String>) -> Self {
17 Self {
18 hash: hash.into(),
19 model: None,
20 provider: None,
21 fingerprint: None,
22 }
23 }
24 pub fn with_model(mut self, model: impl Into<String>) -> Self {
25 self.model = Some(model.into());
26 self
27 }
28 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
29 self.provider = Some(provider.into());
30 self
31 }
32 pub fn with_fingerprint(mut self, fp: impl Into<String>) -> Self {
33 self.fingerprint = Some(fp.into());
34 self
35 }
36 pub fn as_str(&self) -> &str {
37 &self.hash
38 }
39}
40
41impl std::fmt::Display for CacheKey {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "{}", self.hash)
44 }
45}
46
47impl From<&str> for CacheKey {
48 fn from(s: &str) -> Self {
49 Self::new(s)
50 }
51}
52impl From<String> for CacheKey {
53 fn from(s: String) -> Self {
54 Self::new(s)
55 }
56}
57
58pub struct CacheKeyGenerator {
59 include_model: bool,
60 include_temperature: bool,
61 salt: Option<String>,
62}
63
64impl CacheKeyGenerator {
65 pub fn new() -> Self {
66 Self {
67 include_model: true,
68 include_temperature: true,
69 salt: None,
70 }
71 }
72 pub fn with_salt(mut self, salt: impl Into<String>) -> Self {
73 self.salt = Some(salt.into());
74 self
75 }
76
77 pub fn generate(
78 &self,
79 model: Option<&str>,
80 messages: &[serde_json::Value],
81 temperature: Option<f64>,
82 _max_tokens: Option<u32>,
83 ) -> CacheKey {
84 let mut parts: BTreeMap<String, String> = BTreeMap::new();
85 if self.include_model {
86 if let Some(m) = model {
87 parts.insert("model".into(), m.into());
88 }
89 }
90 if self.include_temperature {
91 if let Some(t) = temperature {
92 parts.insert("temperature".into(), format!("{:.2}", t));
93 }
94 }
95 parts.insert(
96 "messages".into(),
97 serde_json::to_string(messages).unwrap_or_default(),
98 );
99 if let Some(ref s) = self.salt {
100 parts.insert("salt".into(), s.clone());
101 }
102 let canonical = serde_json::to_string(&parts).unwrap_or_default();
103 let mut hasher = Sha256::new();
104 hasher.update(canonical.as_bytes());
105 let hash: String = hasher
106 .finalize()
107 .iter()
108 .map(|b| format!("{:02x}", b))
109 .collect();
110 let mut key = CacheKey::new(hash);
111 if let Some(m) = model {
112 key = key.with_model(m);
113 }
114 key
115 }
116
117 pub fn generate_from_json(&self, request: &serde_json::Value) -> CacheKey {
118 self.generate(
119 request["model"].as_str(),
120 request["messages"]
121 .as_array()
122 .cloned()
123 .unwrap_or_default()
124 .as_slice(),
125 request["temperature"].as_f64(),
126 request["max_tokens"].as_u64().map(|v| v as u32),
127 )
128 }
129}
130
131impl Default for CacheKeyGenerator {
132 fn default() -> Self {
133 Self::new()
134 }
135}