1use crate::primitives::big_number::{BigNumber, Endian};
7use crate::primitives::curve::Curve;
8use crate::primitives::error::PrimitivesError;
9use crate::primitives::hash::sha512_hmac;
10use crate::primitives::polynomial::{PointInFiniteField, Polynomial};
11use crate::primitives::private_key::PrivateKey;
12use crate::primitives::random::random_bytes;
13use crate::primitives::utils::to_hex;
14
15pub struct KeyShares {
24 pub points: Vec<PointInFiniteField>,
25 pub threshold: usize,
26 pub integrity: String,
27}
28
29impl KeyShares {
30 pub fn new(points: Vec<PointInFiniteField>, threshold: usize, integrity: String) -> Self {
32 KeyShares {
33 points,
34 threshold,
35 integrity,
36 }
37 }
38
39 pub fn split(
49 key: &PrivateKey,
50 threshold: usize,
51 total: usize,
52 ) -> Result<Self, PrimitivesError> {
53 if threshold < 2 {
54 return Err(PrimitivesError::ThresholdError(
55 "threshold must be at least 2".to_string(),
56 ));
57 }
58 if total < 2 {
59 return Err(PrimitivesError::ThresholdError(
60 "totalShares must be at least 2".to_string(),
61 ));
62 }
63 if threshold > total {
64 return Err(PrimitivesError::ThresholdError(
65 "threshold should be less than or equal to totalShares".to_string(),
66 ));
67 }
68
69 let curve = Curve::secp256k1();
70 let key_bytes = key.to_bytes();
71 let poly = Polynomial::from_private_key(&key_bytes, threshold);
72
73 let mut points = Vec::with_capacity(total);
74 let mut used_x_coords: Vec<BigNumber> = Vec::new();
75
76 let seed = random_bytes(64);
79
80 for i in 0..total {
81 let mut x: BigNumber;
82 let mut attempts = 0u32;
83
84 loop {
85 let mut counter = Vec::new();
86 counter.push(i as u8);
87 counter.push(attempts as u8);
88 counter.extend_from_slice(&random_bytes(32));
89
90 let h = sha512_hmac(&seed, &counter);
91 x = BigNumber::from_bytes(&h, Endian::Big);
92 x = x
93 .umod(&curve.p)
94 .map_err(|e| PrimitivesError::ArithmeticError(format!("mod p: {}", e)))?;
95
96 attempts += 1;
97 if attempts > 5 {
98 return Err(PrimitivesError::ThresholdError(
99 "Failed to generate unique x coordinate after 5 attempts".to_string(),
100 ));
101 }
102
103 if x.is_zero() {
105 continue;
106 }
107 let mut duplicate = false;
108 for existing in &used_x_coords {
109 if existing.cmp(&x) == 0 {
110 duplicate = true;
111 break;
112 }
113 }
114 if !duplicate {
115 break;
116 }
117 }
118
119 used_x_coords.push(x.clone());
120 let y = poly.value_at(&x);
121 points.push(PointInFiniteField::new(x, y));
122 }
123
124 let pubkey = key.to_public_key();
127 let pubkey_hash = pubkey.to_hash();
128 let integrity = to_hex(&pubkey_hash);
129 let integrity = integrity[..8].to_string();
130
131 Ok(KeyShares {
132 points,
133 threshold,
134 integrity,
135 })
136 }
137
138 pub fn to_backup_format(&self) -> Vec<String> {
142 self.points
143 .iter()
144 .map(|share| {
145 format!(
146 "{}.{}.{}",
147 share.to_string_repr(),
148 self.threshold,
149 self.integrity
150 )
151 })
152 .collect()
153 }
154
155 pub fn from_backup_format(shares: &[String]) -> Result<Self, PrimitivesError> {
159 if shares.is_empty() {
160 return Err(PrimitivesError::InvalidFormat(
161 "No shares provided".to_string(),
162 ));
163 }
164
165 let mut threshold = 0usize;
166 let mut integrity = String::new();
167 let mut points = Vec::with_capacity(shares.len());
168
169 for (idx, share) in shares.iter().enumerate() {
170 let parts: Vec<&str> = share.split('.').collect();
171 if parts.len() != 4 {
172 return Err(PrimitivesError::InvalidFormat(format!(
173 "Invalid share format in share {}. Expected format: \"x.y.t.i\" - received {}",
174 idx, share
175 )));
176 }
177
178 let t_str = parts[2];
179 let i_str = parts[3];
180
181 let t: usize = t_str.parse().map_err(|_| {
182 PrimitivesError::InvalidFormat(format!(
183 "Threshold not a valid number in share {}",
184 idx
185 ))
186 })?;
187
188 if idx != 0 && threshold != t {
189 return Err(PrimitivesError::InvalidFormat(format!(
190 "Threshold mismatch in share {}",
191 idx
192 )));
193 }
194 if idx != 0 && integrity != i_str {
195 return Err(PrimitivesError::InvalidFormat(format!(
196 "Integrity mismatch in share {}",
197 idx
198 )));
199 }
200
201 threshold = t;
202 integrity = i_str.to_string();
203
204 let point_str = format!("{}.{}", parts[0], parts[1]);
205 let point = PointInFiniteField::from_string_repr(&point_str)?;
206 points.push(point);
207 }
208
209 Ok(KeyShares::new(points, threshold, integrity))
210 }
211
212 pub fn reconstruct(shares: &KeyShares) -> Result<PrivateKey, PrimitivesError> {
223 let threshold = shares.threshold;
224
225 if threshold < 2 {
226 return Err(PrimitivesError::ThresholdError(
227 "threshold must be at least 2".to_string(),
228 ));
229 }
230
231 if shares.points.len() < threshold {
232 return Err(PrimitivesError::ThresholdError(format!(
233 "At least {} shares are required to reconstruct the private key",
234 threshold
235 )));
236 }
237
238 for i in 0..threshold {
240 for j in (i + 1)..threshold {
241 if shares.points[i].x.cmp(&shares.points[j].x) == 0 {
242 return Err(PrimitivesError::ThresholdError(
243 "Duplicate share detected, each must be unique.".to_string(),
244 ));
245 }
246 }
247 }
248
249 let poly = Polynomial::new(shares.points.clone(), Some(threshold));
251 let secret = poly.value_at(&BigNumber::zero());
252
253 let secret_bytes = secret.to_array(Endian::Big, Some(32));
255 let key = PrivateKey::from_bytes(&secret_bytes)?;
256
257 let pubkey = key.to_public_key();
259 let pubkey_hash = pubkey.to_hash();
260 let integrity_hash = to_hex(&pubkey_hash);
261 let integrity_check = &integrity_hash[..8];
262
263 if integrity_check != shares.integrity {
264 return Err(PrimitivesError::ThresholdError(
265 "Integrity hash mismatch".to_string(),
266 ));
267 }
268
269 Ok(key)
270 }
271}
272
273#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_key_shares_split_produces_correct_count() {
283 let key = PrivateKey::from_random().unwrap();
284 let shares = KeyShares::split(&key, 2, 5).unwrap();
285 assert_eq!(shares.points.len(), 5);
286 assert_eq!(shares.threshold, 2);
287 assert!(!shares.integrity.is_empty());
288 }
289
290 #[test]
291 fn test_key_shares_split_and_reconstruct_threshold_2_of_3() {
292 let key = PrivateKey::from_random().unwrap();
293 let shares = KeyShares::split(&key, 2, 3).unwrap();
294
295 let subset = KeyShares::new(
297 shares.points[..2].to_vec(),
298 shares.threshold,
299 shares.integrity.clone(),
300 );
301 let recovered = KeyShares::reconstruct(&subset).unwrap();
302 assert_eq!(
303 recovered.to_hex(),
304 key.to_hex(),
305 "Should recover original key from 2 of 3 shares"
306 );
307 }
308
309 #[test]
310 fn test_key_shares_split_and_reconstruct_threshold_3_of_5() {
311 let key = PrivateKey::from_random().unwrap();
312 let shares = KeyShares::split(&key, 3, 5).unwrap();
313
314 let subset = KeyShares::new(
316 vec![
317 shares.points[0].clone(),
318 shares.points[2].clone(),
319 shares.points[4].clone(),
320 ],
321 shares.threshold,
322 shares.integrity.clone(),
323 );
324 let recovered = KeyShares::reconstruct(&subset).unwrap();
325 assert_eq!(
326 recovered.to_hex(),
327 key.to_hex(),
328 "Should recover original key from 3 of 5 shares"
329 );
330 }
331
332 #[test]
333 fn test_key_shares_insufficient_shares_fails() {
334 let key = PrivateKey::from_random().unwrap();
335 let shares = KeyShares::split(&key, 3, 5).unwrap();
336
337 let subset = KeyShares::new(
339 shares.points[..2].to_vec(),
340 shares.threshold,
341 shares.integrity.clone(),
342 );
343 let result = KeyShares::reconstruct(&subset);
344 assert!(
345 result.is_err(),
346 "Should fail with fewer than threshold shares"
347 );
348 }
349
350 #[test]
351 fn test_key_shares_threshold_validation() {
352 let key = PrivateKey::from_random().unwrap();
353
354 assert!(KeyShares::split(&key, 1, 3).is_err());
356
357 assert!(KeyShares::split(&key, 2, 1).is_err());
359
360 assert!(KeyShares::split(&key, 4, 3).is_err());
362 }
363
364 #[test]
365 fn test_key_shares_backup_format_roundtrip() {
366 let key = PrivateKey::from_random().unwrap();
367 let shares = KeyShares::split(&key, 2, 3).unwrap();
368
369 let backup = shares.to_backup_format();
371 assert_eq!(backup.len(), 3);
372
373 for b in &backup {
375 let parts: Vec<&str> = b.split('.').collect();
376 assert_eq!(parts.len(), 4, "Backup format should be x.y.t.i");
377 }
378
379 let parsed = KeyShares::from_backup_format(&backup[..2]).unwrap();
381 let recovered = KeyShares::reconstruct(&parsed).unwrap();
382 assert_eq!(
383 recovered.to_hex(),
384 key.to_hex(),
385 "Should recover from backup format"
386 );
387 }
388
389 #[test]
390 fn test_key_shares_integrity_hash() {
391 let key = PrivateKey::from_random().unwrap();
392 let shares = KeyShares::split(&key, 2, 3).unwrap();
393
394 assert_eq!(
396 shares.integrity.len(),
397 8,
398 "Integrity hash should be 8 hex chars"
399 );
400
401 let backup = shares.to_backup_format();
403 for b in &backup {
404 assert!(b.ends_with(&shares.integrity));
405 }
406 }
407
408 #[test]
409 fn test_key_shares_integrity_mismatch_detected() {
410 let key = PrivateKey::from_random().unwrap();
411 let shares = KeyShares::split(&key, 2, 3).unwrap();
412
413 let corrupt_shares = KeyShares::new(
415 shares.points[..2].to_vec(),
416 shares.threshold,
417 "deadbeef".to_string(), );
419 let result = KeyShares::reconstruct(&corrupt_shares);
420 assert!(result.is_err(), "Should fail on integrity mismatch");
421 }
422
423 #[test]
424 fn test_key_shares_known_key() {
425 let key = PrivateKey::from_hex(
427 "e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35",
428 )
429 .unwrap();
430
431 let shares = KeyShares::split(&key, 2, 3).unwrap();
432 let subset = KeyShares::new(
433 shares.points[1..3].to_vec(),
434 shares.threshold,
435 shares.integrity.clone(),
436 );
437 let recovered = KeyShares::reconstruct(&subset).unwrap();
438 assert_eq!(recovered.to_hex(), key.to_hex(), "Should recover known key");
439 }
440
441 #[test]
442 fn test_key_shares_invalid_backup_format() {
443 let bad = vec!["not.valid".to_string()];
444 assert!(KeyShares::from_backup_format(&bad).is_err());
445 }
446
447 #[test]
448 fn test_key_shares_any_subset_reconstructs() {
449 let key = PrivateKey::from_random().unwrap();
451 let shares = KeyShares::split(&key, 2, 4).unwrap();
452
453 for i in 0..4 {
455 for j in (i + 1)..4 {
456 let subset = KeyShares::new(
457 vec![shares.points[i].clone(), shares.points[j].clone()],
458 shares.threshold,
459 shares.integrity.clone(),
460 );
461 let recovered = KeyShares::reconstruct(&subset).unwrap();
462 assert_eq!(
463 recovered.to_hex(),
464 key.to_hex(),
465 "Shares ({}, {}) should reconstruct",
466 i,
467 j
468 );
469 }
470 }
471 }
472}