use std::{collections::HashMap, fmt::Display, io::{self, Cursor, Read}, marker::PhantomData};
use aes_gcm::{aead::Aead, Aes256Gcm, Key, KeyInit, Nonce};
use base64::{prelude::*, DecodeError};
use bitflags::Flags;
use std::hash::Hash;
use num::{traits::{FromBytes, ToBytes}, One, PrimInt, Unsigned};
use p256::{ecdh, PublicKey, SecretKey};
use sha2::Sha256;
#[macro_export]
macro_rules! profile {
([$t:ty]$($k:ident=$v:expr)*) => {
{
let mut map = ::std::collections::HashMap::new();
$(
map.insert(<$t>::$k, $v);
)*
map
}
};
($t:ty{$($k:ident:$v:expr),*$(,)?}) => {
{
let mut map = ::std::collections::HashMap::new();
$(
map.insert(<$t>::$k, $v);
)*
map
}
};
}
pub trait FieldFlags
where
Self: Copy + Eq + Flags<Bits = Self::Repr> + Hash,
{
type Repr: FromBytes<Bytes = Self::FromBytes> + PrimInt + ToBytes + Unsigned;
type FromBytes: for<'a> TryFrom<&'a [u8]>;
}
impl<T> FieldFlags for T
where
Self: Copy + Eq + Flags + Hash,
Self::Bits: FromBytes + PrimInt + ToBytes + Unsigned,
<Self::Bits as FromBytes>::Bytes: for<'a> TryFrom<&'a [u8]>,
{
type Repr = Self::Bits;
type FromBytes = <Self::Bits as FromBytes>::Bytes;
}
#[derive(Debug)]
pub enum ExtractProfileError {
InvalidBase64(DecodeError),
InvalidLength(usize),
FailedToDecrypt(aes_gcm::Error),
IoError(io::Error),
UnknownFieldFlag(u8),
Leb128ReadError(leb128::read::Error),
}
impl Display for ExtractProfileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
Self::InvalidBase64(ref err) => write!(f, "Invalid base64: {}", err),
Self::InvalidLength(len) => write!(f, "Base64 decoded auth code has an invalid length of {}", len),
Self::FailedToDecrypt(ref err) => write!(f, "Failed to decrypt profile: {}", err),
Self::IoError(ref err) => write!(f, "Encountered I/O error when extracting profile: {}", err),
Self::UnknownFieldFlag(flag) => write!(f, "Unknown field flag: 1 << {}", flag),
Self::Leb128ReadError(ref err) => write!(f, "LEB128 read error: {}", err),
}
}
}
impl std::error::Error for ExtractProfileError {}
pub struct ProfileExtractor<F>
where
F: FieldFlags,
{
field_flags_type: PhantomData<F>,
cipher: Aes256Gcm,
}
impl<F> ProfileExtractor<F>
where
F: FieldFlags,
{
pub fn build(secret_key: &str, public_key: &str) -> Result<Self, Box<dyn std::error::Error>> {
let secret_key = BASE64_STANDARD.decode(secret_key)?;
let public_key = BASE64_STANDARD.decode(public_key)?;
let shared_secret = ecdh::diffie_hellman(
SecretKey::from_slice(&secret_key)?.to_nonzero_scalar(),
PublicKey::from_sec1_bytes(&public_key)?.as_affine(),
);
let mut okm = [0u8; 32];
shared_secret.extract::<Sha256>(None).expand(&[], &mut okm).unwrap();
Ok(Self {
field_flags_type: PhantomData,
cipher: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&okm))
})
}
pub fn extract_profile(&self, auth_code: &str) -> Result<HashMap<F, String>, ExtractProfileError> {
let auth_code = BASE64_STANDARD.decode(auth_code).map_err(|e| ExtractProfileError::InvalidBase64(e))?;
if auth_code.len() < 13 {
return Err(ExtractProfileError::InvalidLength(auth_code.len()));
}
let nonce = Nonce::from_slice(&auth_code[..12]);
let mut cursor = Cursor::new(
self.cipher.decrypt(nonce, &auth_code[12..]).map_err(|e| ExtractProfileError::FailedToDecrypt(e))?,
);
let mut profile = HashMap::new();
while cursor.position() != cursor.get_ref().len().try_into().unwrap() {
let mut buf = [0u8];
cursor.read_exact(&mut buf).map_err(|e| ExtractProfileError::IoError(e))?;
let flag = F::from_bits(F::Repr::one() << buf[0].into()).ok_or(ExtractProfileError::UnknownFieldFlag(buf[0]))?;
let value_size = leb128::read::unsigned(&mut cursor).map_err(|e| ExtractProfileError::Leb128ReadError(e))?;
let mut buf = vec![0u8; value_size.try_into().unwrap()];
cursor.read_exact(&mut buf).map_err(|e| ExtractProfileError::IoError(e))?;
profile.insert(flag, String::from_utf8_lossy(&buf).to_string());
}
Ok(profile)
}
}