idcrypt 0.2.0

A library for securely collecting, encrypting, and verifying identity information with field-level authorization.
Documentation
use std::{collections::HashMap, fmt::Display, io::{self, Cursor, Read}, marker::PhantomData};
use aes_gcm::{aead::Aead, Aes256Gcm, Key, KeyInit, Nonce};
use base64::{prelude::*, DecodeError};
use bitflags::Flags;
use std::hash::Hash;
use num::{traits::{FromBytes, ToBytes}, One, PrimInt, Unsigned};
use p256::{ecdh, PublicKey, SecretKey};
use sha2::Sha256;

#[macro_export]
macro_rules! profile {
    ([$t:ty]$($k:ident=$v:expr)*) => {
        {
            let mut map = ::std::collections::HashMap::new();
            $(
                map.insert(<$t>::$k, $v);
            )*
            map
        }
    };
    ($t:ty{$($k:ident:$v:expr),*$(,)?}) => {
        {
            let mut map = ::std::collections::HashMap::new();
            $(
                map.insert(<$t>::$k, $v);
            )*
            map
        }
    };
}

//NOTE - Trait bounds (defined by where-clause) on associated types of supertraits
//       (and many others) must be repeated everywhere, so here we work around this
//       behavior by using duplicate associated types in `FieldFlags` itself.
//       See https://stackoverflow.com/questions/37600687/requiring-a-trait-bound-on-the-associated-type-of-an-inherited-trait
pub trait FieldFlags
where
    Self: Copy + Eq + Flags<Bits = Self::Repr> + Hash,
{
    type Repr: FromBytes<Bytes = Self::FromBytes> + PrimInt + ToBytes + Unsigned;
    type FromBytes: for<'a> TryFrom<&'a [u8]>;
}

impl<T> FieldFlags for T
where
    Self: Copy + Eq + Flags + Hash,
    Self::Bits: FromBytes + PrimInt + ToBytes + Unsigned,
    <Self::Bits as FromBytes>::Bytes: for<'a> TryFrom<&'a [u8]>,
{
    type Repr = Self::Bits;
    type FromBytes = <Self::Bits as FromBytes>::Bytes;
}

#[derive(Debug)]
pub enum ExtractProfileError {
    InvalidBase64(DecodeError),
    InvalidLength(usize),
    FailedToDecrypt(aes_gcm::Error),
    IoError(io::Error),
    UnknownFieldFlag(u8),
    Leb128ReadError(leb128::read::Error),
}

impl Display for ExtractProfileError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match *self {
            Self::InvalidBase64(ref err) => write!(f, "Invalid base64: {}", err),
            Self::InvalidLength(len) => write!(f, "Base64 decoded auth code has an invalid length of {}", len),
            Self::FailedToDecrypt(ref err) => write!(f, "Failed to decrypt profile: {}", err),
            Self::IoError(ref err) => write!(f, "Encountered I/O error when extracting profile: {}", err),
            Self::UnknownFieldFlag(flag) => write!(f, "Unknown field flag: 1 << {}", flag),
            Self::Leb128ReadError(ref err) => write!(f, "LEB128 read error: {}", err),
        }
    }
}

impl std::error::Error for ExtractProfileError {}

pub struct ProfileExtractor<F>
where
    F: FieldFlags,
{
    field_flags_type: PhantomData<F>,
    cipher: Aes256Gcm,
}

impl<F> ProfileExtractor<F>
where
    F: FieldFlags,
{
    pub fn build(secret_key: &str, public_key: &str) -> Result<Self, Box<dyn std::error::Error>> {
        let secret_key = BASE64_STANDARD.decode(secret_key)?;
        let public_key = BASE64_STANDARD.decode(public_key)?;
        let shared_secret = ecdh::diffie_hellman(
            SecretKey::from_slice(&secret_key)?.to_nonzero_scalar(),
            PublicKey::from_sec1_bytes(&public_key)?.as_affine(),
        );
        let mut okm = [0u8; 32];
        shared_secret.extract::<Sha256>(None).expand(&[], &mut okm).unwrap();
        Ok(Self {
            field_flags_type: PhantomData,
            cipher: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&okm))
        })
    }
    pub fn extract_profile(&self, auth_code: &str) -> Result<HashMap<F, String>, ExtractProfileError> {
        let auth_code = BASE64_STANDARD.decode(auth_code).map_err(|e| ExtractProfileError::InvalidBase64(e))?;
        if auth_code.len() < 13 {
            return Err(ExtractProfileError::InvalidLength(auth_code.len()));
        }
        let nonce = Nonce::from_slice(&auth_code[..12]);
        let mut cursor = Cursor::new(
            self.cipher.decrypt(nonce, &auth_code[12..]).map_err(|e| ExtractProfileError::FailedToDecrypt(e))?,
        );
        let mut profile = HashMap::new();
        while cursor.position() != cursor.get_ref().len().try_into().unwrap() {
            let mut buf = [0u8];
            cursor.read_exact(&mut buf).map_err(|e| ExtractProfileError::IoError(e))?;
            let flag = F::from_bits(F::Repr::one() << buf[0].into()).ok_or(ExtractProfileError::UnknownFieldFlag(buf[0]))?;
            let value_size = leb128::read::unsigned(&mut cursor).map_err(|e| ExtractProfileError::Leb128ReadError(e))?;
            let mut buf = vec![0u8; value_size.try_into().unwrap()];
            cursor.read_exact(&mut buf).map_err(|e| ExtractProfileError::IoError(e))?;
            profile.insert(flag, String::from_utf8_lossy(&buf).to_string());
        }
        Ok(profile)
    }
}