use std::io::{Read, Write};
use openmls_traits::signatures::Signer;
use serde::{Deserialize, Serialize};
use tls_codec::{
Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait,
Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes,
};
#[cfg(test)]
mod tests;
use crate::{ciphersuite::SignaturePublicKey, group::Member, treesync::LeafNode};
use errors::*;
#[cfg(doc)]
use crate::group::MlsGroup;
pub mod errors;
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[repr(u16)]
pub enum CredentialType {
Basic = 1,
X509 = 2,
Grease(u16),
Other(u16),
}
impl CredentialType {
pub fn is_grease(&self) -> bool {
matches!(self, CredentialType::Grease(_))
}
}
impl Size for CredentialType {
fn tls_serialized_len(&self) -> usize {
2
}
}
impl TlsDeserializeTrait for CredentialType {
fn tls_deserialize<R: Read>(bytes: &mut R) -> Result<Self, Error>
where
Self: Sized,
{
let mut extension_type = [0u8; 2];
bytes.read_exact(&mut extension_type)?;
Ok(CredentialType::from(u16::from_be_bytes(extension_type)))
}
}
impl TlsSerializeTrait for CredentialType {
fn tls_serialize<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
writer.write_all(&u16::from(*self).to_be_bytes())?;
Ok(2)
}
}
impl DeserializeBytes for CredentialType {
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error>
where
Self: Sized,
{
let mut bytes_ref = bytes;
let credential_type = CredentialType::tls_deserialize(&mut bytes_ref)?;
let remainder = &bytes[credential_type.tls_serialized_len()..];
Ok((credential_type, remainder))
}
}
impl From<u16> for CredentialType {
fn from(value: u16) -> Self {
match value {
1 => CredentialType::Basic,
2 => CredentialType::X509,
other if crate::grease::is_grease_value(other) => CredentialType::Grease(other),
other => CredentialType::Other(other),
}
}
}
impl From<CredentialType> for u16 {
fn from(value: CredentialType) -> Self {
match value {
CredentialType::Basic => 1,
CredentialType::X509 => 2,
CredentialType::Grease(value) => value,
CredentialType::Other(other) => other,
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct Certificate {
cert_data: Vec<u8>,
}
#[derive(
Debug,
PartialEq,
Eq,
Clone,
Serialize,
Deserialize,
TlsSize,
TlsSerialize,
TlsDeserialize,
TlsDeserializeBytes,
)]
pub struct Credential {
credential_type: CredentialType,
serialized_credential_content: VLBytes,
}
impl Credential {
pub fn credential_type(&self) -> CredentialType {
self.credential_type
}
pub fn new(credential_type: CredentialType, serialized_credential: Vec<u8>) -> Self {
Self {
credential_type,
serialized_credential_content: serialized_credential.into(),
}
}
pub fn serialized_content(&self) -> &[u8] {
self.serialized_credential_content.as_slice()
}
pub fn deserialized<T: tls_codec::Size + tls_codec::Deserialize>(
&self,
) -> Result<T, tls_codec::Error> {
T::tls_deserialize_exact(&self.serialized_credential_content)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BasicCredential {
identity: VLBytes,
}
impl BasicCredential {
pub fn new(identity: Vec<u8>) -> Self {
Self {
identity: identity.into(),
}
}
pub fn identity(&self) -> &[u8] {
self.identity.as_slice()
}
}
impl From<BasicCredential> for Credential {
fn from(credential: BasicCredential) -> Self {
Credential {
credential_type: CredentialType::Basic,
serialized_credential_content: credential.identity,
}
}
}
impl TryFrom<Credential> for BasicCredential {
type Error = BasicCredentialError;
fn try_from(credential: Credential) -> Result<Self, Self::Error> {
match credential.credential_type {
CredentialType::Basic => Ok(BasicCredential::new(
credential.serialized_credential_content.into(),
)),
_ => Err(errors::BasicCredentialError::WrongCredentialType),
}
}
}
#[derive(Debug, Clone)]
pub struct NewSignerBundle<'a, S: Signer> {
pub signer: &'a S,
pub credential_with_key: CredentialWithKey,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CredentialWithKey {
pub credential: Credential,
pub signature_key: SignaturePublicKey,
}
impl From<&LeafNode> for CredentialWithKey {
fn from(leaf_node: &LeafNode) -> Self {
Self {
credential: leaf_node.credential().clone(),
signature_key: leaf_node.signature_key().clone(),
}
}
}
impl From<&Member> for CredentialWithKey {
fn from(member: &Member) -> Self {
Self {
credential: member.credential.clone(),
signature_key: member.signature_key.clone().into(),
}
}
}
#[cfg(test)]
impl CredentialWithKey {
pub fn from_parts(credential: Credential, key: &[u8]) -> Self {
Self {
credential,
signature_key: key.into(),
}
}
}
#[cfg(any(test, feature = "test-utils"))]
pub mod test_utils {
use openmls_basic_credential::SignatureKeyPair;
use openmls_traits::{types::SignatureScheme, OpenMlsProvider};
use super::{BasicCredential, CredentialWithKey};
pub fn new_credential(
provider: &impl OpenMlsProvider,
identity: &[u8],
signature_scheme: SignatureScheme,
) -> (CredentialWithKey, SignatureKeyPair) {
let credential = BasicCredential::new(identity.into());
let signature_keys = SignatureKeyPair::new(signature_scheme).unwrap();
signature_keys.store(provider.storage()).unwrap();
(
CredentialWithKey {
credential: credential.into(),
signature_key: signature_keys.public().into(),
},
signature_keys,
)
}
}
#[cfg(test)]
mod unit_tests {
use tls_codec::{
DeserializeBytes, Serialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize,
};
use super::{BasicCredential, Credential, CredentialType};
#[test]
fn basic_credential_identity_and_codec() {
const IDENTITY: &str = "identity";
let basic_credential = BasicCredential::new(IDENTITY.into());
assert_eq!(basic_credential.identity(), IDENTITY.as_bytes());
let credential = Credential::from(basic_credential.clone());
let serialized = credential.tls_serialize_detached().unwrap();
let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
assert_eq!(credential.credential_type(), deserialized.credential_type());
assert_eq!(
credential.serialized_content(),
deserialized.serialized_content()
);
let deserialized_basic_credential = BasicCredential::try_from(deserialized).unwrap();
assert_eq!(
deserialized_basic_credential.identity(),
IDENTITY.as_bytes()
);
assert_eq!(basic_credential, deserialized_basic_credential);
}
#[test]
fn custom_credential() {
#[derive(
Debug, Clone, PartialEq, Eq, TlsSize, TlsSerialize, TlsDeserialize, TlsDeserializeBytes,
)]
struct CustomCredential {
custom_field1: u32,
custom_field2: Vec<u8>,
custom_field3: Option<u8>,
}
let custom_credential = CustomCredential {
custom_field1: 42,
custom_field2: vec![1, 2, 3],
custom_field3: Some(2),
};
let credential = Credential::new(
CredentialType::Other(1234),
custom_credential.tls_serialize_detached().unwrap(),
);
let serialized = credential.tls_serialize_detached().unwrap();
let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap();
assert_eq!(credential, deserialized);
let deserialized_custom_credential =
CustomCredential::tls_deserialize_exact_bytes(deserialized.serialized_content())
.unwrap();
assert_eq!(custom_credential, deserialized_custom_credential);
}
}