use crate::error::{BitcoinError, Result};
use bitcoin::hashes::{Hash, sha256};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[allow(dead_code)]
const SLIP39_ID_LENGTH: usize = 15;
const SLIP39_ITERATION_EXPONENT: u8 = 0;
const WORDLIST_SIZE: usize = 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShareThreshold {
pub threshold: u8,
pub total_shares: u8,
}
impl ShareThreshold {
pub fn new(threshold: u8, total_shares: u8) -> Result<Self> {
if threshold < 1 {
return Err(BitcoinError::InvalidInput(
"Threshold must be at least 1".to_string(),
));
}
if threshold > total_shares {
return Err(BitcoinError::InvalidInput(
"Threshold cannot exceed total shares".to_string(),
));
}
if total_shares > 16 {
return Err(BitcoinError::InvalidInput(
"Total shares cannot exceed 16".to_string(),
));
}
Ok(Self {
threshold,
total_shares,
})
}
pub fn is_valid(&self) -> bool {
self.threshold >= 1 && self.threshold <= self.total_shares && self.total_shares <= 16
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupConfig {
pub group_id: u8,
pub group_threshold: ShareThreshold,
pub description: Option<String>,
}
impl GroupConfig {
pub fn new(group_id: u8, threshold: ShareThreshold) -> Result<Self> {
if group_id > 15 {
return Err(BitcoinError::InvalidInput(
"Group ID must be 0-15".to_string(),
));
}
Ok(Self {
group_id,
group_threshold: threshold,
description: None,
})
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Slip39Share {
pub identifier: u16,
pub iteration_exponent: u8,
pub group_index: u8,
pub group_threshold: u8,
pub group_count: u8,
pub member_index: u8,
pub member_threshold: u8,
pub share_value: Vec<u8>,
pub checksum: u32,
}
impl Slip39Share {
pub fn to_mnemonic(&self) -> String {
let words = self.to_words();
words.join(" ")
}
fn to_words(&self) -> Vec<String> {
let mut words = Vec::new();
words.push(format!("word{}", self.identifier % WORDLIST_SIZE as u16));
words.push(format!(
"word{}",
self.iteration_exponent as usize % WORDLIST_SIZE
));
words.push(format!("word{}", self.group_index as usize % WORDLIST_SIZE));
words.push(format!(
"word{}",
self.member_index as usize % WORDLIST_SIZE
));
for chunk in self.share_value.chunks(2) {
let value = if chunk.len() == 2 {
u16::from_be_bytes([chunk[0], chunk[1]])
} else {
chunk[0] as u16
};
words.push(format!("word{}", value as usize % WORDLIST_SIZE));
}
words.push(format!("checksum{}", self.checksum % WORDLIST_SIZE as u32));
words
}
pub fn from_mnemonic(mnemonic: &str) -> Result<Self> {
let words: Vec<&str> = mnemonic.split_whitespace().collect();
if words.len() < 20 || words.len() > 33 {
return Err(BitcoinError::InvalidInput(
"Invalid mnemonic length for SLIP 39".to_string(),
));
}
Ok(Self {
identifier: 0,
iteration_exponent: SLIP39_ITERATION_EXPONENT,
group_index: 0,
group_threshold: 1,
group_count: 1,
member_index: 0,
member_threshold: 1,
share_value: vec![0u8; 32],
checksum: 0,
})
}
pub fn validate_checksum(&self) -> bool {
let computed = self.compute_checksum();
computed == self.checksum
}
fn compute_checksum(&self) -> u32 {
let mut data = Vec::new();
data.extend_from_slice(&self.identifier.to_be_bytes());
data.push(self.iteration_exponent);
data.push(self.group_index);
data.push(self.member_index);
data.extend_from_slice(&self.share_value);
let hash = sha256::Hash::hash(&data);
let bytes = hash.as_byte_array();
u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
}
impl fmt::Debug for Slip39Share {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Slip39Share")
.field("identifier", &self.identifier)
.field("group_index", &self.group_index)
.field("member_index", &self.member_index)
.field(
"threshold",
&format!("{}/{}", self.member_threshold, self.group_threshold),
)
.field("share_value", &"<redacted>")
.finish()
}
}
pub struct Slip39Generator {
identifier: u16,
}
impl Slip39Generator {
pub fn new() -> Self {
Self {
identifier: Self::generate_identifier(),
}
}
fn generate_identifier() -> u16 {
use rand::RngExt;
let mut rng = rand::rng();
rng.random_range(0..=0x7FFF)
}
pub fn generate_shares(
&self,
secret: &[u8],
threshold: ShareThreshold,
passphrase: Option<&str>,
) -> Result<Vec<Slip39Share>> {
if secret.len() != 16 && secret.len() != 32 {
return Err(BitcoinError::InvalidInput(
"Secret must be 128 or 256 bits (16 or 32 bytes)".to_string(),
));
}
if !threshold.is_valid() {
return Err(BitcoinError::InvalidInput("Invalid threshold".to_string()));
}
let encrypted_secret = if let Some(pass) = passphrase {
self.encrypt_secret(secret, pass)?
} else {
secret.to_vec()
};
let shares = self.split_secret(
&encrypted_secret,
threshold.threshold,
threshold.total_shares,
)?;
let mut slip39_shares = Vec::new();
for (index, share_value) in shares.into_iter().enumerate() {
let share = Slip39Share {
identifier: self.identifier,
iteration_exponent: SLIP39_ITERATION_EXPONENT,
group_index: 0,
group_threshold: threshold.threshold,
group_count: 1,
member_index: index as u8,
member_threshold: threshold.threshold,
share_value,
checksum: 0,
};
let checksum = share.compute_checksum();
let mut share_with_checksum = share;
share_with_checksum.checksum = checksum;
slip39_shares.push(share_with_checksum);
}
Ok(slip39_shares)
}
pub fn recover_secret(shares: &[Slip39Share]) -> Result<Vec<u8>> {
if shares.is_empty() {
return Err(BitcoinError::InvalidInput("No shares provided".to_string()));
}
let identifier = shares[0].identifier;
for share in shares {
if share.identifier != identifier {
return Err(BitcoinError::InvalidInput(
"Shares have different identifiers".to_string(),
));
}
if !share.validate_checksum() {
return Err(BitcoinError::InvalidInput(
"Share checksum validation failed".to_string(),
));
}
}
let threshold = shares[0].member_threshold as usize;
if shares.len() < threshold {
return Err(BitcoinError::InvalidInput(format!(
"Insufficient shares: need {}, got {}",
threshold,
shares.len()
)));
}
let shares_to_use = &shares[0..threshold];
Self::combine_shares(shares_to_use)
}
fn split_secret(&self, secret: &[u8], threshold: u8, total: u8) -> Result<Vec<Vec<u8>>> {
let mut coefficients = vec![secret.to_vec()];
for _ in 1..threshold {
let mut coeff = vec![0u8; secret.len()];
Self::random_bytes(&mut coeff);
coefficients.push(coeff);
}
let mut shares = Vec::new();
for x in 1..=total {
let mut share = vec![0u8; secret.len()];
for (i, coeff) in coefficients.iter().enumerate() {
let term = Self::gf256_multiply_scalar(coeff, Self::gf256_pow(x, i as u8));
Self::gf256_add(&mut share, &term);
}
shares.push(share);
}
Ok(shares)
}
fn combine_shares(shares: &[Slip39Share]) -> Result<Vec<u8>> {
let secret_len = shares[0].share_value.len();
let mut secret = vec![0u8; secret_len];
for share in shares {
let x = share.member_index + 1;
let lagrange_coeff = Self::lagrange_coefficient(x, shares);
let term = Self::gf256_multiply_scalar(&share.share_value, lagrange_coeff);
Self::gf256_add(&mut secret, &term);
}
Ok(secret)
}
fn lagrange_coefficient(x: u8, shares: &[Slip39Share]) -> u8 {
let mut coeff = 1u8;
for share in shares {
let xi = share.member_index + 1;
if xi != x {
let numerator = xi;
let denominator = Self::gf256_sub(xi, x);
let inv = Self::gf256_inverse(denominator);
coeff = Self::gf256_multiply(coeff, Self::gf256_multiply(numerator, inv));
}
}
coeff
}
fn gf256_add(a: &mut [u8], b: &[u8]) {
for (ai, bi) in a.iter_mut().zip(b.iter()) {
*ai ^= bi;
}
}
fn gf256_sub(a: u8, b: u8) -> u8 {
a ^ b
}
fn gf256_multiply(mut a: u8, mut b: u8) -> u8 {
let mut result = 0u8;
for _ in 0..8 {
if b & 1 != 0 {
result ^= a;
}
let high_bit = a & 0x80;
a <<= 1;
if high_bit != 0 {
a ^= 0x1B; }
b >>= 1;
}
result
}
fn gf256_multiply_scalar(data: &[u8], scalar: u8) -> Vec<u8> {
data.iter()
.map(|&byte| Self::gf256_multiply(byte, scalar))
.collect()
}
fn gf256_pow(base: u8, exp: u8) -> u8 {
let mut result = 1u8;
let mut b = base;
let mut e = exp;
while e > 0 {
if e & 1 != 0 {
result = Self::gf256_multiply(result, b);
}
b = Self::gf256_multiply(b, b);
e >>= 1;
}
result
}
fn gf256_inverse(a: u8) -> u8 {
if a == 0 {
return 0;
}
Self::gf256_pow(a, 254)
}
fn random_bytes(buf: &mut [u8]) {
use rand::RngExt;
let mut rng = rand::rng();
for byte in buf.iter_mut() {
*byte = rng.random_range(0..=255u8);
}
}
fn encrypt_secret(&self, secret: &[u8], passphrase: &str) -> Result<Vec<u8>> {
let pass_bytes = passphrase.as_bytes();
let mut encrypted = secret.to_vec();
for (i, byte) in encrypted.iter_mut().enumerate() {
*byte ^= pass_bytes[i % pass_bytes.len()];
}
Ok(encrypted)
}
}
impl Default for Slip39Generator {
fn default() -> Self {
Self::new()
}
}
pub struct MultiGroupGenerator {
identifier: u16,
groups: Vec<GroupConfig>,
group_threshold: u8,
}
impl MultiGroupGenerator {
pub fn new(groups: Vec<GroupConfig>, group_threshold: u8) -> Result<Self> {
if groups.is_empty() {
return Err(BitcoinError::InvalidInput("No groups provided".to_string()));
}
if group_threshold > groups.len() as u8 {
return Err(BitcoinError::InvalidInput(
"Group threshold exceeds number of groups".to_string(),
));
}
if group_threshold == 0 {
return Err(BitcoinError::InvalidInput(
"Group threshold must be at least 1".to_string(),
));
}
Ok(Self {
identifier: Slip39Generator::generate_identifier(),
groups,
group_threshold,
})
}
pub fn generate_all_shares(
&self,
secret: &[u8],
passphrase: Option<&str>,
) -> Result<HashMap<u8, Vec<Slip39Share>>> {
let mut all_shares = HashMap::new();
let generator = Slip39Generator {
identifier: self.identifier,
};
let encrypted_secret = if let Some(pass) = passphrase {
generator.encrypt_secret(secret, pass)?
} else {
secret.to_vec()
};
let group_shares = generator.split_secret(
&encrypted_secret,
self.group_threshold,
self.groups.len() as u8,
)?;
for (group_idx, group) in self.groups.iter().enumerate() {
let group_secret = &group_shares[group_idx];
let threshold = group.group_threshold;
let member_shares = generator.split_secret(
group_secret,
threshold.threshold,
threshold.total_shares,
)?;
let mut group_slip39_shares = Vec::new();
for (member_idx, share_value) in member_shares.into_iter().enumerate() {
let share = Slip39Share {
identifier: self.identifier,
iteration_exponent: SLIP39_ITERATION_EXPONENT,
group_index: group.group_id,
group_threshold: self.group_threshold,
group_count: self.groups.len() as u8,
member_index: member_idx as u8,
member_threshold: threshold.threshold,
share_value,
checksum: 0,
};
let checksum = share.compute_checksum();
let mut share_with_checksum = share;
share_with_checksum.checksum = checksum;
group_slip39_shares.push(share_with_checksum);
}
all_shares.insert(group.group_id, group_slip39_shares);
}
Ok(all_shares)
}
pub fn recover_from_groups(shares: &[Slip39Share]) -> Result<Vec<u8>> {
if shares.is_empty() {
return Err(BitcoinError::InvalidInput("No shares provided".to_string()));
}
let mut groups: HashMap<u8, Vec<&Slip39Share>> = HashMap::new();
for share in shares {
groups.entry(share.group_index).or_default().push(share);
}
let group_threshold = shares[0].group_threshold as usize;
if groups.len() < group_threshold {
return Err(BitcoinError::InvalidInput(format!(
"Insufficient groups: need {}, got {}",
group_threshold,
groups.len()
)));
}
let mut group_secrets = Vec::new();
for (group_id, group_shares) in groups.iter() {
let group_shares_vec: Vec<Slip39Share> =
group_shares.iter().map(|&s| s.clone()).collect();
let group_secret = Slip39Generator::recover_secret(&group_shares_vec)?;
group_secrets.push((*group_id, group_secret));
}
group_secrets.truncate(group_threshold);
if let Some((_, first_secret)) = group_secrets.first() {
Ok(first_secret.clone())
} else {
Err(BitcoinError::InvalidInput(
"No group secrets recovered".to_string(),
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_share_threshold_creation() {
let threshold = ShareThreshold::new(3, 5).unwrap();
assert_eq!(threshold.threshold, 3);
assert_eq!(threshold.total_shares, 5);
assert!(threshold.is_valid());
assert!(ShareThreshold::new(0, 5).is_err());
assert!(ShareThreshold::new(6, 5).is_err());
assert!(ShareThreshold::new(3, 20).is_err());
}
#[test]
fn test_group_config() {
let config = GroupConfig::new(0, ShareThreshold::new(2, 3).unwrap())
.unwrap()
.with_description("Family");
assert_eq!(config.group_id, 0);
assert_eq!(config.description, Some("Family".to_string()));
assert!(GroupConfig::new(16, ShareThreshold::new(2, 3).unwrap()).is_err());
}
#[test]
fn test_share_generation_and_recovery() {
let secret = [42u8; 32];
let generator = Slip39Generator::new();
let shares = generator
.generate_shares(&secret, ShareThreshold::new(3, 5).unwrap(), None)
.unwrap();
assert_eq!(shares.len(), 5);
let id = shares[0].identifier;
for share in &shares {
assert_eq!(share.identifier, id);
assert!(share.validate_checksum());
}
let recovered = Slip39Generator::recover_secret(&shares[0..3]).unwrap();
assert_eq!(recovered.len(), secret.len());
let recovered2 = Slip39Generator::recover_secret(&shares[0..4]).unwrap();
assert_eq!(recovered2.len(), secret.len());
}
#[test]
fn test_insufficient_shares() {
let secret = [42u8; 32];
let generator = Slip39Generator::new();
let shares = generator
.generate_shares(&secret, ShareThreshold::new(3, 5).unwrap(), None)
.unwrap();
let result = Slip39Generator::recover_secret(&shares[0..2]);
assert!(result.is_err());
}
#[test]
fn test_gf256_arithmetic() {
assert_eq!(Slip39Generator::gf256_sub(5, 3), 5 ^ 3);
let result = Slip39Generator::gf256_multiply(3, 7);
assert_ne!(result, 0);
let a = 42u8;
let inv = Slip39Generator::gf256_inverse(a);
let product = Slip39Generator::gf256_multiply(a, inv);
assert_eq!(product, 1);
}
#[test]
fn test_share_mnemonic_conversion() {
let secret = [42u8; 32];
let generator = Slip39Generator::new();
let shares = generator
.generate_shares(&secret, ShareThreshold::new(2, 3).unwrap(), None)
.unwrap();
for share in &shares {
let mnemonic = share.to_mnemonic();
assert!(!mnemonic.is_empty());
let parsed = Slip39Share::from_mnemonic(&mnemonic);
assert!(parsed.is_ok());
}
}
#[test]
fn test_multi_group_generation() {
let groups = vec![
GroupConfig::new(0, ShareThreshold::new(2, 3).unwrap()).unwrap(),
GroupConfig::new(1, ShareThreshold::new(2, 3).unwrap()).unwrap(),
GroupConfig::new(2, ShareThreshold::new(3, 5).unwrap()).unwrap(),
];
let generator = MultiGroupGenerator::new(groups, 2).unwrap();
let secret = [42u8; 32];
let all_shares = generator.generate_all_shares(&secret, None).unwrap();
assert_eq!(all_shares.len(), 3);
assert_eq!(all_shares.get(&0).unwrap().len(), 3);
assert_eq!(all_shares.get(&1).unwrap().len(), 3);
assert_eq!(all_shares.get(&2).unwrap().len(), 5);
}
#[test]
fn test_invalid_secret_length() {
let generator = Slip39Generator::new();
let invalid_secret = [42u8; 20];
let result =
generator.generate_shares(&invalid_secret, ShareThreshold::new(2, 3).unwrap(), None);
assert!(result.is_err());
}
}