napa 0.3.1

A simple and secure command line password manager
// SPDX-License-Identifier: MPL-2.0

use thiserror::Error;

use crate::crypto::argon2::{Iterations, Memory, Salt, Threads};
use crate::crypto::chacha20::Nonce;

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Magic(pub u32);

impl Magic {
    pub const SIZE: usize = 4;
}

impl Default for Magic {
    fn default() -> Magic {
        Magic(0xA8F988BA)
    }
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Version(pub u32);

impl Version {
    pub const SIZE: usize = 4;
}

#[allow(clippy::derivable_impls)]
impl Default for Version {
    fn default() -> Version {
        // TODO bump this to version 1 eventually
        Version(0)
    }
}

#[derive(Error, Debug, PartialEq, Eq)]
pub enum HeaderSerdesError {
    #[error("Header length {found:?} is less than {expected:?}")]
    TooShort { found: usize, expected: usize },
    #[error("Header magic {found:?} does not match {expected:?}")]
    WrongMagic { found: Magic, expected: Magic },
    #[error("Header version {found:?} does not match supported version {expected:?}")]
    UnsupportedVersion { found: Version, expected: Version },
}

#[derive(Debug, Default, PartialEq, Eq)]
pub struct Header {
    pub magic: Magic,
    pub version: Version,
    pub iterations: Iterations,
    pub memory: Memory,
    pub threads: Threads,
    pub salt: Salt,
    pub nonce: Nonce,
}

impl Header {
    pub const SIZE: usize =
        Magic::SIZE + Version::SIZE + Iterations::SIZE + Memory::SIZE + Threads::SIZE + Salt::SIZE + Nonce::SIZE;

    pub fn serialize(&self) -> [u8; Header::SIZE] {
        let mut bytes = [0u8; Header::SIZE];

        bytes[0..4].copy_from_slice(&self.magic.0.to_le_bytes());
        bytes[4..8].copy_from_slice(&self.version.0.to_le_bytes());
        bytes[8..12].copy_from_slice(&self.iterations.0.to_le_bytes());
        bytes[12..16].copy_from_slice(&self.memory.0.to_le_bytes());
        bytes[16..20].copy_from_slice(&self.threads.0.to_le_bytes());
        bytes[20..36].copy_from_slice(&self.salt.0);
        bytes[36..60].copy_from_slice(&self.nonce.0);

        bytes
    }

    pub fn deserialize(bytes: &[u8]) -> Result<Header, HeaderSerdesError> {
        if bytes.len() < Header::SIZE {
            return Err(HeaderSerdesError::TooShort {
                found: bytes.len(),
                expected: Header::SIZE,
            });
        }

        // Indexing into the bytes slice won't fail since we already checked the length, but
        // Rust doesn't know that, hence the unwrap
        let magic = Magic(u32::from_le_bytes(bytes[0..4].try_into().unwrap()));
        let version = Version(u32::from_le_bytes(bytes[4..8].try_into().unwrap()));
        let iterations = u32::from_le_bytes(bytes[8..12].try_into().unwrap());
        let memory = u32::from_le_bytes(bytes[12..16].try_into().unwrap());
        let threads = u32::from_le_bytes(bytes[16..20].try_into().unwrap());
        let salt = bytes[20..36].try_into().unwrap();
        let nonce = bytes[36..60].try_into().unwrap();

        if magic != Magic::default() {
            return Err(HeaderSerdesError::WrongMagic {
                found: magic,
                expected: Magic::default(),
            });
        }

        if version != Version::default() {
            return Err(HeaderSerdesError::UnsupportedVersion {
                found: version,
                expected: Version::default(),
            });
        }

        let header = Header {
            magic,
            version,
            iterations: Iterations(iterations),
            memory: Memory(memory),
            threads: Threads(threads),
            salt: Salt(salt),
            nonce: Nonce(nonce),
        };

        Ok(header)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn header_size_matches() {
        assert_eq!(Header::SIZE, 60);
    }

    #[test]
    fn serialization_roundtrips() {
        let header = Header::default();
        let bytes = header.serialize();
        let des_header = Header::deserialize(&bytes).unwrap();

        assert_eq!(header, des_header);
    }

    #[test]
    fn deserialization_detects_too_short_header() {
        let bytes = vec![0x0; 10];
        let error = Header::deserialize(&bytes).unwrap_err();
        assert_eq!(
            error,
            HeaderSerdesError::TooShort {
                found: 10,
                expected: Header::SIZE
            }
        );
    }

    #[test]
    fn deserialization_detects_wrong_magic() {
        let mut header = Header::default();
        let magic = Magic(0x1);
        header.magic = magic;
        let bytes = header.serialize();
        let error = Header::deserialize(&bytes).err().unwrap();
        assert_eq!(
            error,
            HeaderSerdesError::WrongMagic {
                found: magic,
                expected: Magic::default()
            }
        );
    }

    #[test]
    fn deserialization_detects_wrong_version() {
        let mut header = Header::default();
        let version = Version(100);
        header.version = version;
        let bytes = header.serialize();
        let error = Header::deserialize(&bytes).unwrap_err();
        assert_eq!(
            error,
            HeaderSerdesError::UnsupportedVersion {
                found: version,
                expected: Version::default()
            }
        );
    }
}