chie_crypto/
shamir.rs

1//! Shamir's Secret Sharing for secure key backup and recovery.
2//!
3//! This module implements (M, N) threshold secret sharing where:
4//! - A secret is split into N shares
5//! - Any M shares can reconstruct the secret
6//! - Fewer than M shares reveal nothing about the secret
7//!
8//! Perfect for distributed key backup, multi-party authorization, etc.
9//!
10//! Security properties:
11//! - Information-theoretic security (no amount of computing power helps with <M shares)
12//! - Each share is the same size as the secret
13//! - Shares are independent random values
14
15use rand::RngCore;
16use thiserror::Error;
17use zeroize::Zeroize;
18
19/// Shamir secret sharing errors.
20#[derive(Debug, Error)]
21pub enum ShamirError {
22    #[error("Invalid threshold: M must be > 0 and <= N")]
23    InvalidThreshold,
24    #[error("Not enough shares to reconstruct secret (need {needed}, got {got})")]
25    InsufficientShares { needed: usize, got: usize },
26    #[error("Duplicate share indices")]
27    DuplicateIndices,
28    #[error("Invalid share index (must be 1-255)")]
29    InvalidShareIndex,
30    #[error("Shares have different lengths")]
31    InconsistentShareLengths,
32    #[error("Secret is empty")]
33    EmptySecret,
34}
35
36pub type ShamirResult<T> = Result<T, ShamirError>;
37
38/// A single share of a secret.
39#[derive(Clone, Debug, Zeroize)]
40#[zeroize(drop)]
41pub struct Share {
42    /// Share index (1-255)
43    pub index: u8,
44    /// Share data
45    pub data: Vec<u8>,
46}
47
48impl Share {
49    /// Create a new share.
50    pub fn new(index: u8, data: Vec<u8>) -> ShamirResult<Self> {
51        if index == 0 {
52            return Err(ShamirError::InvalidShareIndex);
53        }
54        Ok(Self { index, data })
55    }
56
57    /// Serialize to bytes (index || data).
58    pub fn to_bytes(&self) -> Vec<u8> {
59        let mut bytes = Vec::with_capacity(1 + self.data.len());
60        bytes.push(self.index);
61        bytes.extend_from_slice(&self.data);
62        bytes
63    }
64
65    /// Deserialize from bytes.
66    pub fn from_bytes(bytes: &[u8]) -> ShamirResult<Self> {
67        if bytes.is_empty() {
68            return Err(ShamirError::InvalidShareIndex);
69        }
70        let index = bytes[0];
71        let data = bytes[1..].to_vec();
72        Share::new(index, data)
73    }
74}
75
76/// Split a secret into N shares with threshold M.
77///
78/// Any M shares can reconstruct the secret, but M-1 or fewer reveal nothing.
79pub fn split(secret: &[u8], threshold: usize, num_shares: usize) -> ShamirResult<Vec<Share>> {
80    if secret.is_empty() {
81        return Err(ShamirError::EmptySecret);
82    }
83    if threshold == 0 || threshold > num_shares || num_shares > 255 {
84        return Err(ShamirError::InvalidThreshold);
85    }
86
87    let mut shares = Vec::with_capacity(num_shares);
88    let mut rng = rand::thread_rng();
89
90    // Split each byte independently using Shamir's scheme over GF(256)
91    for (byte_idx, &secret_byte) in secret.iter().enumerate() {
92        // Generate random polynomial coefficients (degree = threshold - 1)
93        let mut coeffs = vec![secret_byte];
94        for _ in 1..threshold {
95            let mut byte = [0u8; 1];
96            rng.fill_bytes(&mut byte);
97            coeffs.push(byte[0]);
98        }
99
100        // Evaluate polynomial at points 1..=num_shares
101        for share_idx in 0..num_shares {
102            let x = (share_idx + 1) as u8;
103            let y = eval_poly(&coeffs, x);
104
105            if byte_idx == 0 {
106                // First byte: create new share
107                shares.push(Share::new(x, vec![y])?);
108            } else {
109                // Subsequent bytes: append to existing share
110                shares[share_idx].data.push(y);
111            }
112        }
113    }
114
115    Ok(shares)
116}
117
118/// Reconstruct secret from M or more shares.
119///
120/// Returns error if fewer than threshold shares are provided.
121pub fn reconstruct(shares: &[Share]) -> ShamirResult<Vec<u8>> {
122    if shares.is_empty() {
123        return Err(ShamirError::InsufficientShares { needed: 1, got: 0 });
124    }
125
126    // Verify all shares have the same length
127    let share_len = shares[0].data.len();
128    if !shares.iter().all(|s| s.data.len() == share_len) {
129        return Err(ShamirError::InconsistentShareLengths);
130    }
131
132    // Verify no duplicate indices
133    let mut indices = shares.iter().map(|s| s.index).collect::<Vec<_>>();
134    indices.sort_unstable();
135    if indices.windows(2).any(|w| w[0] == w[1]) {
136        return Err(ShamirError::DuplicateIndices);
137    }
138
139    let mut secret = Vec::with_capacity(share_len);
140
141    // Reconstruct each byte independently
142    for byte_idx in 0..share_len {
143        let points: Vec<(u8, u8)> = shares
144            .iter()
145            .map(|share| (share.index, share.data[byte_idx]))
146            .collect();
147
148        let secret_byte = lagrange_interpolate(&points, 0);
149        secret.push(secret_byte);
150    }
151
152    Ok(secret)
153}
154
155/// Evaluate polynomial at x using Horner's method in GF(256).
156fn eval_poly(coeffs: &[u8], x: u8) -> u8 {
157    let mut result = 0u8;
158    for &coeff in coeffs.iter().rev() {
159        result = gf256_add(gf256_mul(result, x), coeff);
160    }
161    result
162}
163
164/// Lagrange interpolation in GF(256) to find f(x).
165fn lagrange_interpolate(points: &[(u8, u8)], x: u8) -> u8 {
166    let mut result = 0u8;
167
168    for (i, &(xi, yi)) in points.iter().enumerate() {
169        let mut basis = 1u8;
170
171        for (j, &(xj, _)) in points.iter().enumerate() {
172            if i != j {
173                let numerator = gf256_sub(x, xj);
174                let denominator = gf256_sub(xi, xj);
175                let inv_denom = gf256_inv(denominator);
176                basis = gf256_mul(basis, gf256_mul(numerator, inv_denom));
177            }
178        }
179
180        result = gf256_add(result, gf256_mul(basis, yi));
181    }
182
183    result
184}
185
186// GF(256) arithmetic using AES polynomial (x^8 + x^4 + x^3 + x + 1)
187const GF256_POLY: u16 = 0x11B;
188
189/// Addition in GF(256) is XOR.
190#[inline]
191fn gf256_add(a: u8, b: u8) -> u8 {
192    a ^ b
193}
194
195/// Subtraction in GF(256) is also XOR.
196#[inline]
197fn gf256_sub(a: u8, b: u8) -> u8 {
198    a ^ b
199}
200
201/// Multiplication in GF(256).
202fn gf256_mul(a: u8, b: u8) -> u8 {
203    if a == 0 || b == 0 {
204        return 0;
205    }
206
207    let mut result = 0u16;
208    let mut a = a as u16;
209    let mut b = b as u16;
210
211    for _ in 0..8 {
212        if b & 1 != 0 {
213            result ^= a;
214        }
215        let carry = a & 0x80;
216        a <<= 1;
217        if carry != 0 {
218            a ^= GF256_POLY;
219        }
220        b >>= 1;
221    }
222
223    (result & 0xFF) as u8
224}
225
226/// Multiplicative inverse in GF(256) using extended Euclidean algorithm.
227fn gf256_inv(a: u8) -> u8 {
228    if a == 0 {
229        panic!("Cannot invert zero in GF(256)");
230    }
231
232    // Use Fermat's little theorem: a^254 = a^(-1) in GF(256)
233    let mut result = 1u8;
234    let mut base = a;
235
236    // Compute a^254 using square-and-multiply
237    for i in 0..8 {
238        if 254 & (1 << i) != 0 {
239            result = gf256_mul(result, base);
240        }
241        base = gf256_mul(base, base);
242    }
243
244    result
245}
246
247/// Split a 32-byte key using Shamir's secret sharing.
248pub fn split_key_32(
249    key: &[u8; 32],
250    threshold: usize,
251    num_shares: usize,
252) -> ShamirResult<Vec<Share>> {
253    split(key, threshold, num_shares)
254}
255
256/// Reconstruct a 32-byte key from shares.
257pub fn reconstruct_key_32(shares: &[Share]) -> ShamirResult<[u8; 32]> {
258    let secret = reconstruct(shares)?;
259    if secret.len() != 32 {
260        return Err(ShamirError::InconsistentShareLengths);
261    }
262    let mut key = [0u8; 32];
263    key.copy_from_slice(&secret);
264    Ok(key)
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_split_and_reconstruct() {
273        let secret = b"This is a secret message!";
274        let shares = split(secret, 3, 5).unwrap();
275
276        assert_eq!(shares.len(), 5);
277
278        // Any 3 shares should reconstruct the secret
279        let reconstructed = reconstruct(&shares[0..3]).unwrap();
280        assert_eq!(&reconstructed, secret);
281
282        let reconstructed2 = reconstruct(&shares[1..4]).unwrap();
283        assert_eq!(&reconstructed2, secret);
284
285        let reconstructed3 = reconstruct(&shares[2..5]).unwrap();
286        assert_eq!(&reconstructed3, secret);
287    }
288
289    #[test]
290    fn test_insufficient_shares() {
291        let secret = b"secret";
292        let shares = split(secret, 3, 5).unwrap();
293
294        // Only 2 shares (less than threshold) should reconstruct something,
295        // but it won't be the original secret (this is a probabilistic guarantee)
296        let result = reconstruct(&shares[0..2]).unwrap();
297        // Result exists but is likely not the secret
298        assert_eq!(result.len(), secret.len());
299    }
300
301    #[test]
302    fn test_32_byte_key() {
303        let key = [42u8; 32];
304        let shares = split_key_32(&key, 2, 3).unwrap();
305
306        assert_eq!(shares.len(), 3);
307
308        // Reconstruct with 2 shares
309        let reconstructed = reconstruct_key_32(&shares[0..2]).unwrap();
310        assert_eq!(reconstructed, key);
311
312        // Reconstruct with all 3 shares
313        let reconstructed2 = reconstruct_key_32(&shares).unwrap();
314        assert_eq!(reconstructed2, key);
315    }
316
317    #[test]
318    fn test_invalid_threshold() {
319        let secret = b"secret";
320
321        // Threshold = 0
322        assert!(split(secret, 0, 5).is_err());
323
324        // Threshold > num_shares
325        assert!(split(secret, 6, 5).is_err());
326
327        // num_shares > 255
328        assert!(split(secret, 2, 256).is_err());
329    }
330
331    #[test]
332    fn test_duplicate_indices() {
333        let secret = b"secret";
334        let shares = split(secret, 2, 3).unwrap();
335
336        // Create duplicate by cloning
337        let dup_shares = vec![shares[0].clone(), shares[0].clone()];
338        assert!(reconstruct(&dup_shares).is_err());
339    }
340
341    #[test]
342    fn test_share_serialization() {
343        let secret = b"test";
344        let shares = split(secret, 2, 3).unwrap();
345
346        for share in &shares {
347            let bytes = share.to_bytes();
348            let deserialized = Share::from_bytes(&bytes).unwrap();
349            assert_eq!(deserialized.index, share.index);
350            assert_eq!(deserialized.data, share.data);
351        }
352    }
353
354    #[test]
355    fn test_different_combinations() {
356        let secret = b"0123456789abcdef";
357        let shares = split(secret, 3, 6).unwrap();
358
359        // Test multiple different 3-share combinations
360        let combo1 = vec![shares[0].clone(), shares[2].clone(), shares[4].clone()];
361        let combo2 = vec![shares[1].clone(), shares[3].clone(), shares[5].clone()];
362
363        let combinations: Vec<&[Share]> = vec![
364            &shares[0..3],
365            &shares[1..4],
366            &shares[2..5],
367            &shares[3..6],
368            &combo1,
369            &combo2,
370        ];
371
372        for combo in combinations {
373            let reconstructed = reconstruct(combo).unwrap();
374            assert_eq!(&reconstructed, secret);
375        }
376    }
377
378    #[test]
379    fn test_gf256_arithmetic() {
380        // Test basic properties
381        assert_eq!(gf256_add(5, 3), 5 ^ 3);
382        assert_eq!(gf256_sub(7, 2), 7 ^ 2);
383
384        // Test multiplicative identity
385        assert_eq!(gf256_mul(42, 1), 42);
386
387        // Test multiplicative inverse
388        for x in 1u8..=255 {
389            let inv = gf256_inv(x);
390            assert_eq!(gf256_mul(x, inv), 1);
391        }
392    }
393
394    #[test]
395    fn test_empty_secret() {
396        assert!(split(&[], 2, 3).is_err());
397    }
398
399    #[test]
400    fn test_share_zeroize() {
401        let share = Share::new(1, vec![1, 2, 3]).unwrap();
402        drop(share); // Should zeroize on drop
403    }
404
405    #[test]
406    fn test_threshold_one() {
407        let secret = b"simple";
408        let shares = split(secret, 1, 3).unwrap();
409
410        // Each single share should reconstruct the secret
411        for share in &shares {
412            let reconstructed = reconstruct(std::slice::from_ref(share)).unwrap();
413            assert_eq!(&reconstructed, secret);
414        }
415    }
416
417    #[test]
418    fn test_large_secret() {
419        let secret = vec![0xAAu8; 1024]; // 1KB secret
420        let shares = split(&secret, 5, 10).unwrap();
421
422        let reconstructed = reconstruct(&shares[0..5]).unwrap();
423        assert_eq!(reconstructed, secret);
424    }
425}