use core::{
fmt::{self, Debug},
ops::Deref,
};
use alloc::vec::Vec;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use super::BasicCredential;
#[cfg(feature = "x509")]
use super::CertificateChain;
#[derive(
Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct CredentialType(u16);
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl CredentialType {
pub const BASIC: CredentialType = CredentialType(1);
#[cfg(feature = "x509")]
pub const X509: CredentialType = CredentialType(2);
pub const fn new(raw_value: u16) -> Self {
CredentialType(raw_value)
}
pub const fn raw_value(&self) -> u16 {
self.0
}
}
impl From<u16> for CredentialType {
fn from(value: u16) -> Self {
CredentialType(value)
}
}
impl Deref for CredentialType {
type Target = u16;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(
all(feature = "ffi", not(test)),
safer_ffi_gen::ffi_type(clone, opaque)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CustomCredential {
pub credential_type: CredentialType,
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
pub data: Vec<u8>,
}
impl Debug for CustomCredential {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CustomCredential")
.field("credential_type", &self.credential_type)
.field("data", &crate::debug::pretty_bytes(&self.data))
.finish()
}
}
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl CustomCredential {
pub fn new(credential_type: CredentialType, data: Vec<u8>) -> CustomCredential {
CustomCredential {
credential_type,
data,
}
}
#[cfg(feature = "ffi")]
pub fn credential_type(&self) -> CredentialType {
self.credential_type
}
#[cfg(feature = "ffi")]
pub fn data(&self) -> &[u8] {
&self.data
}
}
#[derive(Clone, Debug, PartialEq, Ord, PartialOrd, Eq, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(
all(feature = "ffi", not(test)),
safer_ffi_gen::ffi_type(clone, opaque)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Credential {
Basic(BasicCredential),
#[cfg(feature = "x509")]
X509(CertificateChain),
Custom(CustomCredential),
}
impl Credential {
pub fn credential_type(&self) -> CredentialType {
match self {
Credential::Basic(_) => CredentialType::BASIC,
#[cfg(feature = "x509")]
Credential::X509(_) => CredentialType::X509,
Credential::Custom(c) => c.credential_type,
}
}
pub fn as_basic(&self) -> Option<&BasicCredential> {
match self {
Credential::Basic(basic) => Some(basic),
_ => None,
}
}
#[cfg(feature = "x509")]
pub fn as_x509(&self) -> Option<&CertificateChain> {
match self {
Credential::X509(chain) => Some(chain),
_ => None,
}
}
pub fn as_custom(&self) -> Option<&CustomCredential> {
match self {
Credential::Custom(custom) => Some(custom),
_ => None,
}
}
}
impl MlsSize for Credential {
fn mls_encoded_len(&self) -> usize {
let inner_len = match self {
Credential::Basic(c) => c.mls_encoded_len(),
#[cfg(feature = "x509")]
Credential::X509(c) => c.mls_encoded_len(),
Credential::Custom(c) => mls_rs_codec::byte_vec::mls_encoded_len(&c.data),
};
self.credential_type().mls_encoded_len() + inner_len
}
}
impl MlsEncode for Credential {
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
self.credential_type().mls_encode(writer)?;
match self {
Credential::Basic(c) => c.mls_encode(writer),
#[cfg(feature = "x509")]
Credential::X509(c) => c.mls_encode(writer),
Credential::Custom(c) => mls_rs_codec::byte_vec::mls_encode(&c.data, writer),
}
}
}
impl MlsDecode for Credential {
fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
let credential_type = CredentialType::mls_decode(reader)?;
Ok(match credential_type {
CredentialType::BASIC => Credential::Basic(BasicCredential::mls_decode(reader)?),
#[cfg(feature = "x509")]
CredentialType::X509 => Credential::X509(CertificateChain::mls_decode(reader)?),
custom => Credential::Custom(CustomCredential {
credential_type: custom,
data: mls_rs_codec::byte_vec::mls_decode(reader)?,
}),
})
}
}
pub trait MlsCredential: Sized {
type Error;
fn credential_type() -> CredentialType;
fn into_credential(self) -> Result<Credential, Self::Error>;
}