use thiserror::Error;
use crate::crypto::argon2::{Iterations, Memory, Salt, Threads};
use crate::crypto::chacha20::Nonce;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Magic(pub u32);
impl Magic {
pub const SIZE: usize = 4;
}
impl Default for Magic {
fn default() -> Magic {
Magic(0xA8F988BA)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Version(pub u32);
impl Version {
pub const SIZE: usize = 4;
}
#[allow(clippy::derivable_impls)]
impl Default for Version {
fn default() -> Version {
Version(0)
}
}
#[derive(Error, Debug, PartialEq, Eq)]
pub enum HeaderSerdesError {
#[error("Header length {found:?} is less than {expected:?}")]
TooShort { found: usize, expected: usize },
#[error("Header magic {found:?} does not match {expected:?}")]
WrongMagic { found: Magic, expected: Magic },
#[error("Header version {found:?} does not match supported version {expected:?}")]
UnsupportedVersion { found: Version, expected: Version },
}
#[derive(Debug, Default, PartialEq, Eq)]
pub struct Header {
pub magic: Magic,
pub version: Version,
pub iterations: Iterations,
pub memory: Memory,
pub threads: Threads,
pub salt: Salt,
pub nonce: Nonce,
}
impl Header {
pub const SIZE: usize =
Magic::SIZE + Version::SIZE + Iterations::SIZE + Memory::SIZE + Threads::SIZE + Salt::SIZE + Nonce::SIZE;
pub fn serialize(&self) -> [u8; Header::SIZE] {
let mut bytes = [0u8; Header::SIZE];
bytes[0..4].copy_from_slice(&self.magic.0.to_le_bytes());
bytes[4..8].copy_from_slice(&self.version.0.to_le_bytes());
bytes[8..12].copy_from_slice(&self.iterations.0.to_le_bytes());
bytes[12..16].copy_from_slice(&self.memory.0.to_le_bytes());
bytes[16..20].copy_from_slice(&self.threads.0.to_le_bytes());
bytes[20..36].copy_from_slice(&self.salt.0);
bytes[36..60].copy_from_slice(&self.nonce.0);
bytes
}
pub fn deserialize(bytes: &[u8]) -> Result<Header, HeaderSerdesError> {
if bytes.len() < Header::SIZE {
return Err(HeaderSerdesError::TooShort {
found: bytes.len(),
expected: Header::SIZE,
});
}
let magic = Magic(u32::from_le_bytes(bytes[0..4].try_into().unwrap()));
let version = Version(u32::from_le_bytes(bytes[4..8].try_into().unwrap()));
let iterations = u32::from_le_bytes(bytes[8..12].try_into().unwrap());
let memory = u32::from_le_bytes(bytes[12..16].try_into().unwrap());
let threads = u32::from_le_bytes(bytes[16..20].try_into().unwrap());
let salt = bytes[20..36].try_into().unwrap();
let nonce = bytes[36..60].try_into().unwrap();
if magic != Magic::default() {
return Err(HeaderSerdesError::WrongMagic {
found: magic,
expected: Magic::default(),
});
}
if version != Version::default() {
return Err(HeaderSerdesError::UnsupportedVersion {
found: version,
expected: Version::default(),
});
}
let header = Header {
magic,
version,
iterations: Iterations(iterations),
memory: Memory(memory),
threads: Threads(threads),
salt: Salt(salt),
nonce: Nonce(nonce),
};
Ok(header)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_size_matches() {
assert_eq!(Header::SIZE, 60);
}
#[test]
fn serialization_roundtrips() {
let header = Header::default();
let bytes = header.serialize();
let des_header = Header::deserialize(&bytes).unwrap();
assert_eq!(header, des_header);
}
#[test]
fn deserialization_detects_too_short_header() {
let bytes = vec![0x0; 10];
let error = Header::deserialize(&bytes).unwrap_err();
assert_eq!(
error,
HeaderSerdesError::TooShort {
found: 10,
expected: Header::SIZE
}
);
}
#[test]
fn deserialization_detects_wrong_magic() {
let mut header = Header::default();
let magic = Magic(0x1);
header.magic = magic;
let bytes = header.serialize();
let error = Header::deserialize(&bytes).err().unwrap();
assert_eq!(
error,
HeaderSerdesError::WrongMagic {
found: magic,
expected: Magic::default()
}
);
}
#[test]
fn deserialization_detects_wrong_version() {
let mut header = Header::default();
let version = Version(100);
header.version = version;
let bytes = header.serialize();
let error = Header::deserialize(&bytes).unwrap_err();
assert_eq!(
error,
HeaderSerdesError::UnsupportedVersion {
found: version,
expected: Version::default()
}
);
}
}