use std::{
io::{Cursor, Read},
num::TryFromIntError,
};
use ciborium::value::Value;
use coset::{AsCborValue, CborSerializable, CoseKey};
use serde::{Deserialize, Serialize};
use crate::{
crypto::sha256,
ctap2::{Aaguid, Flags},
};
use super::{Ctap2Error, get_assertion, make_credential};
#[derive(Debug, PartialEq)]
pub struct AuthenticatorData {
rp_id_hash: [u8; 32],
pub flags: Flags,
pub counter: Option<u32>,
pub attested_credential_data: Option<AttestedCredentialData>,
pub extensions: Option<Value>,
}
impl AuthenticatorData {
pub fn new(rp_id: &str, counter: Option<u32>) -> Self {
Self {
rp_id_hash: sha256(rp_id.as_bytes()),
flags: Flags::default(),
counter,
attested_credential_data: None,
extensions: None,
}
}
pub fn set_attested_credential_data(mut self, acd: AttestedCredentialData) -> Self {
self.attested_credential_data = Some(acd);
self.set_flags(Flags::AT)
}
pub fn set_flags(mut self, flags: Flags) -> Self {
self.flags |= flags;
self
}
pub fn rp_id_hash(&self) -> &[u8] {
&self.rp_id_hash
}
pub fn set_make_credential_extensions(
mut self,
extensions: Option<make_credential::SignedExtensionOutputs>,
) -> Result<Self, Ctap2Error> {
let Some(ext) = extensions.and_then(|e| e.zip_contents()) else {
return Ok(self);
};
self.extensions =
Some(Value::serialized(&ext).map_err(|_| Ctap2Error::CborUnexpectedType)?);
Ok(self.set_flags(Flags::ED))
}
pub fn set_assertion_extensions(
mut self,
extensions: Option<get_assertion::SignedExtensionOutputs>,
) -> Result<Self, Ctap2Error> {
let Some(ext) = extensions.and_then(|e| e.zip_contents()) else {
return Ok(self);
};
self.extensions =
Some(Value::serialized(&ext).map_err(|_| Ctap2Error::CborUnexpectedType)?);
Ok(self.set_flags(Flags::ED))
}
}
impl Serialize for AuthenticatorData {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let bytes = self.to_vec();
serializer.serialize_bytes(&bytes)
}
}
impl<'de> Deserialize<'de> for AuthenticatorData {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor;
impl serde::de::Visitor<'_> for Visitor {
type Value = AuthenticatorData;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("Authenticator Data")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
AuthenticatorData::from_slice(v).map_err(|e| E::custom(e.to_string()))
}
}
deserializer.deserialize_bytes(Visitor)
}
}
fn io_error<E>(_: E) -> coset::CoseError {
coset::CoseError::DecodeFailed(ciborium::de::Error::Io(coset::EndOfFile))
}
impl AuthenticatorData {
pub fn from_slice(v: &[u8]) -> coset::Result<Self> {
if v.len() < 37 {
return Err(io_error(()));
}
let (rp_id_hash, v) = v.split_at(32);
let (flag_byte, v) = v.split_at(1);
let (counter, v) = v.split_at(4);
let flags =
Flags::from_bits(flag_byte[0]).ok_or(coset::CoseError::OutOfRangeIntegerValue)?;
let mut managed_reader = Cursor::new(v);
let attested_credential_data = flags
.contains(Flags::AT)
.then(|| AttestedCredentialData::from_reader(&mut managed_reader))
.transpose()?;
let extensions = flags
.contains(Flags::ED)
.then(|| ciborium::de::from_reader(&mut managed_reader).map_err(io_error))
.transpose()?;
Ok(AuthenticatorData {
rp_id_hash: rp_id_hash.try_into().unwrap(),
flags,
counter: Some(u32::from_be_bytes(counter.try_into().unwrap())),
attested_credential_data,
extensions,
})
}
pub fn to_vec(&self) -> Vec<u8> {
let flags = if self.attested_credential_data.is_some() {
self.flags | Flags::AT
} else {
self.flags
};
self.rp_id_hash
.into_iter()
.chain(std::iter::once(flags.into()))
.chain(self.counter.unwrap_or_default().to_be_bytes())
.chain(self.attested_credential_data.clone().into_iter().flatten())
.chain(
self.extensions
.as_ref()
.map(|val| {
let mut bytes = Vec::new();
ciborium::ser::into_writer(val, &mut bytes).unwrap();
bytes
})
.into_iter()
.flatten(),
)
.collect()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct AttestedCredentialData {
pub aaguid: Aaguid,
credential_id: Vec<u8>,
pub key: CoseKey,
}
impl AttestedCredentialData {
pub fn new(
aaguid: Aaguid,
credential_id: Vec<u8>,
key: CoseKey,
) -> Result<Self, TryFromIntError> {
u16::try_from(credential_id.len())?;
Ok(Self {
aaguid,
credential_id,
key,
})
}
pub fn credential_id(&self) -> &[u8] {
&self.credential_id
}
}
impl AttestedCredentialData {
fn from_reader<R: Read>(reader: &mut R) -> coset::Result<Self> {
let mut aaguid = [0; 16];
reader.read_exact(&mut aaguid).map_err(io_error)?;
let aaguid = Aaguid(aaguid);
let mut cred_len = [0; 2];
reader.read_exact(&mut cred_len).map_err(io_error)?;
let cred_len: usize = u16::from_be_bytes(cred_len).into();
let mut credential_id = vec![0; cred_len];
reader.read_exact(&mut credential_id).map_err(io_error)?;
let cose_val = ciborium::de::from_reader(reader).map_err(io_error)?;
let key = CoseKey::from_cbor_value(cose_val)?;
Ok(Self {
aaguid,
credential_id,
key,
})
}
}
pub struct AttestedCredentialDataIterator {
aaguid: [u8; 16],
credential_id_len: [u8; 2],
credential_id: Vec<u8>,
cose_key: Vec<u8>,
state: AttestedCredentialDataIteratorState,
}
enum AttestedCredentialDataIteratorState {
Aaguid(u8),
CredIdLen(u8),
CredId(u16),
CoseKey(usize),
Done,
}
impl AttestedCredentialDataIterator {
fn new(data: AttestedCredentialData) -> Self {
let aaguid = data.aaguid.0;
let cred_id_len: [u8; 2] = u16::try_from(data.credential_id.len())
.expect("Credential ID length is guaranteed to fit within 16 bytes by AttestedCredentialData constructors")
.to_be_bytes();
let cose_key = data
.key
.clone()
.to_vec()
.expect("Properly formatted COSE key");
AttestedCredentialDataIterator {
aaguid,
credential_id_len: cred_id_len,
credential_id: data.credential_id,
cose_key,
state: AttestedCredentialDataIteratorState::Aaguid(0),
}
}
}
impl Iterator for AttestedCredentialDataIterator {
type Item = u8;
fn next(&mut self) -> Option<Self::Item> {
match self.state {
AttestedCredentialDataIteratorState::Aaguid(x) => {
debug_assert!(x < 16);
if x == 15 {
self.state = AttestedCredentialDataIteratorState::CredIdLen(0);
} else {
self.state = AttestedCredentialDataIteratorState::Aaguid(x + 1)
}
Some(self.aaguid[usize::from(x)])
}
AttestedCredentialDataIteratorState::CredIdLen(x) => {
debug_assert!(x < 2);
if x == 1 {
self.state = AttestedCredentialDataIteratorState::CredId(0);
} else {
self.state = AttestedCredentialDataIteratorState::CredIdLen(x + 1);
}
Some(self.credential_id_len[usize::from(x)])
}
AttestedCredentialDataIteratorState::CredId(x) => {
let cred_id_len: u16 = u16::try_from(self.credential_id.len())
.expect("credential ID length to be less than 2^16");
debug_assert!(x < cred_id_len);
if x == cred_id_len - 1 {
self.state = AttestedCredentialDataIteratorState::CoseKey(0);
} else {
self.state = AttestedCredentialDataIteratorState::CredId(x + 1);
}
Some(self.credential_id[usize::from(x)])
}
AttestedCredentialDataIteratorState::CoseKey(x) => {
if x == self.cose_key.len() - 1 {
self.state = AttestedCredentialDataIteratorState::Done;
} else {
self.state = AttestedCredentialDataIteratorState::CoseKey(x + 1);
}
Some(self.cose_key[x])
}
AttestedCredentialDataIteratorState::Done => None,
}
}
}
impl IntoIterator for AttestedCredentialData {
type Item = u8;
type IntoIter = AttestedCredentialDataIterator;
fn into_iter(self) -> Self::IntoIter {
AttestedCredentialDataIterator::new(self)
}
}
#[cfg(test)]
mod test;