1use base64ct::{Base64UrlUnpadded, Encoding};
4use serde::{de::DeserializeOwned, Serialize};
5
6use core::{marker::PhantomData, num::NonZeroUsize};
7
8#[cfg(feature = "ciborium")]
9use crate::error::CborSerError;
10use crate::{
11 alloc::{Cow, String, ToOwned, Vec},
12 token::CompleteHeader,
13 Claims, CreationError, Header, SignedToken, Token, UntrustedToken, ValidationError,
14};
15
16pub trait AlgorithmSignature: Sized {
21 const LENGTH: Option<NonZeroUsize> = None;
33
34 fn try_from_slice(slice: &[u8]) -> anyhow::Result<Self>;
37
38 fn as_bytes(&self) -> Cow<'_, [u8]>;
40}
41
42pub trait Algorithm {
44 type SigningKey;
46 type VerifyingKey;
49 type Signature: AlgorithmSignature;
51
52 fn name(&self) -> Cow<'static, str>;
54
55 fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature;
57
58 fn verify_signature(
60 &self,
61 signature: &Self::Signature,
62 verifying_key: &Self::VerifyingKey,
63 message: &[u8],
64 ) -> bool;
65}
66
67#[derive(Debug, Clone, Copy)]
91pub struct Renamed<A> {
92 inner: A,
93 name: &'static str,
94}
95
96impl<A: Algorithm> Renamed<A> {
97 pub fn new(algorithm: A, new_name: &'static str) -> Self {
99 Self {
100 inner: algorithm,
101 name: new_name,
102 }
103 }
104}
105
106impl<A: Algorithm> Algorithm for Renamed<A> {
107 type SigningKey = A::SigningKey;
108 type VerifyingKey = A::VerifyingKey;
109 type Signature = A::Signature;
110
111 fn name(&self) -> Cow<'static, str> {
112 Cow::Borrowed(self.name)
113 }
114
115 fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
116 self.inner.sign(signing_key, message)
117 }
118
119 fn verify_signature(
120 &self,
121 signature: &Self::Signature,
122 verifying_key: &Self::VerifyingKey,
123 message: &[u8],
124 ) -> bool {
125 self.inner
126 .verify_signature(signature, verifying_key, message)
127 }
128}
129
130pub trait AlgorithmExt: Algorithm {
132 fn token<T>(
134 &self,
135 header: &Header<impl Serialize>,
136 claims: &Claims<T>,
137 signing_key: &Self::SigningKey,
138 ) -> Result<String, CreationError>
139 where
140 T: Serialize;
141
142 #[cfg(feature = "ciborium")]
144 #[cfg_attr(docsrs, doc(cfg(feature = "ciborium")))]
145 fn compact_token<T>(
146 &self,
147 header: &Header<impl Serialize>,
148 claims: &Claims<T>,
149 signing_key: &Self::SigningKey,
150 ) -> Result<String, CreationError>
151 where
152 T: Serialize;
153
154 fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T>;
157
158 #[deprecated = "Use `.validator().validate()` for added flexibility"]
160 fn validate_integrity<T>(
161 &self,
162 token: &UntrustedToken<'_>,
163 verifying_key: &Self::VerifyingKey,
164 ) -> Result<Token<T>, ValidationError>
165 where
166 T: DeserializeOwned;
167
168 #[deprecated = "Use `.validator().validate_for_signed_token()` for added flexibility"]
173 fn validate_for_signed_token<T>(
174 &self,
175 token: &UntrustedToken<'_>,
176 verifying_key: &Self::VerifyingKey,
177 ) -> Result<SignedToken<Self, T>, ValidationError>
178 where
179 T: DeserializeOwned;
180}
181
182impl<A: Algorithm> AlgorithmExt for A {
183 fn token<T>(
184 &self,
185 header: &Header<impl Serialize>,
186 claims: &Claims<T>,
187 signing_key: &Self::SigningKey,
188 ) -> Result<String, CreationError>
189 where
190 T: Serialize,
191 {
192 let complete_header = CompleteHeader {
193 algorithm: self.name(),
194 content_type: None,
195 inner: header,
196 };
197 let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
198 let mut buffer = Vec::new();
199 encode_base64_buf(&header, &mut buffer);
200
201 let claims = serde_json::to_string(claims).map_err(CreationError::Claims)?;
202 buffer.push(b'.');
203 encode_base64_buf(&claims, &mut buffer);
204
205 let signature = self.sign(signing_key, &buffer);
206 buffer.push(b'.');
207 encode_base64_buf(signature.as_bytes(), &mut buffer);
208
209 Ok(unsafe { String::from_utf8_unchecked(buffer) })
211 }
212
213 #[cfg(feature = "ciborium")]
214 fn compact_token<T>(
215 &self,
216 header: &Header<impl Serialize>,
217 claims: &Claims<T>,
218 signing_key: &Self::SigningKey,
219 ) -> Result<String, CreationError>
220 where
221 T: Serialize,
222 {
223 let complete_header = CompleteHeader {
224 algorithm: self.name(),
225 content_type: Some("CBOR".to_owned()),
226 inner: header,
227 };
228 let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
229 let mut buffer = Vec::new();
230 encode_base64_buf(&header, &mut buffer);
231
232 let mut serialized_claims = vec![];
233 ciborium::into_writer(claims, &mut serialized_claims).map_err(|err| {
234 CreationError::CborClaims(match err {
235 CborSerError::Value(message) => CborSerError::Value(message),
236 CborSerError::Io(_) => unreachable!(), })
238 })?;
239 buffer.push(b'.');
240 encode_base64_buf(&serialized_claims, &mut buffer);
241
242 let signature = self.sign(signing_key, &buffer);
243 buffer.push(b'.');
244 encode_base64_buf(signature.as_bytes(), &mut buffer);
245
246 Ok(unsafe { String::from_utf8_unchecked(buffer) })
248 }
249
250 fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T> {
251 Validator {
252 algorithm: self,
253 verifying_key,
254 _claims: PhantomData,
255 }
256 }
257
258 fn validate_integrity<T>(
259 &self,
260 token: &UntrustedToken<'_>,
261 verifying_key: &Self::VerifyingKey,
262 ) -> Result<Token<T>, ValidationError>
263 where
264 T: DeserializeOwned,
265 {
266 self.validator::<T>(verifying_key).validate(token)
267 }
268
269 fn validate_for_signed_token<T>(
270 &self,
271 token: &UntrustedToken<'_>,
272 verifying_key: &Self::VerifyingKey,
273 ) -> Result<SignedToken<Self, T>, ValidationError>
274 where
275 T: DeserializeOwned,
276 {
277 self.validator::<T>(verifying_key)
278 .validate_for_signed_token(token)
279 }
280}
281
282#[derive(Debug)]
285pub struct Validator<'a, A: Algorithm + ?Sized, T> {
286 algorithm: &'a A,
287 verifying_key: &'a A::VerifyingKey,
288 _claims: PhantomData<fn() -> T>,
289}
290
291impl<A: Algorithm + ?Sized, T> Clone for Validator<'_, A, T> {
292 fn clone(&self) -> Self {
293 *self
294 }
295}
296
297impl<A: Algorithm + ?Sized, T> Copy for Validator<'_, A, T> {}
298
299impl<A: Algorithm + ?Sized, T: DeserializeOwned> Validator<'_, A, T> {
300 pub fn validate<H: Clone>(
302 self,
303 token: &UntrustedToken<'_, H>,
304 ) -> Result<Token<T, H>, ValidationError> {
305 self.validate_for_signed_token(token)
306 .map(|signed| signed.token)
307 }
308
309 pub fn validate_for_signed_token<H: Clone>(
312 self,
313 token: &UntrustedToken<'_, H>,
314 ) -> Result<SignedToken<A, T, H>, ValidationError> {
315 let expected_alg = self.algorithm.name();
316 if expected_alg != token.algorithm() {
317 return Err(ValidationError::AlgorithmMismatch {
318 expected: expected_alg.into_owned(),
319 actual: token.algorithm().to_owned(),
320 });
321 }
322
323 let signature = token.signature_bytes();
324 if let Some(expected_len) = A::Signature::LENGTH {
325 if signature.len() != expected_len.get() {
326 return Err(ValidationError::InvalidSignatureLen {
327 expected: expected_len.get(),
328 actual: signature.len(),
329 });
330 }
331 }
332
333 let signature =
334 A::Signature::try_from_slice(signature).map_err(ValidationError::MalformedSignature)?;
335 let claims = token.deserialize_claims_unchecked::<T>()?;
338 if !self
339 .algorithm
340 .verify_signature(&signature, self.verifying_key, &token.signed_data)
341 {
342 return Err(ValidationError::InvalidSignature);
343 }
344
345 Ok(SignedToken {
346 signature,
347 token: Token::new(token.header().clone(), claims),
348 })
349 }
350}
351
352fn encode_base64_buf(source: impl AsRef<[u8]>, buffer: &mut Vec<u8>) {
353 let source = source.as_ref();
354 let previous_len = buffer.len();
355 let claims_len = Base64UrlUnpadded::encoded_len(source);
356 buffer.resize(previous_len + claims_len, 0);
357 Base64UrlUnpadded::encode(source, &mut buffer[previous_len..])
358 .expect("miscalculated base64-encoded length; this should never happen");
359}