use std::{
path::Path,
str::FromStr,
fs::File,
io::Read,
sync::{Arc, Mutex},
};
use rand::Rng;
use salsa20::Key;
use thiserror::Error;
use derive_getters::Getters;
use rand_chacha::ChaCha20Rng;
use super::erase_bytes;
#[derive(Clone, Debug)]
pub enum Randomness {
Entropy,
ChaCha20(Arc<Mutex<ChaCha20Rng>>),
}
impl Randomness {
pub (in crate) fn try_fill(&self, buf: &mut [u8]) -> Result<(), RandomnessError> {
match self {
Self::Entropy => getrandom::getrandom(buf)?,
Self::ChaCha20(cha_m) => {
let mut cha_g = cha_m
.lock()
.expect("ChaCha20 poison.");
cha_g.try_fill(buf)?
},
}
Ok(())
}
}
impl Default for Randomness {
fn default() -> Self {
Randomness::Entropy
}
}
impl From<ChaCha20Rng> for Randomness {
fn from(c: ChaCha20Rng) -> Self {
Randomness::ChaCha20(Arc::new(Mutex::new(c)))
}
}
#[derive(Debug, Error)]
pub enum RandomnessError {
#[error("Pure Randomness Error: {0}")]
Pure(#[from] getrandom::Error),
#[error("ChaCha20 Error: {0}")]
ChaCha20(#[from] rand::Error),
}
#[derive(Clone, Debug, PartialEq, Eq, Getters)]
pub struct WrapKey {
key: Key,
}
impl WrapKey {
pub fn new(key: &Key) -> Self {
Self { key: key.clone() }
}
pub fn from_bytes_and_erase_source(bytes: &mut [u8; 32]) -> Self {
let key = Key::clone_from_slice(bytes);
erase_bytes(bytes);
Self { key }
}
}
impl Drop for WrapKey {
fn drop(&mut self) {
erase_bytes(self.key.as_mut_slice())
}
}
impl From<&Key> for WrapKey {
fn from(key: &Key) -> Self {
Self::new(key)
}
}
impl FromStr for WrapKey {
type Err = ParseKeyError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut bytes = base64::decode(s.trim().as_bytes())?;
if bytes.len() != 32 {
erase_bytes(bytes.as_mut_slice());
return Err(ParseKeyError::Length(bytes.len()));
}
let mut checked: [u8; 32] = [0; 32];
checked.copy_from_slice(bytes.as_slice());
erase_bytes(bytes.as_mut_slice());
Ok(WrapKey::from_bytes_and_erase_source(&mut checked))
}
}
#[derive(Debug, Error)]
pub enum ParseKeyError {
#[error("Base64: {0}")]
Base64(#[from] base64::DecodeError),
#[error("Invalid Length: {0}. Need 32 bytes.")]
Length(usize),
}
#[derive(Clone, Debug, PartialEq, Eq, Getters)]
pub struct Psk {
wrap_k: WrapKey,
check: Vec<u8>,
}
impl Psk {
pub fn new(wrap_k: WrapKey, check: Vec<u8>) -> Self {
Self { wrap_k, check }
}
}
impl Drop for Psk {
fn drop(&mut self) {
erase_bytes(self.check.as_mut_slice());
}
}
impl FromStr for Psk {
type Err = LoadPskError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut lines: Vec<&str> = s
.split('\n')
.collect();
if lines.len() < 2 {
return Err(LoadPskError::Lines(lines.len()))
}
let raw_k = lines.remove(0);
let raw_c = lines
.iter()
.fold(String::new(), |mut raw_c, l| {
raw_c.push_str(l);
raw_c
});
let wrap_k = WrapKey::from_str(&raw_k)?;
let check = base64::decode(raw_c.trim().as_bytes())?;
Ok(Psk::new(wrap_k, check))
}
}
pub fn load_psk<P: AsRef<Path>>(path: P) -> Result<Psk, LoadPskError> {
let mut file = File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
let psk = Psk::from_str(&contents)?;
let mut contents = contents.into_bytes();
erase_bytes(contents.as_mut_slice());
Ok(psk)
}
#[derive(Debug, Error)]
pub enum LoadPskError {
#[error("Couldn't parse key: {0}")]
Key(#[from] ParseKeyError),
#[error("Must have two base64 encoded numbers. Found {0}.")]
Lines(usize),
#[error("Base64: {0}")]
Base64(#[from] base64::DecodeError),
#[error("IO Error: {0}")]
IO(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use super::*;
const KEY_B64: &str = "zvy9fTrwI/QdBISSLEZhrrKc2Bir4/WmLGOKfNTPaMg=";
const PSK_B64: &str = "iSDarIpxoyTGUgFlar66/3J7LIq69pD2oKOU1o02HR0=
ObW84/49Zx8EFDtHegEGb/I4lje8/hBb5EgJriJ9SHCha29UUMptmJ5WaS8wxq5v
+NS8FhdD5rpK1cZzGsf4VDmm0OaWHGP8fBJ4Vh2piQh98cCTZmes0cepmCP3PF8b
H1VMlsUdzRZzmVYn50drKaexxvS/UhEn7qy2LrGX9J7CrZ1p8P15lqRxDD+jaJuh
3hX5XTB77kmaYZzUQrkCgqA7kKB8nE1K4ETzXzK77zsPr39Stcim3OloXcwW1EbD
CL0VW2id0/5EJ0v/xd7LvM/OjGRo/A8XrGV2R4SsnxbpkCyeK+bsu38OHLi8rb1T
MxAneKN2CKvT9JoNbqtJCA==";
#[test]
fn parse_b64_key() {
let _ = WrapKey::from_str(KEY_B64).unwrap();
}
#[test]
fn parse_psk() {
let _ = Psk::from_str(PSK_B64).unwrap();
}
#[test]
fn randomness_fills() {
let mut buffer: [u8; 16384] = [0; 16384];
let check = buffer;
let randomness = Randomness::default();
assert!(check == buffer);
randomness.try_fill(&mut buffer).unwrap();
assert!(check != buffer);
}
#[test]
fn chacha20_fills() {
let mut buffer: [u8; 16384] = [0; 16384];
let check = buffer;
let randomness: Randomness = ChaCha20Rng::seed_from_u64(1101).into();
assert!(check == buffer);
randomness.try_fill(&mut buffer).unwrap();
assert!(check != buffer);
}
}