1use rand::RngCore;
16use thiserror::Error;
17use zeroize::Zeroize;
18
19#[derive(Debug, Error)]
21pub enum ShamirError {
22 #[error("Invalid threshold: M must be > 0 and <= N")]
23 InvalidThreshold,
24 #[error("Not enough shares to reconstruct secret (need {needed}, got {got})")]
25 InsufficientShares { needed: usize, got: usize },
26 #[error("Duplicate share indices")]
27 DuplicateIndices,
28 #[error("Invalid share index (must be 1-255)")]
29 InvalidShareIndex,
30 #[error("Shares have different lengths")]
31 InconsistentShareLengths,
32 #[error("Secret is empty")]
33 EmptySecret,
34}
35
36pub type ShamirResult<T> = Result<T, ShamirError>;
37
38#[derive(Clone, Debug, Zeroize)]
40#[zeroize(drop)]
41pub struct Share {
42 pub index: u8,
44 pub data: Vec<u8>,
46}
47
48impl Share {
49 pub fn new(index: u8, data: Vec<u8>) -> ShamirResult<Self> {
51 if index == 0 {
52 return Err(ShamirError::InvalidShareIndex);
53 }
54 Ok(Self { index, data })
55 }
56
57 pub fn to_bytes(&self) -> Vec<u8> {
59 let mut bytes = Vec::with_capacity(1 + self.data.len());
60 bytes.push(self.index);
61 bytes.extend_from_slice(&self.data);
62 bytes
63 }
64
65 pub fn from_bytes(bytes: &[u8]) -> ShamirResult<Self> {
67 if bytes.is_empty() {
68 return Err(ShamirError::InvalidShareIndex);
69 }
70 let index = bytes[0];
71 let data = bytes[1..].to_vec();
72 Share::new(index, data)
73 }
74}
75
76pub fn split(secret: &[u8], threshold: usize, num_shares: usize) -> ShamirResult<Vec<Share>> {
80 if secret.is_empty() {
81 return Err(ShamirError::EmptySecret);
82 }
83 if threshold == 0 || threshold > num_shares || num_shares > 255 {
84 return Err(ShamirError::InvalidThreshold);
85 }
86
87 let mut shares = Vec::with_capacity(num_shares);
88 let mut rng = rand::thread_rng();
89
90 for (byte_idx, &secret_byte) in secret.iter().enumerate() {
92 let mut coeffs = vec![secret_byte];
94 for _ in 1..threshold {
95 let mut byte = [0u8; 1];
96 rng.fill_bytes(&mut byte);
97 coeffs.push(byte[0]);
98 }
99
100 for share_idx in 0..num_shares {
102 let x = (share_idx + 1) as u8;
103 let y = eval_poly(&coeffs, x);
104
105 if byte_idx == 0 {
106 shares.push(Share::new(x, vec![y])?);
108 } else {
109 shares[share_idx].data.push(y);
111 }
112 }
113 }
114
115 Ok(shares)
116}
117
118pub fn reconstruct(shares: &[Share]) -> ShamirResult<Vec<u8>> {
122 if shares.is_empty() {
123 return Err(ShamirError::InsufficientShares { needed: 1, got: 0 });
124 }
125
126 let share_len = shares[0].data.len();
128 if !shares.iter().all(|s| s.data.len() == share_len) {
129 return Err(ShamirError::InconsistentShareLengths);
130 }
131
132 let mut indices = shares.iter().map(|s| s.index).collect::<Vec<_>>();
134 indices.sort_unstable();
135 if indices.windows(2).any(|w| w[0] == w[1]) {
136 return Err(ShamirError::DuplicateIndices);
137 }
138
139 let mut secret = Vec::with_capacity(share_len);
140
141 for byte_idx in 0..share_len {
143 let points: Vec<(u8, u8)> = shares
144 .iter()
145 .map(|share| (share.index, share.data[byte_idx]))
146 .collect();
147
148 let secret_byte = lagrange_interpolate(&points, 0);
149 secret.push(secret_byte);
150 }
151
152 Ok(secret)
153}
154
155fn eval_poly(coeffs: &[u8], x: u8) -> u8 {
157 let mut result = 0u8;
158 for &coeff in coeffs.iter().rev() {
159 result = gf256_add(gf256_mul(result, x), coeff);
160 }
161 result
162}
163
164fn lagrange_interpolate(points: &[(u8, u8)], x: u8) -> u8 {
166 let mut result = 0u8;
167
168 for (i, &(xi, yi)) in points.iter().enumerate() {
169 let mut basis = 1u8;
170
171 for (j, &(xj, _)) in points.iter().enumerate() {
172 if i != j {
173 let numerator = gf256_sub(x, xj);
174 let denominator = gf256_sub(xi, xj);
175 let inv_denom = gf256_inv(denominator);
176 basis = gf256_mul(basis, gf256_mul(numerator, inv_denom));
177 }
178 }
179
180 result = gf256_add(result, gf256_mul(basis, yi));
181 }
182
183 result
184}
185
186const GF256_POLY: u16 = 0x11B;
188
189#[inline]
191fn gf256_add(a: u8, b: u8) -> u8 {
192 a ^ b
193}
194
195#[inline]
197fn gf256_sub(a: u8, b: u8) -> u8 {
198 a ^ b
199}
200
201fn gf256_mul(a: u8, b: u8) -> u8 {
203 if a == 0 || b == 0 {
204 return 0;
205 }
206
207 let mut result = 0u16;
208 let mut a = a as u16;
209 let mut b = b as u16;
210
211 for _ in 0..8 {
212 if b & 1 != 0 {
213 result ^= a;
214 }
215 let carry = a & 0x80;
216 a <<= 1;
217 if carry != 0 {
218 a ^= GF256_POLY;
219 }
220 b >>= 1;
221 }
222
223 (result & 0xFF) as u8
224}
225
226fn gf256_inv(a: u8) -> u8 {
228 if a == 0 {
229 panic!("Cannot invert zero in GF(256)");
230 }
231
232 let mut result = 1u8;
234 let mut base = a;
235
236 for i in 0..8 {
238 if 254 & (1 << i) != 0 {
239 result = gf256_mul(result, base);
240 }
241 base = gf256_mul(base, base);
242 }
243
244 result
245}
246
247pub fn split_key_32(
249 key: &[u8; 32],
250 threshold: usize,
251 num_shares: usize,
252) -> ShamirResult<Vec<Share>> {
253 split(key, threshold, num_shares)
254}
255
256pub fn reconstruct_key_32(shares: &[Share]) -> ShamirResult<[u8; 32]> {
258 let secret = reconstruct(shares)?;
259 if secret.len() != 32 {
260 return Err(ShamirError::InconsistentShareLengths);
261 }
262 let mut key = [0u8; 32];
263 key.copy_from_slice(&secret);
264 Ok(key)
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_split_and_reconstruct() {
273 let secret = b"This is a secret message!";
274 let shares = split(secret, 3, 5).unwrap();
275
276 assert_eq!(shares.len(), 5);
277
278 let reconstructed = reconstruct(&shares[0..3]).unwrap();
280 assert_eq!(&reconstructed, secret);
281
282 let reconstructed2 = reconstruct(&shares[1..4]).unwrap();
283 assert_eq!(&reconstructed2, secret);
284
285 let reconstructed3 = reconstruct(&shares[2..5]).unwrap();
286 assert_eq!(&reconstructed3, secret);
287 }
288
289 #[test]
290 fn test_insufficient_shares() {
291 let secret = b"secret";
292 let shares = split(secret, 3, 5).unwrap();
293
294 let result = reconstruct(&shares[0..2]).unwrap();
297 assert_eq!(result.len(), secret.len());
299 }
300
301 #[test]
302 fn test_32_byte_key() {
303 let key = [42u8; 32];
304 let shares = split_key_32(&key, 2, 3).unwrap();
305
306 assert_eq!(shares.len(), 3);
307
308 let reconstructed = reconstruct_key_32(&shares[0..2]).unwrap();
310 assert_eq!(reconstructed, key);
311
312 let reconstructed2 = reconstruct_key_32(&shares).unwrap();
314 assert_eq!(reconstructed2, key);
315 }
316
317 #[test]
318 fn test_invalid_threshold() {
319 let secret = b"secret";
320
321 assert!(split(secret, 0, 5).is_err());
323
324 assert!(split(secret, 6, 5).is_err());
326
327 assert!(split(secret, 2, 256).is_err());
329 }
330
331 #[test]
332 fn test_duplicate_indices() {
333 let secret = b"secret";
334 let shares = split(secret, 2, 3).unwrap();
335
336 let dup_shares = vec![shares[0].clone(), shares[0].clone()];
338 assert!(reconstruct(&dup_shares).is_err());
339 }
340
341 #[test]
342 fn test_share_serialization() {
343 let secret = b"test";
344 let shares = split(secret, 2, 3).unwrap();
345
346 for share in &shares {
347 let bytes = share.to_bytes();
348 let deserialized = Share::from_bytes(&bytes).unwrap();
349 assert_eq!(deserialized.index, share.index);
350 assert_eq!(deserialized.data, share.data);
351 }
352 }
353
354 #[test]
355 fn test_different_combinations() {
356 let secret = b"0123456789abcdef";
357 let shares = split(secret, 3, 6).unwrap();
358
359 let combo1 = vec![shares[0].clone(), shares[2].clone(), shares[4].clone()];
361 let combo2 = vec![shares[1].clone(), shares[3].clone(), shares[5].clone()];
362
363 let combinations: Vec<&[Share]> = vec![
364 &shares[0..3],
365 &shares[1..4],
366 &shares[2..5],
367 &shares[3..6],
368 &combo1,
369 &combo2,
370 ];
371
372 for combo in combinations {
373 let reconstructed = reconstruct(combo).unwrap();
374 assert_eq!(&reconstructed, secret);
375 }
376 }
377
378 #[test]
379 fn test_gf256_arithmetic() {
380 assert_eq!(gf256_add(5, 3), 5 ^ 3);
382 assert_eq!(gf256_sub(7, 2), 7 ^ 2);
383
384 assert_eq!(gf256_mul(42, 1), 42);
386
387 for x in 1u8..=255 {
389 let inv = gf256_inv(x);
390 assert_eq!(gf256_mul(x, inv), 1);
391 }
392 }
393
394 #[test]
395 fn test_empty_secret() {
396 assert!(split(&[], 2, 3).is_err());
397 }
398
399 #[test]
400 fn test_share_zeroize() {
401 let share = Share::new(1, vec![1, 2, 3]).unwrap();
402 drop(share); }
404
405 #[test]
406 fn test_threshold_one() {
407 let secret = b"simple";
408 let shares = split(secret, 1, 3).unwrap();
409
410 for share in &shares {
412 let reconstructed = reconstruct(std::slice::from_ref(share)).unwrap();
413 assert_eq!(&reconstructed, secret);
414 }
415 }
416
417 #[test]
418 fn test_large_secret() {
419 let secret = vec![0xAAu8; 1024]; let shares = split(&secret, 5, 10).unwrap();
421
422 let reconstructed = reconstruct(&shares[0..5]).unwrap();
423 assert_eq!(reconstructed, secret);
424 }
425}