1use core::{
6 fmt::{self, Debug},
7 ops::Deref,
8};
9
10use alloc::vec::Vec;
11use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
12
13use super::BasicCredential;
14
15#[cfg(feature = "x509")]
16use super::CertificateChain;
17
18#[derive(
21 Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
22)]
23#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[repr(transparent)]
27pub struct CredentialType(u16);
28
29#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
30impl CredentialType {
31 pub const BASIC: CredentialType = CredentialType(1);
33
34 #[cfg(feature = "x509")]
35 pub const X509: CredentialType = CredentialType(2);
37
38 pub const fn new(raw_value: u16) -> Self {
39 CredentialType(raw_value)
40 }
41
42 pub const fn raw_value(&self) -> u16 {
43 self.0
44 }
45}
46
47impl From<u16> for CredentialType {
48 fn from(value: u16) -> Self {
49 CredentialType(value)
50 }
51}
52
53impl Deref for CredentialType {
54 type Target = u16;
55
56 fn deref(&self) -> &Self::Target {
57 &self.0
58 }
59}
60
61#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq, Hash, PartialOrd, Ord)]
62#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
63#[cfg_attr(
64 all(feature = "ffi", not(test)),
65 safer_ffi_gen::ffi_type(clone, opaque)
66)]
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68pub struct CustomCredential {
76 pub credential_type: CredentialType,
78 #[mls_codec(with = "mls_rs_codec::byte_vec")]
80 #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
81 pub data: Vec<u8>,
82}
83
84impl Debug for CustomCredential {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 f.debug_struct("CustomCredential")
87 .field("credential_type", &self.credential_type)
88 .field("data", &crate::debug::pretty_bytes(&self.data))
89 .finish()
90 }
91}
92
93#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
94impl CustomCredential {
95 pub fn new(credential_type: CredentialType, data: Vec<u8>) -> CustomCredential {
102 CustomCredential {
103 credential_type,
104 data,
105 }
106 }
107
108 #[cfg(feature = "ffi")]
110 pub fn credential_type(&self) -> CredentialType {
111 self.credential_type
112 }
113
114 #[cfg(feature = "ffi")]
116 pub fn data(&self) -> &[u8] {
117 &self.data
118 }
119}
120
121#[derive(Clone, Debug, PartialEq, Ord, PartialOrd, Eq, Hash)]
123#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
124#[cfg_attr(
125 all(feature = "ffi", not(test)),
126 safer_ffi_gen::ffi_type(clone, opaque)
127)]
128#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
129#[non_exhaustive]
130pub enum Credential {
131 Basic(BasicCredential),
139 #[cfg(feature = "x509")]
140 X509(CertificateChain),
142 Custom(CustomCredential),
144}
145
146impl Credential {
147 pub fn credential_type(&self) -> CredentialType {
149 match self {
150 Credential::Basic(_) => CredentialType::BASIC,
151 #[cfg(feature = "x509")]
152 Credential::X509(_) => CredentialType::X509,
153 Credential::Custom(c) => c.credential_type,
154 }
155 }
156
157 pub fn as_basic(&self) -> Option<&BasicCredential> {
161 match self {
162 Credential::Basic(basic) => Some(basic),
163 _ => None,
164 }
165 }
166
167 #[cfg(feature = "x509")]
171 pub fn as_x509(&self) -> Option<&CertificateChain> {
172 match self {
173 Credential::X509(chain) => Some(chain),
174 _ => None,
175 }
176 }
177
178 pub fn as_custom(&self) -> Option<&CustomCredential> {
182 match self {
183 Credential::Custom(custom) => Some(custom),
184 _ => None,
185 }
186 }
187}
188
189impl MlsSize for Credential {
190 fn mls_encoded_len(&self) -> usize {
191 let inner_len = match self {
192 Credential::Basic(c) => c.mls_encoded_len(),
193 #[cfg(feature = "x509")]
194 Credential::X509(c) => c.mls_encoded_len(),
195 Credential::Custom(c) => mls_rs_codec::byte_vec::mls_encoded_len(&c.data),
196 };
197
198 self.credential_type().mls_encoded_len() + inner_len
199 }
200}
201
202impl MlsEncode for Credential {
203 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
204 self.credential_type().mls_encode(writer)?;
205
206 match self {
207 Credential::Basic(c) => c.mls_encode(writer),
208 #[cfg(feature = "x509")]
209 Credential::X509(c) => c.mls_encode(writer),
210 Credential::Custom(c) => mls_rs_codec::byte_vec::mls_encode(&c.data, writer),
211 }
212 }
213}
214
215impl MlsDecode for Credential {
216 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
217 let credential_type = CredentialType::mls_decode(reader)?;
218
219 Ok(match credential_type {
220 CredentialType::BASIC => Credential::Basic(BasicCredential::mls_decode(reader)?),
221 #[cfg(feature = "x509")]
222 CredentialType::X509 => Credential::X509(CertificateChain::mls_decode(reader)?),
223 custom => Credential::Custom(CustomCredential {
224 credential_type: custom,
225 data: mls_rs_codec::byte_vec::mls_decode(reader)?,
226 }),
227 })
228 }
229}
230
231pub trait MlsCredential: Sized {
234 type Error;
236
237 fn credential_type() -> CredentialType;
239
240 fn into_credential(self) -> Result<Credential, Self::Error>;
242}