use std::{collections::HashMap, fmt::Display, io, mem};
use aes_gcm::{aead::Aead, AeadCore, Aes256Gcm, Key, KeyInit};
use base64::{prelude::*, DecodeError};
use num::{traits::{FromBytes, ToBytes}, PrimInt};
use p256::{ecdh, ecdsa::{signature::{Signer, Verifier}, Signature, SigningKey, VerifyingKey}, elliptic_curve::sec1::ToEncodedPoint, PublicKey, SecretKey};
use rand_core::OsRng;
use sha2::Sha256;
use crate::{collection::Collection, profile::FieldFlags};
#[derive(Debug)]
pub enum ParseIdError {
InvalidBase64(DecodeError),
InvalidLength(usize),
InvalidPublicKey(p256::elliptic_curve::Error),
InvalidSignature(p256::ecdsa::Error),
UnknownFieldFlag(u8),
}
impl Display for ParseIdError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
Self::InvalidBase64(ref err) => write!(f, "Invalid base64: {}", err),
Self::InvalidLength(len) => write!(f, "Base64 decoded id has an invalid length of {}", len),
Self::InvalidPublicKey(ref err) => write!(f, "Invalid public key: {}", err),
Self::InvalidSignature(ref err) => write!(f, "Invalid signature: {}", err),
Self::UnknownFieldFlag(flag) => write!(f, "Unknown field flag: 1 << {}", flag),
}
}
}
impl std::error::Error for ParseIdError {}
#[derive(Debug)]
pub enum BuildAuthCodeError {
NoFieldCollected,
IoError(io::Error),
}
impl Display for BuildAuthCodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
Self::NoFieldCollected => write!(f, "No field can be collected"),
Self::IoError(ref err) => write!(f, "Encountered I/O error when building auth code: {}", err),
}
}
}
impl std::error::Error for BuildAuthCodeError {}
pub struct Authority {
ecdh_secret_key: SecretKey,
ecdh_public_key: PublicKey,
ecdsa_signing_key: SigningKey,
ecdsa_verifying_key: VerifyingKey,
}
impl Authority {
pub fn new((ecdh_secret_key, ecdsa_signing_key): (SecretKey, SigningKey)) -> Self {
let ecdh_public_key = ecdh_secret_key.public_key();
let ecdsa_verifying_key = ecdsa_signing_key.verifying_key().clone();
Self {
ecdh_secret_key,
ecdh_public_key,
ecdsa_signing_key,
ecdsa_verifying_key,
}
}
pub fn ecdh_secret_key(&self) -> &SecretKey {
&self.ecdh_secret_key
}
pub fn ecdh_public_key(&self) -> PublicKey {
self.ecdh_public_key
}
pub fn ecdsa_signing_key(&self) -> &SigningKey {
&self.ecdsa_signing_key
}
pub fn ecdsa_verifying_key(&self) -> VerifyingKey {
self.ecdsa_verifying_key
}
fn build_collection_id<F>(&self, collection: &Collection<F>) -> String
where
F: FieldFlags,
{
let mut id = Vec::new();
id.extend_from_slice(collection.collected_fields.bits().to_le_bytes().as_ref());
id.extend_from_slice(&collection.public_key.to_encoded_point(true).to_bytes());
let mut signee = Vec::new();
signee.extend_from_slice(&id);
signee.extend_from_slice(collection.context.as_bytes());
id.extend_from_slice(
&<SigningKey as Signer<Signature>>::sign(&self.ecdsa_signing_key, &signee).to_bytes(),
);
BASE64_STANDARD.encode(id)
}
pub fn issue_auth_code<F, V>(&self, collection: &Collection<F>, profile: &HashMap<F, V>) -> Result<String, BuildAuthCodeError>
where
F: FieldFlags,
V: ToString,
{
let mut plaintext = Vec::<u8>::new();
for (_, flag) in collection.collected_fields.iter_names() {
let value = match profile.get(&flag) {
Some(value) => value,
None => continue,
}.to_string();
plaintext.push(flag.bits().trailing_zeros().try_into().unwrap());
leb128::write::unsigned(&mut plaintext, value.as_bytes().len().try_into().unwrap()).map_err(|e| BuildAuthCodeError::IoError(e))?;
plaintext.extend_from_slice(value.as_bytes());
}
if plaintext.len() == 0 {
return Err(BuildAuthCodeError::NoFieldCollected);
}
let shared_secret = ecdh::diffie_hellman(
self.ecdh_secret_key.to_nonzero_scalar(),
collection.public_key.as_affine(),
);
let mut okm = [0u8; 32];
shared_secret.extract::<Sha256>(None).expand(&[], &mut okm).unwrap();
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&okm));
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let mut auth_code = Vec::new();
auth_code.extend_from_slice(&nonce);
auth_code.extend_from_slice(&cipher.encrypt(&nonce, &plaintext[..]).unwrap());
return Ok(BASE64_STANDARD.encode(auth_code));
}
pub fn new_collection<F>(&self, collected_fields: F, context: &str) -> Collection<F>
where
F: FieldFlags,
{
let collection_secret_key = SecretKey::random(&mut OsRng);
let collection_public_key = collection_secret_key.public_key();
Collection {
authority: self,
collected_fields,
public_key: collection_public_key,
secret_key: Some(collection_secret_key),
context: context.into(),
}
}
pub fn load_collection<F>(&self, id: &str, context: &str) -> Result<Collection<F>, ParseIdError>
where
F: FieldFlags,
{
const PUBKEY_SIZE: usize = 33;
let size_of_repr = mem::size_of::<F::Repr>();
let pubkey_end = size_of_repr + PUBKEY_SIZE;
let id = BASE64_STANDARD.decode(id).map_err(|e| ParseIdError::InvalidBase64(e))?;
if id.len() < PUBKEY_SIZE + size_of_repr + 1 {
return Err(ParseIdError::InvalidLength(id.len()));
}
let flag_bits = F::Repr::from_le_bytes(
&(&id[..size_of_repr]).try_into().map_err(|_| unreachable!())?
);
let collected_fields = F::from_bits(flag_bits).ok_or_else(|| ParseIdError::UnknownFieldFlag(flag_bits.trailing_zeros().try_into().unwrap()))?;
let collection_public_key = PublicKey::from_sec1_bytes(&id[size_of_repr..pubkey_end]).map_err(|e| ParseIdError::InvalidPublicKey(e))?;
let mut signee = Vec::new();
signee.extend_from_slice(&id[..pubkey_end]);
signee.extend_from_slice(context.as_bytes());
self.ecdsa_verifying_key.verify(&signee, &Signature::from_slice(&id[pubkey_end..]).map_err(|e| ParseIdError::InvalidSignature(e))?).map_err(|e| ParseIdError::InvalidSignature(e))?;
Ok(Collection {
authority: &self,
collected_fields,
public_key: collection_public_key,
secret_key: None,
context: context.into(),
})
}
}
impl<F> Collection<'_, F>
where
F: FieldFlags,
{
pub fn id(&self) -> String {
self.authority.build_collection_id(self)
}
}