llm_edge_cache/
key.rs

1//! Cache key generation using SHA-256 hashing
2//!
3//! This module provides efficient cache key generation from LLM prompts and parameters.
4//! Uses SHA-256 for consistent, collision-resistant hashing.
5
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9
10/// Represents a cacheable LLM request
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CacheableRequest {
13    /// The model name (e.g., "gpt-4", "claude-3-sonnet")
14    pub model: String,
15    /// The prompt or messages
16    pub prompt: String,
17    /// Temperature parameter
18    pub temperature: Option<f32>,
19    /// Max tokens to generate
20    pub max_tokens: Option<u32>,
21    /// Additional parameters that affect the response
22    pub parameters: HashMap<String, serde_json::Value>,
23}
24
25impl CacheableRequest {
26    /// Create a new cacheable request
27    pub fn new(model: impl Into<String>, prompt: impl Into<String>) -> Self {
28        Self {
29            model: model.into(),
30            prompt: prompt.into(),
31            temperature: None,
32            max_tokens: None,
33            parameters: HashMap::new(),
34        }
35    }
36
37    /// Set the temperature
38    pub fn with_temperature(mut self, temp: f32) -> Self {
39        self.temperature = Some(temp);
40        self
41    }
42
43    /// Set max tokens
44    pub fn with_max_tokens(mut self, tokens: u32) -> Self {
45        self.max_tokens = Some(tokens);
46        self
47    }
48
49    /// Add a custom parameter
50    pub fn with_parameter(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
51        self.parameters.insert(key.into(), value);
52        self
53    }
54}
55
56/// Generate a cache key from a request using SHA-256
57///
58/// The key includes:
59/// - Model name
60/// - Prompt content
61/// - Temperature (normalized to 2 decimal places)
62/// - Max tokens
63/// - All additional parameters (sorted for consistency)
64///
65/// # Performance
66/// - Target: <100μs for typical requests
67/// - SHA-256 is hardware-accelerated on most modern CPUs
68pub fn generate_cache_key(request: &CacheableRequest) -> String {
69    let mut hasher = Sha256::new();
70
71    // Add model name
72    hasher.update(request.model.as_bytes());
73    hasher.update(b"|");
74
75    // Add prompt
76    hasher.update(request.prompt.as_bytes());
77    hasher.update(b"|");
78
79    // Add temperature (normalized to 2 decimals to avoid floating point precision issues)
80    if let Some(temp) = request.temperature {
81        hasher.update(format!("{:.2}", temp).as_bytes());
82    }
83    hasher.update(b"|");
84
85    // Add max_tokens
86    if let Some(max_tokens) = request.max_tokens {
87        hasher.update(max_tokens.to_string().as_bytes());
88    }
89    hasher.update(b"|");
90
91    // Add sorted parameters for deterministic hashing
92    let mut param_keys: Vec<_> = request.parameters.keys().collect();
93    param_keys.sort();
94    for key in param_keys {
95        if let Some(value) = request.parameters.get(key) {
96            hasher.update(key.as_bytes());
97            hasher.update(b"=");
98            // Serialize value to JSON for consistent representation
99            if let Ok(json_str) = serde_json::to_string(value) {
100                hasher.update(json_str.as_bytes());
101            }
102            hasher.update(b";");
103        }
104    }
105
106    // Return hex-encoded hash
107    let result = hasher.finalize();
108    hex::encode(result)
109}
110
111/// Generate a short cache key (first 16 characters of the full hash)
112/// Useful for logging and debugging
113pub fn generate_short_key(request: &CacheableRequest) -> String {
114    let full_key = generate_cache_key(request);
115    full_key.chars().take(16).collect()
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_cache_key_consistency() {
124        let req1 = CacheableRequest::new("gpt-4", "Hello, world!")
125            .with_temperature(0.7)
126            .with_max_tokens(100);
127
128        let req2 = CacheableRequest::new("gpt-4", "Hello, world!")
129            .with_temperature(0.7)
130            .with_max_tokens(100);
131
132        let key1 = generate_cache_key(&req1);
133        let key2 = generate_cache_key(&req2);
134
135        assert_eq!(
136            key1, key2,
137            "Identical requests should produce identical keys"
138        );
139    }
140
141    #[test]
142    fn test_cache_key_different_prompts() {
143        let req1 = CacheableRequest::new("gpt-4", "Hello, world!");
144        let req2 = CacheableRequest::new("gpt-4", "Goodbye, world!");
145
146        let key1 = generate_cache_key(&req1);
147        let key2 = generate_cache_key(&req2);
148
149        assert_ne!(
150            key1, key2,
151            "Different prompts should produce different keys"
152        );
153    }
154
155    #[test]
156    fn test_cache_key_different_models() {
157        let req1 = CacheableRequest::new("gpt-4", "Hello, world!");
158        let req2 = CacheableRequest::new("gpt-3.5-turbo", "Hello, world!");
159
160        let key1 = generate_cache_key(&req1);
161        let key2 = generate_cache_key(&req2);
162
163        assert_ne!(key1, key2, "Different models should produce different keys");
164    }
165
166    #[test]
167    fn test_cache_key_temperature_normalization() {
168        let req1 = CacheableRequest::new("gpt-4", "Hello").with_temperature(0.7);
169        let req2 = CacheableRequest::new("gpt-4", "Hello").with_temperature(0.700001);
170
171        let key1 = generate_cache_key(&req1);
172        let key2 = generate_cache_key(&req2);
173
174        assert_eq!(
175            key1, key2,
176            "Temperature should be normalized to avoid precision issues"
177        );
178    }
179
180    #[test]
181    fn test_cache_key_parameter_order_independence() {
182        let mut req1 = CacheableRequest::new("gpt-4", "Hello");
183        req1.parameters
184            .insert("param_a".to_string(), serde_json::json!("value1"));
185        req1.parameters
186            .insert("param_b".to_string(), serde_json::json!("value2"));
187
188        let mut req2 = CacheableRequest::new("gpt-4", "Hello");
189        req2.parameters
190            .insert("param_b".to_string(), serde_json::json!("value2"));
191        req2.parameters
192            .insert("param_a".to_string(), serde_json::json!("value1"));
193
194        let key1 = generate_cache_key(&req1);
195        let key2 = generate_cache_key(&req2);
196
197        assert_eq!(key1, key2, "Parameter order should not affect cache key");
198    }
199
200    #[test]
201    fn test_short_key_length() {
202        let req = CacheableRequest::new("gpt-4", "Test prompt");
203        let short_key = generate_short_key(&req);
204
205        assert_eq!(short_key.len(), 16, "Short key should be 16 characters");
206    }
207
208    #[test]
209    fn test_cache_key_is_hexadecimal() {
210        let req = CacheableRequest::new("gpt-4", "Test prompt");
211        let key = generate_cache_key(&req);
212
213        assert!(
214            key.chars().all(|c| c.is_ascii_hexdigit()),
215            "Cache key should be valid hexadecimal"
216        );
217        assert_eq!(key.len(), 64, "SHA-256 hash should be 64 hex characters");
218    }
219}