1use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CacheableRequest {
13 pub model: String,
15 pub prompt: String,
17 pub temperature: Option<f32>,
19 pub max_tokens: Option<u32>,
21 pub parameters: HashMap<String, serde_json::Value>,
23}
24
25impl CacheableRequest {
26 pub fn new(model: impl Into<String>, prompt: impl Into<String>) -> Self {
28 Self {
29 model: model.into(),
30 prompt: prompt.into(),
31 temperature: None,
32 max_tokens: None,
33 parameters: HashMap::new(),
34 }
35 }
36
37 pub fn with_temperature(mut self, temp: f32) -> Self {
39 self.temperature = Some(temp);
40 self
41 }
42
43 pub fn with_max_tokens(mut self, tokens: u32) -> Self {
45 self.max_tokens = Some(tokens);
46 self
47 }
48
49 pub fn with_parameter(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
51 self.parameters.insert(key.into(), value);
52 self
53 }
54}
55
56pub fn generate_cache_key(request: &CacheableRequest) -> String {
69 let mut hasher = Sha256::new();
70
71 hasher.update(request.model.as_bytes());
73 hasher.update(b"|");
74
75 hasher.update(request.prompt.as_bytes());
77 hasher.update(b"|");
78
79 if let Some(temp) = request.temperature {
81 hasher.update(format!("{:.2}", temp).as_bytes());
82 }
83 hasher.update(b"|");
84
85 if let Some(max_tokens) = request.max_tokens {
87 hasher.update(max_tokens.to_string().as_bytes());
88 }
89 hasher.update(b"|");
90
91 let mut param_keys: Vec<_> = request.parameters.keys().collect();
93 param_keys.sort();
94 for key in param_keys {
95 if let Some(value) = request.parameters.get(key) {
96 hasher.update(key.as_bytes());
97 hasher.update(b"=");
98 if let Ok(json_str) = serde_json::to_string(value) {
100 hasher.update(json_str.as_bytes());
101 }
102 hasher.update(b";");
103 }
104 }
105
106 let result = hasher.finalize();
108 hex::encode(result)
109}
110
111pub fn generate_short_key(request: &CacheableRequest) -> String {
114 let full_key = generate_cache_key(request);
115 full_key.chars().take(16).collect()
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_cache_key_consistency() {
124 let req1 = CacheableRequest::new("gpt-4", "Hello, world!")
125 .with_temperature(0.7)
126 .with_max_tokens(100);
127
128 let req2 = CacheableRequest::new("gpt-4", "Hello, world!")
129 .with_temperature(0.7)
130 .with_max_tokens(100);
131
132 let key1 = generate_cache_key(&req1);
133 let key2 = generate_cache_key(&req2);
134
135 assert_eq!(
136 key1, key2,
137 "Identical requests should produce identical keys"
138 );
139 }
140
141 #[test]
142 fn test_cache_key_different_prompts() {
143 let req1 = CacheableRequest::new("gpt-4", "Hello, world!");
144 let req2 = CacheableRequest::new("gpt-4", "Goodbye, world!");
145
146 let key1 = generate_cache_key(&req1);
147 let key2 = generate_cache_key(&req2);
148
149 assert_ne!(
150 key1, key2,
151 "Different prompts should produce different keys"
152 );
153 }
154
155 #[test]
156 fn test_cache_key_different_models() {
157 let req1 = CacheableRequest::new("gpt-4", "Hello, world!");
158 let req2 = CacheableRequest::new("gpt-3.5-turbo", "Hello, world!");
159
160 let key1 = generate_cache_key(&req1);
161 let key2 = generate_cache_key(&req2);
162
163 assert_ne!(key1, key2, "Different models should produce different keys");
164 }
165
166 #[test]
167 fn test_cache_key_temperature_normalization() {
168 let req1 = CacheableRequest::new("gpt-4", "Hello").with_temperature(0.7);
169 let req2 = CacheableRequest::new("gpt-4", "Hello").with_temperature(0.700001);
170
171 let key1 = generate_cache_key(&req1);
172 let key2 = generate_cache_key(&req2);
173
174 assert_eq!(
175 key1, key2,
176 "Temperature should be normalized to avoid precision issues"
177 );
178 }
179
180 #[test]
181 fn test_cache_key_parameter_order_independence() {
182 let mut req1 = CacheableRequest::new("gpt-4", "Hello");
183 req1.parameters
184 .insert("param_a".to_string(), serde_json::json!("value1"));
185 req1.parameters
186 .insert("param_b".to_string(), serde_json::json!("value2"));
187
188 let mut req2 = CacheableRequest::new("gpt-4", "Hello");
189 req2.parameters
190 .insert("param_b".to_string(), serde_json::json!("value2"));
191 req2.parameters
192 .insert("param_a".to_string(), serde_json::json!("value1"));
193
194 let key1 = generate_cache_key(&req1);
195 let key2 = generate_cache_key(&req2);
196
197 assert_eq!(key1, key2, "Parameter order should not affect cache key");
198 }
199
200 #[test]
201 fn test_short_key_length() {
202 let req = CacheableRequest::new("gpt-4", "Test prompt");
203 let short_key = generate_short_key(&req);
204
205 assert_eq!(short_key.len(), 16, "Short key should be 16 characters");
206 }
207
208 #[test]
209 fn test_cache_key_is_hexadecimal() {
210 let req = CacheableRequest::new("gpt-4", "Test prompt");
211 let key = generate_cache_key(&req);
212
213 assert!(
214 key.chars().all(|c| c.is_ascii_hexdigit()),
215 "Cache key should be valid hexadecimal"
216 );
217 assert_eq!(key.len(), 64, "SHA-256 hash should be 64 hex characters");
218 }
219}