Skip to main content

bsv/primitives/
key_shares.rs

1//! Shamir's Secret Sharing for private keys.
2//!
3//! Implements key splitting and reconstruction using polynomial interpolation
4//! over GF(p). Follows the TS SDK PrivateKey.ts KeyShares implementation.
5
6use crate::primitives::big_number::{BigNumber, Endian};
7use crate::primitives::curve::Curve;
8use crate::primitives::error::PrimitivesError;
9use crate::primitives::hash::sha512_hmac;
10use crate::primitives::polynomial::{PointInFiniteField, Polynomial};
11use crate::primitives::private_key::PrivateKey;
12use crate::primitives::random::random_bytes;
13use crate::primitives::utils::to_hex;
14
15/// Shamir's Secret Sharing for PrivateKey backup and recovery.
16///
17/// Splits a private key into `total` shares such that any `threshold`
18/// shares can reconstruct the original key, but fewer shares reveal
19/// nothing about the secret.
20///
21/// Share format (backup string): "base58(x).base58(y).threshold.integrity"
22/// where integrity is the first 8 hex chars of hash160(pubkey) as hex.
23pub struct KeyShares {
24    pub points: Vec<PointInFiniteField>,
25    pub threshold: usize,
26    pub integrity: String,
27}
28
29impl KeyShares {
30    /// Create a new KeyShares instance.
31    pub fn new(points: Vec<PointInFiniteField>, threshold: usize, integrity: String) -> Self {
32        KeyShares {
33            points,
34            threshold,
35            integrity,
36        }
37    }
38
39    /// Split a private key into shares using Shamir's Secret Sharing.
40    ///
41    /// # Arguments
42    /// * `key` - The private key to split
43    /// * `threshold` - Minimum shares needed to reconstruct (must be >= 2)
44    /// * `total` - Total shares to generate (must be >= threshold)
45    ///
46    /// # Returns
47    /// A KeyShares instance containing the shares, threshold, and integrity hash.
48    pub fn split(
49        key: &PrivateKey,
50        threshold: usize,
51        total: usize,
52    ) -> Result<Self, PrimitivesError> {
53        if threshold < 2 {
54            return Err(PrimitivesError::ThresholdError(
55                "threshold must be at least 2".to_string(),
56            ));
57        }
58        if total < 2 {
59            return Err(PrimitivesError::ThresholdError(
60                "totalShares must be at least 2".to_string(),
61            ));
62        }
63        if threshold > total {
64            return Err(PrimitivesError::ThresholdError(
65                "threshold should be less than or equal to totalShares".to_string(),
66            ));
67        }
68
69        let curve = Curve::secp256k1();
70        let key_bytes = key.to_bytes();
71        let poly = Polynomial::from_private_key(&key_bytes, threshold);
72
73        let mut points = Vec::with_capacity(total);
74        let mut used_x_coords: Vec<BigNumber> = Vec::new();
75
76        // Cryptographically secure x-coordinate generation
77        // Matching TS SDK: uses HMAC-SHA-512 with master seed for x-coordinate generation
78        let seed = random_bytes(64);
79
80        for i in 0..total {
81            let mut x: BigNumber;
82            let mut attempts = 0u32;
83
84            loop {
85                let mut counter = Vec::new();
86                counter.push(i as u8);
87                counter.push(attempts as u8);
88                counter.extend_from_slice(&random_bytes(32));
89
90                let h = sha512_hmac(&seed, &counter);
91                x = BigNumber::from_bytes(&h, Endian::Big);
92                x = x
93                    .umod(&curve.p)
94                    .map_err(|e| PrimitivesError::ArithmeticError(format!("mod p: {}", e)))?;
95
96                attempts += 1;
97                if attempts > 5 {
98                    return Err(PrimitivesError::ThresholdError(
99                        "Failed to generate unique x coordinate after 5 attempts".to_string(),
100                    ));
101                }
102
103                // Check x is non-zero and not already used
104                if x.is_zero() {
105                    continue;
106                }
107                let mut duplicate = false;
108                for existing in &used_x_coords {
109                    if existing.cmp(&x) == 0 {
110                        duplicate = true;
111                        break;
112                    }
113                }
114                if !duplicate {
115                    break;
116                }
117            }
118
119            used_x_coords.push(x.clone());
120            let y = poly.value_at(&x);
121            points.push(PointInFiniteField::new(x, y));
122        }
123
124        // Integrity hash: first 8 hex chars of hash160(compressed pubkey) as hex
125        // TS SDK: this.toPublicKey().toHash('hex').slice(0, 8)
126        let pubkey = key.to_public_key();
127        let pubkey_hash = pubkey.to_hash();
128        let integrity = to_hex(&pubkey_hash);
129        let integrity = integrity[..8].to_string();
130
131        Ok(KeyShares {
132            points,
133            threshold,
134            integrity,
135        })
136    }
137
138    /// Convert shares to backup format strings.
139    ///
140    /// Each share is formatted as: "base58(x).base58(y).threshold.integrity"
141    pub fn to_backup_format(&self) -> Vec<String> {
142        self.points
143            .iter()
144            .map(|share| {
145                format!(
146                    "{}.{}.{}",
147                    share.to_string_repr(),
148                    self.threshold,
149                    self.integrity
150                )
151            })
152            .collect()
153    }
154
155    /// Parse shares from backup format strings.
156    ///
157    /// Each share must be in format: "base58(x).base58(y).threshold.integrity"
158    pub fn from_backup_format(shares: &[String]) -> Result<Self, PrimitivesError> {
159        if shares.is_empty() {
160            return Err(PrimitivesError::InvalidFormat(
161                "No shares provided".to_string(),
162            ));
163        }
164
165        let mut threshold = 0usize;
166        let mut integrity = String::new();
167        let mut points = Vec::with_capacity(shares.len());
168
169        for (idx, share) in shares.iter().enumerate() {
170            let parts: Vec<&str> = share.split('.').collect();
171            if parts.len() != 4 {
172                return Err(PrimitivesError::InvalidFormat(format!(
173                    "Invalid share format in share {}. Expected format: \"x.y.t.i\" - received {}",
174                    idx, share
175                )));
176            }
177
178            let t_str = parts[2];
179            let i_str = parts[3];
180
181            let t: usize = t_str.parse().map_err(|_| {
182                PrimitivesError::InvalidFormat(format!(
183                    "Threshold not a valid number in share {}",
184                    idx
185                ))
186            })?;
187
188            if idx != 0 && threshold != t {
189                return Err(PrimitivesError::InvalidFormat(format!(
190                    "Threshold mismatch in share {}",
191                    idx
192                )));
193            }
194            if idx != 0 && integrity != i_str {
195                return Err(PrimitivesError::InvalidFormat(format!(
196                    "Integrity mismatch in share {}",
197                    idx
198                )));
199            }
200
201            threshold = t;
202            integrity = i_str.to_string();
203
204            let point_str = format!("{}.{}", parts[0], parts[1]);
205            let point = PointInFiniteField::from_string_repr(&point_str)?;
206            points.push(point);
207        }
208
209        Ok(KeyShares::new(points, threshold, integrity))
210    }
211
212    /// Reconstruct a private key from shares.
213    ///
214    /// Requires at least `threshold` shares. Uses Lagrange interpolation
215    /// to recover the secret (polynomial value at x=0).
216    ///
217    /// # Arguments
218    /// * `shares` - The KeyShares containing points, threshold, and integrity hash
219    ///
220    /// # Returns
221    /// The reconstructed PrivateKey, validated against the integrity hash.
222    pub fn reconstruct(shares: &KeyShares) -> Result<PrivateKey, PrimitivesError> {
223        let threshold = shares.threshold;
224
225        if threshold < 2 {
226            return Err(PrimitivesError::ThresholdError(
227                "threshold must be at least 2".to_string(),
228            ));
229        }
230
231        if shares.points.len() < threshold {
232            return Err(PrimitivesError::ThresholdError(format!(
233                "At least {} shares are required to reconstruct the private key",
234                threshold
235            )));
236        }
237
238        // Check for duplicate x values
239        for i in 0..threshold {
240            for j in (i + 1)..threshold {
241                if shares.points[i].x.cmp(&shares.points[j].x) == 0 {
242                    return Err(PrimitivesError::ThresholdError(
243                        "Duplicate share detected, each must be unique.".to_string(),
244                    ));
245                }
246            }
247        }
248
249        // Lagrange interpolation at x=0
250        let poly = Polynomial::new(shares.points.clone(), Some(threshold));
251        let secret = poly.value_at(&BigNumber::zero());
252
253        // Create PrivateKey from recovered secret
254        let secret_bytes = secret.to_array(Endian::Big, Some(32));
255        let key = PrivateKey::from_bytes(&secret_bytes)?;
256
257        // Validate integrity hash
258        let pubkey = key.to_public_key();
259        let pubkey_hash = pubkey.to_hash();
260        let integrity_hash = to_hex(&pubkey_hash);
261        let integrity_check = &integrity_hash[..8];
262
263        if integrity_check != shares.integrity {
264            return Err(PrimitivesError::ThresholdError(
265                "Integrity hash mismatch".to_string(),
266            ));
267        }
268
269        Ok(key)
270    }
271}
272
273// ---------------------------------------------------------------------------
274// Tests
275// ---------------------------------------------------------------------------
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_key_shares_split_produces_correct_count() {
283        let key = PrivateKey::from_random().unwrap();
284        let shares = KeyShares::split(&key, 2, 5).unwrap();
285        assert_eq!(shares.points.len(), 5);
286        assert_eq!(shares.threshold, 2);
287        assert!(!shares.integrity.is_empty());
288    }
289
290    #[test]
291    fn test_key_shares_split_and_reconstruct_threshold_2_of_3() {
292        let key = PrivateKey::from_random().unwrap();
293        let shares = KeyShares::split(&key, 2, 3).unwrap();
294
295        // Use first 2 shares (indices 0, 1)
296        let subset = KeyShares::new(
297            shares.points[..2].to_vec(),
298            shares.threshold,
299            shares.integrity.clone(),
300        );
301        let recovered = KeyShares::reconstruct(&subset).unwrap();
302        assert_eq!(
303            recovered.to_hex(),
304            key.to_hex(),
305            "Should recover original key from 2 of 3 shares"
306        );
307    }
308
309    #[test]
310    fn test_key_shares_split_and_reconstruct_threshold_3_of_5() {
311        let key = PrivateKey::from_random().unwrap();
312        let shares = KeyShares::split(&key, 3, 5).unwrap();
313
314        // Use shares at indices 0, 2, 4
315        let subset = KeyShares::new(
316            vec![
317                shares.points[0].clone(),
318                shares.points[2].clone(),
319                shares.points[4].clone(),
320            ],
321            shares.threshold,
322            shares.integrity.clone(),
323        );
324        let recovered = KeyShares::reconstruct(&subset).unwrap();
325        assert_eq!(
326            recovered.to_hex(),
327            key.to_hex(),
328            "Should recover original key from 3 of 5 shares"
329        );
330    }
331
332    #[test]
333    fn test_key_shares_insufficient_shares_fails() {
334        let key = PrivateKey::from_random().unwrap();
335        let shares = KeyShares::split(&key, 3, 5).unwrap();
336
337        // Only provide 2 shares when threshold is 3
338        let subset = KeyShares::new(
339            shares.points[..2].to_vec(),
340            shares.threshold,
341            shares.integrity.clone(),
342        );
343        let result = KeyShares::reconstruct(&subset);
344        assert!(
345            result.is_err(),
346            "Should fail with fewer than threshold shares"
347        );
348    }
349
350    #[test]
351    fn test_key_shares_threshold_validation() {
352        let key = PrivateKey::from_random().unwrap();
353
354        // threshold < 2
355        assert!(KeyShares::split(&key, 1, 3).is_err());
356
357        // total < 2
358        assert!(KeyShares::split(&key, 2, 1).is_err());
359
360        // threshold > total
361        assert!(KeyShares::split(&key, 4, 3).is_err());
362    }
363
364    #[test]
365    fn test_key_shares_backup_format_roundtrip() {
366        let key = PrivateKey::from_random().unwrap();
367        let shares = KeyShares::split(&key, 2, 3).unwrap();
368
369        // Convert to backup format
370        let backup = shares.to_backup_format();
371        assert_eq!(backup.len(), 3);
372
373        // Each backup string should have 4 dot-separated parts
374        for b in &backup {
375            let parts: Vec<&str> = b.split('.').collect();
376            assert_eq!(parts.len(), 4, "Backup format should be x.y.t.i");
377        }
378
379        // Parse back and reconstruct
380        let parsed = KeyShares::from_backup_format(&backup[..2]).unwrap();
381        let recovered = KeyShares::reconstruct(&parsed).unwrap();
382        assert_eq!(
383            recovered.to_hex(),
384            key.to_hex(),
385            "Should recover from backup format"
386        );
387    }
388
389    #[test]
390    fn test_key_shares_integrity_hash() {
391        let key = PrivateKey::from_random().unwrap();
392        let shares = KeyShares::split(&key, 2, 3).unwrap();
393
394        // Integrity should be 8 hex chars
395        assert_eq!(
396            shares.integrity.len(),
397            8,
398            "Integrity hash should be 8 hex chars"
399        );
400
401        // All shares in backup format should have the same integrity
402        let backup = shares.to_backup_format();
403        for b in &backup {
404            assert!(b.ends_with(&shares.integrity));
405        }
406    }
407
408    #[test]
409    fn test_key_shares_integrity_mismatch_detected() {
410        let key = PrivateKey::from_random().unwrap();
411        let shares = KeyShares::split(&key, 2, 3).unwrap();
412
413        // Corrupt integrity
414        let corrupt_shares = KeyShares::new(
415            shares.points[..2].to_vec(),
416            shares.threshold,
417            "deadbeef".to_string(), // wrong integrity
418        );
419        let result = KeyShares::reconstruct(&corrupt_shares);
420        assert!(result.is_err(), "Should fail on integrity mismatch");
421    }
422
423    #[test]
424    fn test_key_shares_known_key() {
425        // Use a known key for reproducibility
426        let key = PrivateKey::from_hex(
427            "e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35",
428        )
429        .unwrap();
430
431        let shares = KeyShares::split(&key, 2, 3).unwrap();
432        let subset = KeyShares::new(
433            shares.points[1..3].to_vec(),
434            shares.threshold,
435            shares.integrity.clone(),
436        );
437        let recovered = KeyShares::reconstruct(&subset).unwrap();
438        assert_eq!(recovered.to_hex(), key.to_hex(), "Should recover known key");
439    }
440
441    #[test]
442    fn test_key_shares_invalid_backup_format() {
443        let bad = vec!["not.valid".to_string()];
444        assert!(KeyShares::from_backup_format(&bad).is_err());
445    }
446
447    #[test]
448    fn test_key_shares_any_subset_reconstructs() {
449        // With threshold=2, total=4, any 2 shares should work
450        let key = PrivateKey::from_random().unwrap();
451        let shares = KeyShares::split(&key, 2, 4).unwrap();
452
453        // Try all pairs
454        for i in 0..4 {
455            for j in (i + 1)..4 {
456                let subset = KeyShares::new(
457                    vec![shares.points[i].clone(), shares.points[j].clone()],
458                    shares.threshold,
459                    shares.integrity.clone(),
460                );
461                let recovered = KeyShares::reconstruct(&subset).unwrap();
462                assert_eq!(
463                    recovered.to_hex(),
464                    key.to_hex(),
465                    "Shares ({}, {}) should reconstruct",
466                    i,
467                    j
468                );
469            }
470        }
471    }
472}