#![feature(const_eval_limit)]
#![feature(test)]
#![const_eval_limit = "0"]
#![warn(
clippy::all,
clippy::pedantic,
clippy::nursery,
clippy::cargo,
clippy::style
)]
#![allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
#[derive(StructOpt)]
enum ShamirMainCommand {
#[structopt(name = "share")]
Share(ShareCommand),
#[structopt(name = "recover")]
Recover(RecoverCommand),
}
#[derive(StructOpt)]
struct ShareCommand {
secret_file: String,
num_shares: u8,
threshold: u8,
}
#[derive(StructOpt)]
struct RecoverCommand {
secret_file: String,
shares: Vec<String>,
}
mod galois_field_256;
mod polynomials;
mod shamir;
const PROGRAM_UUID: uuid::Uuid =
uuid::Uuid::from_u128(0x1bc0_96ca_eec2_4565_8634_f34a_5903_fcda_u128);
const FILE_FORMAT_VERSION: u32 = 1;
use structopt::StructOpt;
fn split_secret_file_into_shares(
filepath: &str,
num_shares: u8,
threshold: u8,
) -> Result<(), Box<dyn std::error::Error>> {
use std::fs::File;
use std::io::prelude::*;
let unique_uuid_for_this_operation = uuid::Uuid::new_v4();
if num_shares < 2 {
return Err("Must have at least two shares".into());
}
if threshold < 2 {
return Err("Must have a threshold of at least two".into());
}
if threshold > num_shares {
return Err("Must have at least as many shares as the threshold for recovery".into());
}
let mut rng = rand::thread_rng();
let abscissae = shamir::create_distinct_x_values(num_shares, &mut rng);
let input_file = std::io::BufReader::new(File::open(filepath)?);
let mut output_files: Vec<_> = vec![];
for index in 1..=num_shares {
let file_name = format!("{filepath}_{index:03}");
let opened_file = std::io::BufWriter::new(File::create(file_name)?);
output_files.push(opened_file);
}
if abscissae.len() != output_files.len() {
return Err("Logic error".into());
}
for (idx, output) in output_files.iter().enumerate() {
bincode::serialize_into(
output.get_ref(),
&(
PROGRAM_UUID.as_bytes(),
FILE_FORMAT_VERSION,
unique_uuid_for_this_operation.as_bytes(),
threshold,
abscissae[idx],
),
)?;
}
for input_byte in input_file.bytes() {
let shares =
shamir::create_secret_shares_from_value(input_byte?, &abscissae, threshold, &mut rng);
if shares.len() != output_files.len() {
return Err("Logic error".into());
}
for index in 0..shares.len() {
output_files[index].write_all(&[shares[index]])?;
}
}
Ok(())
}
fn recover_shared_secret(
inputs: Vec<String>,
output: &str,
) -> Result<(), Box<dyn std::error::Error>> {
use std::fs::File;
use std::io::prelude::*;
let mut input_files: Vec<_> = Vec::new();
for input_file_name in inputs {
let opened_file = std::io::BufReader::new(File::open(input_file_name)?);
input_files.push(opened_file);
}
let mut output_file = std::io::BufWriter::new(File::create(output)?);
let mut program_uuids: Vec<uuid::Uuid> = Vec::new();
for stream in &input_files {
program_uuids.push(uuid::Uuid::from_bytes(bincode::deserialize_from(
stream.get_ref(),
)?));
}
let program_uuids = program_uuids;
if program_uuids.into_iter().any(|uuid| uuid != PROGRAM_UUID) {
return Err("One of the input files was not created by this program".into());
}
let mut file_format_versions: Vec<u32> = Vec::new();
for stream in &input_files {
file_format_versions.push(bincode::deserialize_from(stream.get_ref())?);
}
let file_format_versions = file_format_versions;
if file_format_versions
.into_iter()
.any(|ver| ver != FILE_FORMAT_VERSION)
{
return Err(
"One of the input files was created by a non-supported version of this program".into(),
);
}
let mut file_identifiers: Vec<uuid::Uuid> = Vec::new();
for stream in &input_files {
file_identifiers.push(uuid::Uuid::from_bytes(bincode::deserialize_from(
stream.get_ref(),
)?));
}
file_identifiers.sort();
file_identifiers.dedup();
if file_identifiers.len() != 1 {
return Err("The secrets shares do not pertain to the same set of secrets".into());
}
let mut threshold: Vec<u8> = Vec::new();
for stream in &input_files {
threshold.push(bincode::deserialize_from(stream.get_ref())?);
}
threshold.sort_unstable();
threshold.dedup();
if threshold.len() != 1 {
return Err("The secrets shares do not pertain to the same set of secrets".into());
}
let threshold = threshold[0];
let mut abscissae: Vec<u8> = Vec::new();
for stream in &input_files {
abscissae.push(bincode::deserialize_from(stream.get_ref())?);
}
let abscissae = abscissae;
let mut sorted_abscissae = abscissae.clone();
sorted_abscissae.sort_unstable();
sorted_abscissae.dedup();
if sorted_abscissae.len() != input_files.len() {
return Err("The same share has been given twice".into());
}
let mut y_values_as_options: Vec<Option<u8>> = Vec::with_capacity(input_files.len());
let mut y_values: Vec<u8> = Vec::with_capacity(input_files.len());
loop {
y_values_as_options.clear();
y_values_as_options.extend(input_files.iter_mut().map(|stream| {
let mut x = [0];
match stream.read(&mut x) {
Ok(1) => Some(x[0]),
_ => None,
}
}));
if y_values_as_options.iter().any(std::option::Option::is_none) {
if y_values_as_options.iter().all(std::option::Option::is_none) {
break;
}
return Err("An input file has reached its end".into());
}
y_values.clear();
y_values.extend(y_values_as_options.iter().map(|x| x.map_or(0, |val| val)));
let recovered_byte = shamir::recover_secret_from_points(&abscissae, &y_values, threshold);
output_file.write_all(&[recovered_byte])?;
}
Ok(())
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let operation = ShamirMainCommand::from_args();
match operation {
ShamirMainCommand::Share(cmd) => {
split_secret_file_into_shares(&cmd.secret_file, cmd.num_shares, cmd.threshold)?;
}
ShamirMainCommand::Recover(cmd) => recover_shared_secret(cmd.shares, &cmd.secret_file)?,
};
Ok(())
}