1use crate::crypto::{generate_keypair, AgentIdentity};
4use crate::error::{CryptoError, Error, Result};
5use ed25519_dalek::{SigningKey, VerifyingKey, PUBLIC_KEY_LENGTH, SECRET_KEY_LENGTH};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::fs;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13use zeroize::{Zeroize, ZeroizeOnDrop};
14use base64::{Engine as _, engine::general_purpose::STANDARD};
15
16#[derive(Clone)]
18pub struct KeyPair {
19 signing_key: SigningKey,
20 verifying_key: VerifyingKey,
21}
22
23impl KeyPair {
24 pub fn generate() -> Result<Self> {
26 let (signing_key, verifying_key) = generate_keypair()?;
27 Ok(Self {
28 signing_key,
29 verifying_key,
30 })
31 }
32
33 pub fn from_bytes(signing_key_bytes: &[u8]) -> Result<Self> {
35 if signing_key_bytes.len() != SECRET_KEY_LENGTH {
36 return Err(Error::Crypto(CryptoError::InvalidPrivateKey {
37 details: format!(
38 "Invalid key length: expected {}, got {}",
39 SECRET_KEY_LENGTH,
40 signing_key_bytes.len()
41 ),
42 }));
43 }
44
45 let signing_key = SigningKey::from_bytes(
46 signing_key_bytes.try_into().map_err(|_| {
47 Error::Crypto(CryptoError::InvalidPrivateKey {
48 details: "Invalid key bytes".to_string(),
49 })
50 })?
51 );
52 let verifying_key = signing_key.verifying_key();
53
54 Ok(Self {
55 signing_key,
56 verifying_key,
57 })
58 }
59
60 pub fn signing_key(&self) -> &SigningKey {
62 &self.signing_key
63 }
64
65 pub fn verifying_key(&self) -> &VerifyingKey {
67 &self.verifying_key
68 }
69
70 pub fn to_bytes(&self) -> [u8; SECRET_KEY_LENGTH] {
72 self.signing_key.to_bytes()
73 }
74
75 pub fn verifying_key_bytes(&self) -> [u8; PUBLIC_KEY_LENGTH] {
77 self.verifying_key.to_bytes()
78 }
79
80 pub fn to_identity(self) -> Result<AgentIdentity> {
82 AgentIdentity::from_signing_key(self.signing_key.clone())
83 }
84}
85
86impl Drop for KeyPair {
87 fn drop(&mut self) {
88 }
90}
91
92impl std::fmt::Debug for KeyPair {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("KeyPair")
95 .field("verifying_key", &hex::encode(self.verifying_key.to_bytes()))
96 .finish_non_exhaustive()
97 }
98}
99
100#[derive(Clone, Serialize, Deserialize)]
102pub struct StoredKey {
103 pub id: Uuid,
105 pub alias: String,
107 #[serde(with = "base64_serde")]
109 pub signing_key: Vec<u8>,
110 #[serde(with = "base64_array_serde")]
112 pub verifying_key: [u8; PUBLIC_KEY_LENGTH],
113 pub created_at: chrono::DateTime<chrono::Utc>,
115 pub last_used: Option<chrono::DateTime<chrono::Utc>>,
117 pub tags: HashMap<String, String>,
119}
120
121impl StoredKey {
122 pub fn new(alias: String, keypair: &KeyPair) -> Self {
124 Self {
125 id: Uuid::new_v4(),
126 alias,
127 signing_key: keypair.to_bytes().to_vec(),
128 verifying_key: keypair.verifying_key_bytes(),
129 created_at: chrono::Utc::now(),
130 last_used: None,
131 tags: HashMap::new(),
132 }
133 }
134
135 pub fn to_keypair(&self) -> Result<KeyPair> {
137 KeyPair::from_bytes(&self.signing_key)
138 }
139
140 pub fn add_tag(&mut self, key: String, value: String) {
142 self.tags.insert(key, value);
143 }
144
145 pub fn mark_used(&mut self) {
147 self.last_used = Some(chrono::Utc::now());
148 }
149}
150
151impl Zeroize for StoredKey {
152 fn zeroize(&mut self) {
153 self.signing_key.zeroize();
154 }
155}
156
157impl Drop for StoredKey {
158 fn drop(&mut self) {
159 self.zeroize();
160 }
161}
162
163pub struct KeyManager {
165 keys: Arc<RwLock<HashMap<Uuid, StoredKey>>>,
167 aliases: Arc<RwLock<HashMap<String, Uuid>>>,
169 storage_path: Option<PathBuf>,
171}
172
173impl KeyManager {
174 pub fn new() -> Self {
176 Self {
177 keys: Arc::new(RwLock::new(HashMap::new())),
178 aliases: Arc::new(RwLock::new(HashMap::new())),
179 storage_path: None,
180 }
181 }
182
183 pub fn with_storage<P: AsRef<Path>>(path: P) -> Self {
185 Self {
186 keys: Arc::new(RwLock::new(HashMap::new())),
187 aliases: Arc::new(RwLock::new(HashMap::new())),
188 storage_path: Some(path.as_ref().to_path_buf()),
189 }
190 }
191
192 pub async fn store(&self, alias: String, keypair: KeyPair) -> Result<Uuid> {
194 let stored_key = StoredKey::new(alias.clone(), &keypair);
195 let id = stored_key.id;
196
197 {
198 let mut keys = self.keys.write().await;
199 let mut aliases = self.aliases.write().await;
200
201 keys.insert(id, stored_key.clone());
202 aliases.insert(alias, id);
203 }
204
205 if let Some(ref path) = self.storage_path {
207 self.persist_key(path, &stored_key).await?;
208 }
209
210 Ok(id)
211 }
212
213 pub async fn get(&self, id: &Uuid) -> Result<KeyPair> {
215 let keys = self.keys.read().await;
216 let stored_key = keys.get(id)
217 .ok_or_else(|| Error::KeyNotFound(id.to_string()))?;
218
219 stored_key.to_keypair()
220 }
221
222 pub async fn get_by_alias(&self, alias: &str) -> Result<KeyPair> {
224 let aliases = self.aliases.read().await;
225 let id = aliases.get(alias)
226 .ok_or_else(|| Error::KeyNotFound(format!("alias: {}", alias)))?
227 .clone();
228
229 drop(aliases);
230 self.get(&id).await
231 }
232
233 pub async fn list(&self) -> Vec<Uuid> {
235 let keys = self.keys.read().await;
236 keys.keys().copied().collect()
237 }
238
239 pub async fn list_aliases(&self) -> Vec<String> {
241 let aliases = self.aliases.read().await;
242 aliases.keys().cloned().collect()
243 }
244
245 pub async fn remove(&self, id: &Uuid) -> Result<()> {
247 let mut keys = self.keys.write().await;
248 let mut aliases = self.aliases.write().await;
249
250 if let Some(stored_key) = keys.remove(id) {
251 aliases.remove(&stored_key.alias);
252
253 if let Some(ref path) = self.storage_path {
255 self.remove_persisted_key(path, id).await?;
256 }
257 }
258
259 Ok(())
260 }
261
262 pub async fn exists(&self, id: &Uuid) -> bool {
264 let keys = self.keys.read().await;
265 keys.contains_key(id)
266 }
267
268 pub async fn alias_exists(&self, alias: &str) -> bool {
270 let aliases = self.aliases.read().await;
271 aliases.contains_key(alias)
272 }
273
274 pub async fn load_from_storage(&self) -> Result<usize> {
276 let path = self.storage_path.as_ref()
277 .ok_or_else(|| Error::Configuration("No storage path configured".to_string()))?;
278
279 if !path.exists() {
280 fs::create_dir_all(path).await?;
281 return Ok(0);
282 }
283
284 let mut entries = fs::read_dir(path).await?;
285 let mut count = 0;
286
287 while let Some(entry) = entries.next_entry().await? {
288 if entry.path().extension().and_then(|s| s.to_str()) == Some("json") {
289 let content = fs::read_to_string(entry.path()).await?;
290 let stored_key: StoredKey = serde_json::from_str(&content)
291 .map_err(|e| Error::Serialization(e.into()))?;
292
293 let mut keys = self.keys.write().await;
294 let mut aliases = self.aliases.write().await;
295
296 keys.insert(stored_key.id, stored_key.clone());
297 aliases.insert(stored_key.alias.clone(), stored_key.id);
298
299 count += 1;
300 }
301 }
302
303 Ok(count)
304 }
305
306 async fn persist_key(&self, base_path: &Path, key: &StoredKey) -> Result<()> {
308 if !base_path.exists() {
309 fs::create_dir_all(base_path).await?;
310 }
311
312 let file_path = base_path.join(format!("{}.json", key.id));
313 let content = serde_json::to_string_pretty(key)
314 .map_err(|e| Error::Serialization(e.into()))?;
315 fs::write(file_path, content).await?;
316
317 Ok(())
318 }
319
320 async fn remove_persisted_key(&self, base_path: &Path, id: &Uuid) -> Result<()> {
322 let file_path = base_path.join(format!("{}.json", id));
323 if file_path.exists() {
324 fs::remove_file(file_path).await?;
325 }
326 Ok(())
327 }
328}
329
330impl Default for KeyManager {
331 fn default() -> Self {
332 Self::new()
333 }
334}
335
336mod base64_serde {
338 use super::STANDARD;
339 use base64::Engine as _;
340 use serde::{Deserialize, Deserializer, Serializer};
341
342 pub fn serialize<S, T>(data: T, serializer: S) -> std::result::Result<S::Ok, S::Error>
343 where
344 S: Serializer,
345 T: AsRef<[u8]>,
346 {
347 let encoded = STANDARD.encode(data.as_ref());
348 serializer.serialize_str(&encoded)
349 }
350
351 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Vec<u8>, D::Error>
352 where
353 D: Deserializer<'de>,
354 {
355 let s: String = Deserialize::deserialize(deserializer)?;
356 STANDARD.decode(&s).map_err(serde::de::Error::custom)
357 }
358}
359
360mod base64_array_serde {
362 use super::{STANDARD, PUBLIC_KEY_LENGTH};
363 use base64::Engine as _;
364 use serde::{Deserialize, Deserializer, Serializer};
365
366 pub fn serialize<S>(data: &[u8; PUBLIC_KEY_LENGTH], serializer: S) -> std::result::Result<S::Ok, S::Error>
367 where
368 S: Serializer,
369 {
370 let encoded = STANDARD.encode(data);
371 serializer.serialize_str(&encoded)
372 }
373
374 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<[u8; PUBLIC_KEY_LENGTH], D::Error>
375 where
376 D: Deserializer<'de>,
377 {
378 let s: String = Deserialize::deserialize(deserializer)?;
379 let vec = STANDARD.decode(&s).map_err(serde::de::Error::custom)?;
380
381 if vec.len() != PUBLIC_KEY_LENGTH {
382 return Err(serde::de::Error::custom(format!(
383 "Expected {} bytes, got {}",
384 PUBLIC_KEY_LENGTH,
385 vec.len()
386 )));
387 }
388
389 let mut array = [0u8; PUBLIC_KEY_LENGTH];
390 array.copy_from_slice(&vec);
391 Ok(array)
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_keypair_generate() {
401 let keypair = KeyPair::generate().unwrap();
402 assert_eq!(keypair.verifying_key_bytes().len(), PUBLIC_KEY_LENGTH);
403 }
404
405 #[test]
406 fn test_keypair_from_bytes() {
407 let keypair1 = KeyPair::generate().unwrap();
408 let bytes = keypair1.to_bytes();
409
410 let keypair2 = KeyPair::from_bytes(&bytes).unwrap();
411 assert_eq!(
412 keypair1.verifying_key_bytes(),
413 keypair2.verifying_key_bytes()
414 );
415 }
416
417 #[test]
418 fn test_stored_key() {
419 let keypair = KeyPair::generate().unwrap();
420 let stored = StoredKey::new("test_key".to_string(), &keypair);
421
422 assert_eq!(stored.alias, "test_key");
423 assert!(!stored.id.is_nil());
424
425 let restored = stored.to_keypair().unwrap();
426 assert_eq!(
427 keypair.verifying_key_bytes(),
428 restored.verifying_key_bytes()
429 );
430 }
431
432 #[tokio::test]
433 async fn test_key_manager_store_and_get() {
434 let manager = KeyManager::new();
435 let keypair = KeyPair::generate().unwrap();
436 let alias = "test_key".to_string();
437
438 let id = manager.store(alias.clone(), keypair.clone()).await.unwrap();
439 let retrieved = manager.get(&id).await.unwrap();
440
441 assert_eq!(
442 keypair.verifying_key_bytes(),
443 retrieved.verifying_key_bytes()
444 );
445 }
446
447 #[tokio::test]
448 async fn test_key_manager_get_by_alias() {
449 let manager = KeyManager::new();
450 let keypair = KeyPair::generate().unwrap();
451 let alias = "test_key".to_string();
452
453 manager.store(alias.clone(), keypair.clone()).await.unwrap();
454 let retrieved = manager.get_by_alias(&alias).await.unwrap();
455
456 assert_eq!(
457 keypair.verifying_key_bytes(),
458 retrieved.verifying_key_bytes()
459 );
460 }
461
462 #[tokio::test]
463 async fn test_key_manager_list() {
464 let manager = KeyManager::new();
465
466 let kp1 = KeyPair::generate().unwrap();
467 let kp2 = KeyPair::generate().unwrap();
468
469 manager.store("key1".to_string(), kp1).await.unwrap();
470 manager.store("key2".to_string(), kp2).await.unwrap();
471
472 let ids = manager.list().await;
473 assert_eq!(ids.len(), 2);
474
475 let aliases = manager.list_aliases().await;
476 assert_eq!(aliases.len(), 2);
477 }
478
479 #[tokio::test]
480 async fn test_key_manager_remove() {
481 let manager = KeyManager::new();
482 let keypair = KeyPair::generate().unwrap();
483 let alias = "test_key".to_string();
484
485 let id = manager.store(alias.clone(), keypair).await.unwrap();
486 assert!(manager.exists(&id).await);
487
488 manager.remove(&id).await.unwrap();
489 assert!(!manager.exists(&id).await);
490 assert!(!manager.alias_exists(&alias).await);
491 }
492}