1use std::collections::HashMap;
8
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11pub struct ProductKeyCacheEntry {
12 pub product_id: String,
14 pub index: u32,
16 pub pubkey: [u8; 32],
18}
19
20#[derive(Debug)]
23pub struct CacheIdentityMismatch;
24
25impl std::fmt::Display for CacheIdentityMismatch {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.write_str("product key cache identity mismatch")
28 }
29}
30
31pub struct ProductKeyCache {
37 phone_identity: [u8; 32],
38 entries: HashMap<(String, u32), [u8; 32]>,
39}
40
41impl ProductKeyCache {
42 pub fn new(phone_identity: [u8; 32]) -> Self {
44 Self {
45 phone_identity,
46 entries: HashMap::new(),
47 }
48 }
49
50 pub fn identity(&self) -> &[u8; 32] {
52 &self.phone_identity
53 }
54
55 pub fn get(
61 &mut self,
62 phone_identity: &[u8; 32],
63 product_id: &str,
64 index: u32,
65 ) -> Option<[u8; 32]> {
66 if *phone_identity != self.phone_identity {
67 log::warn!("product key cache identity mismatch — clearing");
68 self.entries.clear();
69 self.phone_identity = *phone_identity;
70 return None;
71 }
72 self.entries.get(&(product_id.to_string(), index)).copied()
73 }
74
75 pub fn insert(
81 &mut self,
82 phone_identity: &[u8; 32],
83 product_id: &str,
84 index: u32,
85 pubkey: [u8; 32],
86 ) -> Result<(), CacheIdentityMismatch> {
87 if *phone_identity != self.phone_identity {
88 return Err(CacheIdentityMismatch);
89 }
90 self.entries.insert((product_id.to_string(), index), pubkey);
91 Ok(())
92 }
93
94 pub fn remove(&mut self, product_id: &str, index: u32) {
96 self.entries.remove(&(product_id.to_string(), index));
97 }
98
99 pub fn clear(&mut self) {
101 self.entries.clear();
102 }
103
104 pub fn to_entries(&self) -> Vec<ProductKeyCacheEntry> {
106 self.entries
107 .iter()
108 .map(|((product_id, index), pubkey)| ProductKeyCacheEntry {
109 product_id: product_id.clone(),
110 index: *index,
111 pubkey: *pubkey,
112 })
113 .collect()
114 }
115
116 pub fn load_from(&mut self, identity: &[u8; 32], entries: &[ProductKeyCacheEntry]) {
121 if *identity != self.phone_identity {
122 return;
124 }
125 for e in entries {
126 self.entries
127 .insert((e.product_id.clone(), e.index), e.pubkey);
128 }
129 }
130}
131
132#[cfg(test)]
137mod tests {
138 use super::*;
139
140 const IDENTITY_A: [u8; 32] = [0x01u8; 32];
141 const IDENTITY_B: [u8; 32] = [0x02u8; 32];
142 const PUBKEY_1: [u8; 32] = [0xAAu8; 32];
143 const PUBKEY_2: [u8; 32] = [0xBBu8; 32];
144
145 fn make_cache() -> ProductKeyCache {
146 ProductKeyCache::new(IDENTITY_A)
147 }
148
149 #[test]
150 fn test_cache_hit_returns_stored_key() {
151 let mut cache = make_cache();
152 cache.insert(&IDENTITY_A, "acme.dot", 0, PUBKEY_1).unwrap();
153
154 let result = cache.get(&IDENTITY_A, "acme.dot", 0);
155 assert_eq!(result, Some(PUBKEY_1));
156 }
157
158 #[test]
159 fn test_cache_miss_returns_none() {
160 let mut cache = make_cache();
161 let result = cache.get(&IDENTITY_A, "acme.dot", 99);
162 assert_eq!(result, None);
163 }
164
165 #[test]
166 fn test_cache_clears_on_identity_mismatch() {
167 let mut cache = make_cache();
168 cache.insert(&IDENTITY_A, "acme.dot", 0, PUBKEY_1).unwrap();
169
170 let result = cache.get(&IDENTITY_B, "acme.dot", 0);
172 assert_eq!(result, None, "mismatched identity must yield None");
173
174 let result2 = cache.get(&IDENTITY_B, "acme.dot", 0);
176 assert_eq!(
177 result2, None,
178 "cache must be empty after identity mismatch clear"
179 );
180 }
181
182 #[test]
183 fn test_insert_rejects_identity_mismatch() {
184 let mut cache = make_cache();
185 let err = cache.insert(&IDENTITY_B, "acme.dot", 0, PUBKEY_1);
186 assert!(err.is_err(), "insert must reject a mismatched identity");
187 }
188
189 #[test]
190 fn test_remove_deletes_entry() {
191 let mut cache = make_cache();
192 cache.insert(&IDENTITY_A, "acme.dot", 0, PUBKEY_1).unwrap();
193 cache.remove("acme.dot", 0);
194
195 let result = cache.get(&IDENTITY_A, "acme.dot", 0);
196 assert_eq!(result, None, "entry must be absent after remove");
197 }
198
199 #[test]
200 fn test_clear_removes_all() {
201 let mut cache = make_cache();
202 cache.insert(&IDENTITY_A, "acme.dot", 0, PUBKEY_1).unwrap();
203 cache.insert(&IDENTITY_A, "foo.dot", 1, PUBKEY_2).unwrap();
204
205 cache.clear();
206
207 assert_eq!(cache.get(&IDENTITY_A, "acme.dot", 0), None);
208 assert_eq!(cache.get(&IDENTITY_A, "foo.dot", 1), None);
209 }
210
211 #[test]
212 fn test_load_from_populates_matching_identity() {
213 let mut cache = make_cache();
214 let snapshot = vec![
215 ProductKeyCacheEntry {
216 product_id: "acme.dot".to_string(),
217 index: 0,
218 pubkey: PUBKEY_1,
219 },
220 ProductKeyCacheEntry {
221 product_id: "foo.dot".to_string(),
222 index: 2,
223 pubkey: PUBKEY_2,
224 },
225 ];
226
227 cache.load_from(&IDENTITY_A, &snapshot);
228
229 assert_eq!(cache.get(&IDENTITY_A, "acme.dot", 0), Some(PUBKEY_1));
230 assert_eq!(cache.get(&IDENTITY_A, "foo.dot", 2), Some(PUBKEY_2));
231 }
232
233 #[test]
234 fn test_load_from_ignores_mismatching_identity() {
235 let mut cache = make_cache();
236 let snapshot = vec![ProductKeyCacheEntry {
237 product_id: "acme.dot".to_string(),
238 index: 0,
239 pubkey: PUBKEY_1,
240 }];
241
242 cache.load_from(&IDENTITY_B, &snapshot);
244
245 assert_eq!(
247 cache.get(&IDENTITY_A, "acme.dot", 0),
248 None,
249 "stale snapshot must be discarded on identity mismatch"
250 );
251 }
252
253 #[test]
254 fn test_to_entries_snapshots_all() {
255 let mut cache = make_cache();
256 cache.insert(&IDENTITY_A, "acme.dot", 0, PUBKEY_1).unwrap();
257 cache.insert(&IDENTITY_A, "foo.dot", 3, PUBKEY_2).unwrap();
258
259 let entries = cache.to_entries();
260 assert_eq!(entries.len(), 2);
261
262 let has_acme = entries
264 .iter()
265 .any(|e| e.product_id == "acme.dot" && e.index == 0 && e.pubkey == PUBKEY_1);
266 let has_foo = entries
267 .iter()
268 .any(|e| e.product_id == "foo.dot" && e.index == 3 && e.pubkey == PUBKEY_2);
269
270 assert!(has_acme, "snapshot must contain acme.dot entry");
271 assert!(has_foo, "snapshot must contain foo.dot entry");
272 }
273}