kaccy_bitcoin/
key_management.rs

1//! Advanced key management for Bitcoin wallets
2//!
3//! Provides advanced key management features including:
4//! - Key rotation for enhanced security
5//! - Time-delayed recovery for inheritance and emergency access
6//! - Social recovery using multi-signature schemes
7
8use bitcoin::Network;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use tokio::sync::RwLock;
14
15use crate::error::{BitcoinError, Result};
16
17/// Key rotation configuration
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct KeyRotationConfig {
20    /// Rotation interval in seconds
21    pub rotation_interval: u64,
22    /// Number of old keys to keep for transition
23    pub keys_to_keep: usize,
24    /// Whether to automatically rotate keys
25    pub auto_rotate: bool,
26}
27
28impl Default for KeyRotationConfig {
29    fn default() -> Self {
30        Self {
31            rotation_interval: 30 * 24 * 60 * 60, // 30 days
32            keys_to_keep: 3,
33            auto_rotate: false,
34        }
35    }
36}
37
38/// Rotated key information
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct RotatedKey {
41    /// Key identifier
42    pub key_id: String,
43    /// Creation timestamp
44    pub created_at: u64,
45    /// Expiry timestamp (when this key should be rotated out)
46    pub expires_at: u64,
47    /// Whether this is the active key
48    pub is_active: bool,
49    /// Derivation path (if applicable)
50    pub derivation_path: Option<String>,
51}
52
53/// Key rotation manager
54pub struct KeyRotationManager {
55    config: KeyRotationConfig,
56    keys: Arc<RwLock<HashMap<String, RotatedKey>>>,
57    active_key_id: Arc<RwLock<Option<String>>>,
58}
59
60impl KeyRotationManager {
61    /// Create a new key rotation manager
62    pub fn new(config: KeyRotationConfig) -> Self {
63        Self {
64            config,
65            keys: Arc::new(RwLock::new(HashMap::new())),
66            active_key_id: Arc::new(RwLock::new(None)),
67        }
68    }
69
70    /// Register a new key
71    pub async fn register_key(
72        &self,
73        key_id: String,
74        derivation_path: Option<String>,
75    ) -> Result<RotatedKey> {
76        let now = SystemTime::now()
77            .duration_since(UNIX_EPOCH)
78            .unwrap()
79            .as_secs();
80
81        let rotated_key = RotatedKey {
82            key_id: key_id.clone(),
83            created_at: now,
84            expires_at: now + self.config.rotation_interval,
85            is_active: false,
86            derivation_path,
87        };
88
89        self.keys.write().await.insert(key_id, rotated_key.clone());
90
91        tracing::info!(
92            key_id = %rotated_key.key_id,
93            expires_at = rotated_key.expires_at,
94            "Registered new key"
95        );
96
97        Ok(rotated_key)
98    }
99
100    /// Set the active key
101    pub async fn set_active_key(&self, key_id: String) -> Result<()> {
102        // Deactivate current active key
103        if let Some(old_key_id) = self.active_key_id.read().await.as_ref() {
104            if let Some(key) = self.keys.write().await.get_mut(old_key_id) {
105                key.is_active = false;
106            }
107        }
108
109        // Activate new key
110        let mut keys = self.keys.write().await;
111        if let Some(key) = keys.get_mut(&key_id) {
112            key.is_active = true;
113            *self.active_key_id.write().await = Some(key_id.clone());
114
115            tracing::info!(key_id = %key_id, "Activated new key");
116            Ok(())
117        } else {
118            Err(BitcoinError::Validation(format!(
119                "Key not found: {}",
120                key_id
121            )))
122        }
123    }
124
125    /// Check if rotation is needed
126    pub async fn needs_rotation(&self) -> bool {
127        let active_key_id = self.active_key_id.read().await;
128        if let Some(key_id) = active_key_id.as_ref() {
129            if let Some(key) = self.keys.read().await.get(key_id) {
130                let now = SystemTime::now()
131                    .duration_since(UNIX_EPOCH)
132                    .unwrap()
133                    .as_secs();
134                return now >= key.expires_at;
135            }
136        }
137        false
138    }
139
140    /// Get active key
141    pub async fn get_active_key(&self) -> Option<RotatedKey> {
142        let active_key_id = self.active_key_id.read().await;
143        if let Some(key_id) = active_key_id.as_ref() {
144            self.keys.read().await.get(key_id).cloned()
145        } else {
146            None
147        }
148    }
149
150    /// Clean up expired keys
151    pub async fn cleanup_expired_keys(&self) -> usize {
152        let now = SystemTime::now()
153            .duration_since(UNIX_EPOCH)
154            .unwrap()
155            .as_secs();
156
157        let mut keys = self.keys.write().await;
158        let active_key_id = self.active_key_id.read().await.clone();
159
160        // Get all keys sorted by creation time
161        let mut sorted_keys: Vec<_> = keys
162            .iter()
163            .map(|(id, key)| (id.clone(), key.clone()))
164            .collect();
165        sorted_keys.sort_by_key(|(_, key)| key.created_at);
166
167        // Keep the active key and the most recent N keys
168        let keys_to_keep: Vec<String> = sorted_keys
169            .iter()
170            .rev()
171            .take(self.config.keys_to_keep)
172            .map(|(id, _)| id.clone())
173            .collect();
174
175        let mut removed_count = 0;
176        keys.retain(|id, key| {
177            let keep = key.is_active
178                || keys_to_keep.contains(id)
179                || active_key_id.as_ref() == Some(id)
180                || key.expires_at > now;
181            if !keep {
182                removed_count += 1;
183            }
184            keep
185        });
186
187        if removed_count > 0 {
188            tracing::info!(removed = removed_count, "Cleaned up expired keys");
189        }
190
191        removed_count
192    }
193}
194
195/// Time-delayed recovery configuration
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct TimeDelayedRecoveryConfig {
198    /// Delay period in seconds before recovery is possible
199    pub delay_seconds: u64,
200    /// Network to use
201    pub network: Network,
202}
203
204impl Default for TimeDelayedRecoveryConfig {
205    fn default() -> Self {
206        Self {
207            delay_seconds: 90 * 24 * 60 * 60, // 90 days
208            network: Network::Bitcoin,
209        }
210    }
211}
212
213/// Recovery key information
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct RecoveryKey {
216    /// Recovery key identifier
217    pub id: String,
218    /// Activation timestamp (when recovery becomes available)
219    pub activation_time: u64,
220    /// Recovery key derivation path
221    pub derivation_path: String,
222    /// Whether this recovery key has been used
223    pub used: bool,
224}
225
226/// Time-delayed recovery manager
227pub struct TimeDelayedRecoveryManager {
228    config: TimeDelayedRecoveryConfig,
229    recovery_keys: Arc<RwLock<HashMap<String, RecoveryKey>>>,
230}
231
232impl TimeDelayedRecoveryManager {
233    /// Create a new time-delayed recovery manager
234    pub fn new(config: TimeDelayedRecoveryConfig) -> Self {
235        Self {
236            config,
237            recovery_keys: Arc::new(RwLock::new(HashMap::new())),
238        }
239    }
240
241    /// Create a new recovery key
242    pub async fn create_recovery_key(&self, derivation_path: String) -> Result<RecoveryKey> {
243        let now = SystemTime::now()
244            .duration_since(UNIX_EPOCH)
245            .unwrap()
246            .as_secs();
247
248        let recovery_key = RecoveryKey {
249            id: uuid::Uuid::new_v4().to_string(),
250            activation_time: now + self.config.delay_seconds,
251            derivation_path,
252            used: false,
253        };
254
255        self.recovery_keys
256            .write()
257            .await
258            .insert(recovery_key.id.clone(), recovery_key.clone());
259
260        tracing::info!(
261            id = %recovery_key.id,
262            activation_time = recovery_key.activation_time,
263            "Created time-delayed recovery key"
264        );
265
266        Ok(recovery_key)
267    }
268
269    /// Check if a recovery key is available for use
270    pub async fn is_recovery_available(&self, recovery_id: &str) -> Result<bool> {
271        let keys = self.recovery_keys.read().await;
272        if let Some(key) = keys.get(recovery_id) {
273            if key.used {
274                return Ok(false);
275            }
276
277            let now = SystemTime::now()
278                .duration_since(UNIX_EPOCH)
279                .unwrap()
280                .as_secs();
281            Ok(now >= key.activation_time)
282        } else {
283            Err(BitcoinError::Validation(format!(
284                "Recovery key not found: {}",
285                recovery_id
286            )))
287        }
288    }
289
290    /// Use a recovery key (mark as used)
291    pub async fn use_recovery_key(&self, recovery_id: &str) -> Result<RecoveryKey> {
292        if !self.is_recovery_available(recovery_id).await? {
293            return Err(BitcoinError::Validation(
294                "Recovery key not yet available or already used".to_string(),
295            ));
296        }
297
298        let mut keys = self.recovery_keys.write().await;
299        if let Some(key) = keys.get_mut(recovery_id) {
300            key.used = true;
301            tracing::info!(id = %recovery_id, "Used recovery key");
302            Ok(key.clone())
303        } else {
304            Err(BitcoinError::Validation(format!(
305                "Recovery key not found: {}",
306                recovery_id
307            )))
308        }
309    }
310
311    /// Get time remaining until recovery is available
312    pub async fn time_until_recovery(&self, recovery_id: &str) -> Result<Duration> {
313        let keys = self.recovery_keys.read().await;
314        if let Some(key) = keys.get(recovery_id) {
315            let now = SystemTime::now()
316                .duration_since(UNIX_EPOCH)
317                .unwrap()
318                .as_secs();
319
320            if now >= key.activation_time {
321                Ok(Duration::from_secs(0))
322            } else {
323                Ok(Duration::from_secs(key.activation_time - now))
324            }
325        } else {
326            Err(BitcoinError::Validation(format!(
327                "Recovery key not found: {}",
328                recovery_id
329            )))
330        }
331    }
332}
333
334/// Social recovery configuration
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct SocialRecoveryConfig {
337    /// Threshold (m in m-of-n)
338    pub threshold: usize,
339    /// Total guardians (n in m-of-n)
340    pub total_guardians: usize,
341    /// Network to use
342    pub network: Network,
343}
344
345impl Default for SocialRecoveryConfig {
346    fn default() -> Self {
347        Self {
348            threshold: 2,
349            total_guardians: 3,
350            network: Network::Bitcoin,
351        }
352    }
353}
354
355/// Guardian information
356#[derive(Debug, Clone, Serialize, Deserialize)]
357pub struct Guardian {
358    /// Guardian identifier
359    pub id: String,
360    /// Guardian's public key or xpub
361    pub public_key: String,
362    /// Guardian's name/identifier
363    pub name: String,
364    /// Whether this guardian is active
365    pub active: bool,
366}
367
368/// Recovery share from a guardian
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct RecoveryShare {
371    /// Guardian who provided this share
372    pub guardian_id: String,
373    /// Share data (encrypted key share)
374    pub share_data: Vec<u8>,
375    /// Timestamp when share was created
376    pub created_at: u64,
377}
378
379/// Social recovery manager using Shamir's Secret Sharing
380pub struct SocialRecoveryManager {
381    config: SocialRecoveryConfig,
382    guardians: Arc<RwLock<HashMap<String, Guardian>>>,
383    recovery_shares: Arc<RwLock<Vec<RecoveryShare>>>,
384}
385
386impl SocialRecoveryManager {
387    /// Create a new social recovery manager
388    pub fn new(config: SocialRecoveryConfig) -> Result<Self> {
389        if config.threshold > config.total_guardians {
390            return Err(BitcoinError::Validation(
391                "Threshold cannot exceed total guardians".to_string(),
392            ));
393        }
394
395        if config.threshold == 0 {
396            return Err(BitcoinError::Validation(
397                "Threshold must be at least 1".to_string(),
398            ));
399        }
400
401        Ok(Self {
402            config,
403            guardians: Arc::new(RwLock::new(HashMap::new())),
404            recovery_shares: Arc::new(RwLock::new(Vec::new())),
405        })
406    }
407
408    /// Add a guardian
409    pub async fn add_guardian(&self, guardian: Guardian) -> Result<()> {
410        let guardians_count = self.guardians.read().await.len();
411
412        if guardians_count >= self.config.total_guardians {
413            return Err(BitcoinError::Validation(format!(
414                "Maximum number of guardians ({}) already reached",
415                self.config.total_guardians
416            )));
417        }
418
419        self.guardians
420            .write()
421            .await
422            .insert(guardian.id.clone(), guardian.clone());
423
424        tracing::info!(
425            guardian_id = %guardian.id,
426            guardian_name = %guardian.name,
427            "Added guardian"
428        );
429
430        Ok(())
431    }
432
433    /// Remove a guardian
434    pub async fn remove_guardian(&self, guardian_id: &str) -> Result<Guardian> {
435        if let Some(guardian) = self.guardians.write().await.remove(guardian_id) {
436            tracing::info!(guardian_id = %guardian_id, "Removed guardian");
437            Ok(guardian)
438        } else {
439            Err(BitcoinError::Validation(format!(
440                "Guardian not found: {}",
441                guardian_id
442            )))
443        }
444    }
445
446    /// Get all guardians
447    pub async fn get_guardians(&self) -> Vec<Guardian> {
448        self.guardians.read().await.values().cloned().collect()
449    }
450
451    /// Submit a recovery share from a guardian
452    pub async fn submit_recovery_share(&self, share: RecoveryShare) -> Result<()> {
453        // Verify guardian exists
454        if !self.guardians.read().await.contains_key(&share.guardian_id) {
455            return Err(BitcoinError::Validation(format!(
456                "Unknown guardian: {}",
457                share.guardian_id
458            )));
459        }
460
461        // Check if this guardian already submitted a share
462        let mut shares = self.recovery_shares.write().await;
463        if shares.iter().any(|s| s.guardian_id == share.guardian_id) {
464            return Err(BitcoinError::Validation(format!(
465                "Guardian {} already submitted a share",
466                share.guardian_id
467            )));
468        }
469
470        shares.push(share.clone());
471
472        tracing::info!(
473            guardian_id = %share.guardian_id,
474            total_shares = shares.len(),
475            "Received recovery share"
476        );
477
478        Ok(())
479    }
480
481    /// Check if recovery is possible (threshold met)
482    pub async fn can_recover(&self) -> bool {
483        self.recovery_shares.read().await.len() >= self.config.threshold
484    }
485
486    /// Attempt recovery with collected shares
487    pub async fn attempt_recovery(&self) -> Result<Vec<u8>> {
488        let shares = self.recovery_shares.read().await;
489
490        if shares.len() < self.config.threshold {
491            return Err(BitcoinError::Validation(format!(
492                "Insufficient shares: have {}, need {}",
493                shares.len(),
494                self.config.threshold
495            )));
496        }
497
498        // In production, this would:
499        // 1. Use Shamir's Secret Sharing to reconstruct the secret
500        // 2. Derive the recovery key from the secret
501        // 3. Return the reconstructed key
502
503        tracing::info!(
504            shares_used = shares.len(),
505            threshold = self.config.threshold,
506            "Attempting recovery with collected shares"
507        );
508
509        // Placeholder: return empty recovery data
510        Ok(vec![0u8; 32])
511    }
512
513    /// Clear all recovery shares (e.g., after successful recovery)
514    pub async fn clear_shares(&self) {
515        self.recovery_shares.write().await.clear();
516        tracing::info!("Cleared all recovery shares");
517    }
518
519    /// Get recovery progress
520    pub async fn get_recovery_progress(&self) -> RecoveryProgress {
521        let shares_collected = self.recovery_shares.read().await.len();
522        RecoveryProgress {
523            shares_collected,
524            threshold: self.config.threshold,
525            total_guardians: self.config.total_guardians,
526            can_recover: shares_collected >= self.config.threshold,
527        }
528    }
529}
530
531/// Recovery progress information
532#[derive(Debug, Clone, Serialize, Deserialize)]
533pub struct RecoveryProgress {
534    /// Number of shares collected
535    pub shares_collected: usize,
536    /// Threshold required for recovery
537    pub threshold: usize,
538    /// Total number of guardians
539    pub total_guardians: usize,
540    /// Whether recovery is possible
541    pub can_recover: bool,
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[tokio::test]
549    async fn test_key_rotation_registration() {
550        let config = KeyRotationConfig::default();
551        let manager = KeyRotationManager::new(config);
552
553        let key = manager
554            .register_key("key1".to_string(), Some("m/84'/0'/0'".to_string()))
555            .await
556            .unwrap();
557
558        assert_eq!(key.key_id, "key1");
559        assert!(!key.is_active);
560        assert_eq!(key.derivation_path, Some("m/84'/0'/0'".to_string()));
561    }
562
563    #[tokio::test]
564    async fn test_key_rotation_activation() {
565        let config = KeyRotationConfig::default();
566        let manager = KeyRotationManager::new(config);
567
568        manager
569            .register_key("key1".to_string(), None)
570            .await
571            .unwrap();
572        manager.set_active_key("key1".to_string()).await.unwrap();
573
574        let active = manager.get_active_key().await.unwrap();
575        assert_eq!(active.key_id, "key1");
576        assert!(active.is_active);
577    }
578
579    #[tokio::test]
580    async fn test_key_rotation_cleanup() {
581        let config = KeyRotationConfig {
582            rotation_interval: 1, // 1 second
583            keys_to_keep: 2,
584            auto_rotate: false,
585        };
586        let manager = KeyRotationManager::new(config);
587
588        // Register multiple keys
589        manager
590            .register_key("key1".to_string(), None)
591            .await
592            .unwrap();
593        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
594        manager
595            .register_key("key2".to_string(), None)
596            .await
597            .unwrap();
598        manager
599            .register_key("key3".to_string(), None)
600            .await
601            .unwrap();
602        manager.set_active_key("key3".to_string()).await.unwrap();
603
604        let removed = manager.cleanup_expired_keys().await;
605        assert!(removed <= 1); // Should keep active key and recent keys
606    }
607
608    #[tokio::test]
609    async fn test_time_delayed_recovery_creation() {
610        let config = TimeDelayedRecoveryConfig::default();
611        let manager = TimeDelayedRecoveryManager::new(config);
612
613        let recovery_key = manager
614            .create_recovery_key("m/84'/0'/0'/0/0".to_string())
615            .await
616            .unwrap();
617
618        assert!(!recovery_key.used);
619        assert_eq!(recovery_key.derivation_path, "m/84'/0'/0'/0/0");
620    }
621
622    #[tokio::test]
623    async fn test_time_delayed_recovery_not_available() {
624        let config = TimeDelayedRecoveryConfig {
625            delay_seconds: 3600, // 1 hour
626            network: Network::Bitcoin,
627        };
628        let manager = TimeDelayedRecoveryManager::new(config);
629
630        let recovery_key = manager
631            .create_recovery_key("m/84'/0'/0'/0/0".to_string())
632            .await
633            .unwrap();
634
635        let available = manager
636            .is_recovery_available(&recovery_key.id)
637            .await
638            .unwrap();
639        assert!(!available);
640    }
641
642    #[tokio::test]
643    async fn test_social_recovery_add_guardian() {
644        let config = SocialRecoveryConfig {
645            threshold: 2,
646            total_guardians: 3,
647            network: Network::Bitcoin,
648        };
649        let manager = SocialRecoveryManager::new(config).unwrap();
650
651        let guardian = Guardian {
652            id: "guardian1".to_string(),
653            public_key: "xpub...".to_string(),
654            name: "Alice".to_string(),
655            active: true,
656        };
657
658        manager.add_guardian(guardian).await.unwrap();
659
660        let guardians = manager.get_guardians().await;
661        assert_eq!(guardians.len(), 1);
662        assert_eq!(guardians[0].name, "Alice");
663    }
664
665    #[tokio::test]
666    async fn test_social_recovery_threshold_validation() {
667        let result = SocialRecoveryManager::new(SocialRecoveryConfig {
668            threshold: 5,
669            total_guardians: 3, // Invalid: threshold > total
670            network: Network::Bitcoin,
671        });
672
673        assert!(result.is_err());
674    }
675
676    #[tokio::test]
677    async fn test_social_recovery_share_submission() {
678        let config = SocialRecoveryConfig {
679            threshold: 2,
680            total_guardians: 3,
681            network: Network::Bitcoin,
682        };
683        let manager = SocialRecoveryManager::new(config).unwrap();
684
685        // Add guardian first
686        let guardian = Guardian {
687            id: "guardian1".to_string(),
688            public_key: "xpub...".to_string(),
689            name: "Alice".to_string(),
690            active: true,
691        };
692        manager.add_guardian(guardian).await.unwrap();
693
694        // Submit share
695        let share = RecoveryShare {
696            guardian_id: "guardian1".to_string(),
697            share_data: vec![1, 2, 3, 4],
698            created_at: 0,
699        };
700        manager.submit_recovery_share(share).await.unwrap();
701
702        let progress = manager.get_recovery_progress().await;
703        assert_eq!(progress.shares_collected, 1);
704        assert!(!progress.can_recover); // Need 2 shares
705    }
706
707    #[tokio::test]
708    async fn test_social_recovery_threshold_met() {
709        let config = SocialRecoveryConfig {
710            threshold: 2,
711            total_guardians: 3,
712            network: Network::Bitcoin,
713        };
714        let manager = SocialRecoveryManager::new(config).unwrap();
715
716        // Add guardians
717        for i in 1..=3 {
718            let guardian = Guardian {
719                id: format!("guardian{}", i),
720                public_key: format!("xpub{}", i),
721                name: format!("Guardian {}", i),
722                active: true,
723            };
724            manager.add_guardian(guardian).await.unwrap();
725        }
726
727        // Submit 2 shares (meets threshold)
728        for i in 1..=2 {
729            let share = RecoveryShare {
730                guardian_id: format!("guardian{}", i),
731                share_data: vec![i as u8; 32],
732                created_at: 0,
733            };
734            manager.submit_recovery_share(share).await.unwrap();
735        }
736
737        assert!(manager.can_recover().await);
738
739        let progress = manager.get_recovery_progress().await;
740        assert!(progress.can_recover);
741        assert_eq!(progress.shares_collected, 2);
742    }
743}