1use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Mutex;
6
7use crate::error::{CryptoError, Result};
8use crate::types::{KeyData, TofuRecord};
9
10#[async_trait]
12pub trait KeyStore: Send + Sync {
13 async fn save_group_key(&self, session_id: &str, data: KeyData) -> Result<()>;
15 async fn load_group_key(&self, session_id: &str) -> Result<Option<KeyData>>;
17 async fn delete_group_key(&self, session_id: &str) -> Result<()>;
19 async fn save_tofu_record(&self, id: &str, record: TofuRecord) -> Result<()>;
21 async fn load_tofu_record(&self, id: &str) -> Result<Option<TofuRecord>>;
23}
24
25pub struct MemoryKeyStore {
27 group_keys: Mutex<HashMap<String, KeyData>>,
28 tofu_records: Mutex<HashMap<String, TofuRecord>>,
29}
30
31impl MemoryKeyStore {
32 pub fn new() -> Self {
33 Self {
34 group_keys: Mutex::new(HashMap::new()),
35 tofu_records: Mutex::new(HashMap::new()),
36 }
37 }
38}
39
40impl Default for MemoryKeyStore {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46#[async_trait]
47impl KeyStore for MemoryKeyStore {
48 async fn save_group_key(&self, session_id: &str, data: KeyData) -> Result<()> {
49 self.group_keys
50 .lock()
51 .map_err(|_| CryptoError::Storage("key store lock poisoned".into()))?
52 .insert(session_id.to_string(), data);
53 Ok(())
54 }
55
56 async fn load_group_key(&self, session_id: &str) -> Result<Option<KeyData>> {
57 Ok(self
58 .group_keys
59 .lock()
60 .map_err(|_| CryptoError::Storage("key store lock poisoned".into()))?
61 .get(session_id)
62 .cloned())
63 }
64
65 async fn delete_group_key(&self, session_id: &str) -> Result<()> {
66 self.group_keys
67 .lock()
68 .map_err(|_| CryptoError::Storage("key store lock poisoned".into()))?
69 .remove(session_id);
70 Ok(())
71 }
72
73 async fn save_tofu_record(&self, id: &str, record: TofuRecord) -> Result<()> {
74 self.tofu_records
75 .lock()
76 .map_err(|_| CryptoError::Storage("key store lock poisoned".into()))?
77 .insert(id.to_string(), record);
78 Ok(())
79 }
80
81 async fn load_tofu_record(&self, id: &str) -> Result<Option<TofuRecord>> {
82 Ok(self
83 .tofu_records
84 .lock()
85 .map_err(|_| CryptoError::Storage("key store lock poisoned".into()))?
86 .get(id)
87 .cloned())
88 }
89}
90
91#[cfg(feature = "fs-store")]
98pub struct FileSystemKeyStore {
99 base_dir: std::path::PathBuf,
100}
101
102#[cfg(feature = "fs-store")]
103impl FileSystemKeyStore {
104 pub fn new(base_dir: impl Into<std::path::PathBuf>) -> Self {
105 Self {
106 base_dir: base_dir.into(),
107 }
108 }
109
110 fn hash_id(id: &str) -> String {
111 use sha2::{Digest, Sha256};
112 let hash = Sha256::digest(id.as_bytes());
113 hash.iter().map(|b| format!("{b:02x}")).collect()
114 }
115
116 fn group_key_path(&self, session_id: &str) -> std::path::PathBuf {
117 self.base_dir
118 .join("group-keys")
119 .join(format!("{}.json", Self::hash_id(session_id)))
120 }
121
122 fn tofu_path(&self, id: &str) -> std::path::PathBuf {
123 self.base_dir
124 .join("tofu")
125 .join(format!("{}.json", Self::hash_id(id)))
126 }
127
128 async fn atomic_write(
129 path: &std::path::Path,
130 data: &[u8],
131 ) -> std::result::Result<(), CryptoError> {
132 if let Some(parent) = path.parent() {
133 tokio::fs::create_dir_all(parent)
134 .await
135 .map_err(|e| CryptoError::Storage(format!("create dir: {e}")))?;
136 }
137 let tmp = path.with_extension("tmp");
138 tokio::fs::write(&tmp, data)
139 .await
140 .map_err(|e| CryptoError::Storage(format!("write tmp: {e}")))?;
141 tokio::fs::rename(&tmp, path)
142 .await
143 .map_err(|e| CryptoError::Storage(format!("rename: {e}")))?;
144 Ok(())
145 }
146}
147
148#[cfg(feature = "fs-store")]
149#[async_trait]
150impl KeyStore for FileSystemKeyStore {
151 async fn save_group_key(&self, session_id: &str, data: KeyData) -> Result<()> {
152 let path = self.group_key_path(session_id);
153 let json = serde_json::to_vec_pretty(&data)
154 .map_err(|e| CryptoError::Serialization(e.to_string()))?;
155 Self::atomic_write(&path, &json).await
156 }
157
158 async fn load_group_key(&self, session_id: &str) -> Result<Option<KeyData>> {
159 let path = self.group_key_path(session_id);
160 match tokio::fs::read(&path).await {
161 Ok(bytes) => {
162 let data: KeyData = serde_json::from_slice(&bytes)
163 .map_err(|e| CryptoError::Storage(format!("parse group key: {e}")))?;
164 Ok(Some(data))
165 }
166 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
167 Err(e) => Err(CryptoError::Storage(format!("read group key: {e}"))),
168 }
169 }
170
171 async fn delete_group_key(&self, session_id: &str) -> Result<()> {
172 let path = self.group_key_path(session_id);
173 match tokio::fs::remove_file(&path).await {
174 Ok(()) => Ok(()),
175 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
176 Err(e) => Err(CryptoError::Storage(format!("delete group key: {e}"))),
177 }
178 }
179
180 async fn save_tofu_record(&self, id: &str, record: TofuRecord) -> Result<()> {
181 let path = self.tofu_path(id);
182 let json = serde_json::to_vec_pretty(&record)
183 .map_err(|e| CryptoError::Serialization(e.to_string()))?;
184 Self::atomic_write(&path, &json).await
185 }
186
187 async fn load_tofu_record(&self, id: &str) -> Result<Option<TofuRecord>> {
188 let path = self.tofu_path(id);
189 match tokio::fs::read(&path).await {
190 Ok(bytes) => {
191 let record: TofuRecord = serde_json::from_slice(&bytes)
192 .map_err(|e| CryptoError::Storage(format!("parse tofu: {e}")))?;
193 Ok(Some(record))
194 }
195 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
196 Err(e) => Err(CryptoError::Storage(format!("read tofu: {e}"))),
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[tokio::test]
206 async fn memory_store_group_key_round_trip() {
207 let store = MemoryKeyStore::new();
208 let data = KeyData {
209 key: serde_json::json!({"kty": "oct", "k": "dGVzdA=="}),
210 stored_at: 1000,
211 };
212 store
213 .save_group_key("session-1", data.clone())
214 .await
215 .unwrap();
216 let loaded = store.load_group_key("session-1").await.unwrap();
217 assert!(loaded.is_some());
218 assert_eq!(loaded.unwrap().key["kty"], "oct");
219 }
220
221 #[tokio::test]
222 async fn memory_store_group_key_missing() {
223 let store = MemoryKeyStore::new();
224 let loaded = store.load_group_key("nonexistent").await.unwrap();
225 assert!(loaded.is_none());
226 }
227
228 #[tokio::test]
229 async fn memory_store_delete_group_key() {
230 let store = MemoryKeyStore::new();
231 let data = KeyData {
232 key: serde_json::json!({"kty": "oct", "k": "dGVzdA=="}),
233 stored_at: 1000,
234 };
235 store.save_group_key("s1", data).await.unwrap();
236 store.delete_group_key("s1").await.unwrap();
237 let loaded = store.load_group_key("s1").await.unwrap();
238 assert!(loaded.is_none());
239 }
240
241 #[tokio::test]
242 async fn memory_store_tofu_round_trip() {
243 let store = MemoryKeyStore::new();
244 let record = TofuRecord {
245 fingerprint: "abcd 1234".to_string(),
246 first_seen: 5000,
247 };
248 store.save_tofu_record("peer-1", record).await.unwrap();
249 let loaded = store.load_tofu_record("peer-1").await.unwrap();
250 assert!(loaded.is_some());
251 assert_eq!(loaded.unwrap().fingerprint, "abcd 1234");
252 }
253}
254
255#[cfg(all(test, feature = "fs-store"))]
256mod fs_tests {
257 use super::*;
258
259 #[tokio::test]
260 async fn fs_store_group_key_round_trip() {
261 let dir = tempfile::tempdir().unwrap();
262 let store = FileSystemKeyStore::new(dir.path());
263 let data = KeyData {
264 key: serde_json::json!({"kty": "oct", "k": "dGVzdA=="}),
265 stored_at: 1000,
266 };
267 store
268 .save_group_key("session-1", data.clone())
269 .await
270 .unwrap();
271 let loaded = store.load_group_key("session-1").await.unwrap();
272 assert!(loaded.is_some());
273 let loaded = loaded.unwrap();
274 assert_eq!(loaded.key["kty"], "oct");
275 assert_eq!(loaded.stored_at, 1000);
276 }
277
278 #[tokio::test]
279 async fn fs_store_group_key_missing() {
280 let dir = tempfile::tempdir().unwrap();
281 let store = FileSystemKeyStore::new(dir.path());
282 let loaded = store.load_group_key("nonexistent").await.unwrap();
283 assert!(loaded.is_none());
284 }
285
286 #[tokio::test]
287 async fn fs_store_delete_group_key() {
288 let dir = tempfile::tempdir().unwrap();
289 let store = FileSystemKeyStore::new(dir.path());
290 let data = KeyData {
291 key: serde_json::json!({"kty": "oct", "k": "dGVzdA=="}),
292 stored_at: 1000,
293 };
294 store.save_group_key("s1", data).await.unwrap();
295 store.delete_group_key("s1").await.unwrap();
296 let loaded = store.load_group_key("s1").await.unwrap();
297 assert!(loaded.is_none());
298 }
299
300 #[tokio::test]
301 async fn fs_store_tofu_round_trip() {
302 let dir = tempfile::tempdir().unwrap();
303 let store = FileSystemKeyStore::new(dir.path());
304 let record = TofuRecord {
305 fingerprint: "abcd 1234".to_string(),
306 first_seen: 5000,
307 };
308 store.save_tofu_record("peer-1", record).await.unwrap();
309 let loaded = store.load_tofu_record("peer-1").await.unwrap();
310 assert!(loaded.is_some());
311 assert_eq!(loaded.unwrap().fingerprint, "abcd 1234");
312 }
313
314 #[tokio::test]
315 async fn fs_store_persist_and_reload() {
316 let dir = tempfile::tempdir().unwrap();
317 let path = dir.path().to_path_buf();
318
319 {
321 let store = FileSystemKeyStore::new(&path);
322 let data = KeyData {
323 key: serde_json::json!({"kty": "oct", "k": "dGVzdGtleQ=="}),
324 stored_at: 42000,
325 };
326 store.save_group_key("persist-test", data).await.unwrap();
327 }
328
329 {
331 let store = FileSystemKeyStore::new(&path);
332 let loaded = store.load_group_key("persist-test").await.unwrap();
333 assert!(loaded.is_some());
334 let loaded = loaded.unwrap();
335 assert_eq!(loaded.key["k"], "dGVzdGtleQ==");
336 assert_eq!(loaded.stored_at, 42000);
337 }
338 }
339}