1use base64ct::{Base64UrlUnpadded, Encoding};
4use parity_scale_codec::Encode;
5use scale_info::TypeInfo;
6use serde::{de::DeserializeOwned, Serialize};
7
8use core::{marker::PhantomData, num::NonZeroUsize};
9
10#[cfg(feature = "ciborium")]
11use crate::error::CborSerError;
12use crate::{
13 alloc::{Cow, String, ToOwned, Vec},
14 token::CompleteHeader,
15 Claims, CreationError, Header, SignedToken, Token, UntrustedToken, ValidationError,
16};
17
18pub trait AlgorithmSignature: Sized {
23 const LENGTH: Option<NonZeroUsize> = None;
35
36 fn try_from_slice(slice: &[u8]) -> anyhow::Result<Self>;
39
40 fn as_bytes(&self) -> Cow<'_, [u8]>;
42}
43
44pub trait Algorithm {
46 type SigningKey;
48 type VerifyingKey;
51 type Signature: AlgorithmSignature;
53
54 fn name(&self) -> Cow<'static, str>;
56
57 fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature;
59
60 fn verify_signature(&self, signature: &Self::Signature, verifying_key: &Self::VerifyingKey, message: &[u8])
62 -> bool;
63}
64
65#[derive(Debug, Clone, Copy, TypeInfo, Encode)]
89pub struct Renamed<A> {
90 inner: A,
91 name: &'static str,
92}
93
94impl<A: Algorithm> Renamed<A> {
95 pub const fn new(algorithm: A, new_name: &'static str) -> Self {
97 Self { inner: algorithm, name: new_name }
98 }
99}
100
101impl<A: Algorithm> Algorithm for Renamed<A> {
102 type Signature = A::Signature;
103 type SigningKey = A::SigningKey;
104 type VerifyingKey = A::VerifyingKey;
105
106 fn name(&self) -> Cow<'static, str> {
107 Cow::Borrowed(self.name)
108 }
109
110 fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
111 self.inner.sign(signing_key, message)
112 }
113
114 fn verify_signature(
115 &self,
116 signature: &Self::Signature,
117 verifying_key: &Self::VerifyingKey,
118 message: &[u8],
119 ) -> bool {
120 self.inner.verify_signature(signature, verifying_key, message)
121 }
122}
123
124pub trait AlgorithmExt: Algorithm {
126 fn token<T>(
128 &self,
129 header: &Header<impl Serialize>,
130 claims: &Claims<T>,
131 signing_key: &Self::SigningKey,
132 ) -> Result<String, CreationError>
133 where
134 T: Serialize;
135
136 #[cfg(feature = "ciborium")]
138 #[cfg_attr(docsrs, doc(cfg(feature = "ciborium")))]
139 fn compact_token<T>(
140 &self,
141 header: &Header<impl Serialize>,
142 claims: &Claims<T>,
143 signing_key: &Self::SigningKey,
144 ) -> Result<String, CreationError>
145 where
146 T: Serialize;
147
148 fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T>;
151
152 #[deprecated = "Use `.validator().validate()` for added flexibility"]
154 fn validate_integrity<T>(
155 &self,
156 token: &UntrustedToken,
157 verifying_key: &Self::VerifyingKey,
158 ) -> Result<Token<T>, ValidationError>
159 where
160 T: DeserializeOwned;
161
162 #[deprecated = "Use `.validator().validate_for_signed_token()` for added flexibility"]
167 fn validate_for_signed_token<T>(
168 &self,
169 token: &UntrustedToken,
170 verifying_key: &Self::VerifyingKey,
171 ) -> Result<SignedToken<Self, T>, ValidationError>
172 where
173 T: DeserializeOwned;
174}
175
176impl<A: Algorithm> AlgorithmExt for A {
177 fn token<T>(
178 &self,
179 header: &Header<impl Serialize>,
180 claims: &Claims<T>,
181 signing_key: &Self::SigningKey,
182 ) -> Result<String, CreationError>
183 where
184 T: Serialize,
185 {
186 let complete_header = CompleteHeader { algorithm: self.name(), content_type: None, inner: header };
187 let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
188 let mut buffer = Vec::new();
189 encode_base64_buf(&header, &mut buffer);
190
191 let claims = serde_json::to_string(claims).map_err(CreationError::Claims)?;
192 buffer.push(b'.');
193 encode_base64_buf(&claims, &mut buffer);
194
195 let signature = self.sign(signing_key, &buffer);
196 buffer.push(b'.');
197 encode_base64_buf(signature.as_bytes(), &mut buffer);
198
199 Ok(unsafe { String::from_utf8_unchecked(buffer) })
201 }
202
203 #[cfg(feature = "ciborium")]
204 fn compact_token<T>(
205 &self,
206 header: &Header<impl Serialize>,
207 claims: &Claims<T>,
208 signing_key: &Self::SigningKey,
209 ) -> Result<String, CreationError>
210 where
211 T: Serialize,
212 {
213 let complete_header =
214 CompleteHeader { algorithm: self.name(), content_type: Some("CBOR".to_owned()), inner: header };
215 let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
216 let mut buffer = Vec::new();
217 encode_base64_buf(&header, &mut buffer);
218
219 let mut serialized_claims = vec![];
220 ciborium::into_writer(claims, &mut serialized_claims).map_err(|err| {
221 CreationError::CborClaims(match err {
222 CborSerError::Value(message) => CborSerError::Value(message),
223 CborSerError::Io(_) => unreachable!(), })
225 })?;
226 buffer.push(b'.');
227 encode_base64_buf(&serialized_claims, &mut buffer);
228
229 let signature = self.sign(signing_key, &buffer);
230 buffer.push(b'.');
231 encode_base64_buf(signature.as_bytes(), &mut buffer);
232
233 Ok(unsafe { String::from_utf8_unchecked(buffer) })
235 }
236
237 fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T> {
238 Validator { algorithm: self, verifying_key, _claims: PhantomData }
239 }
240
241 fn validate_integrity<T>(
242 &self,
243 token: &UntrustedToken,
244 verifying_key: &Self::VerifyingKey,
245 ) -> Result<Token<T>, ValidationError>
246 where
247 T: DeserializeOwned,
248 {
249 self.validator::<T>(verifying_key).validate(token)
250 }
251
252 fn validate_for_signed_token<T>(
253 &self,
254 token: &UntrustedToken,
255 verifying_key: &Self::VerifyingKey,
256 ) -> Result<SignedToken<Self, T>, ValidationError>
257 where
258 T: DeserializeOwned,
259 {
260 self.validator::<T>(verifying_key).validate_for_signed_token(token)
261 }
262}
263
264#[derive(Debug)]
267pub struct Validator<'a, A: Algorithm + ?Sized, T> {
268 algorithm: &'a A,
269 verifying_key: &'a A::VerifyingKey,
270 _claims: PhantomData<fn() -> T>,
271}
272
273impl<A: Algorithm + ?Sized, T> Clone for Validator<'_, A, T> {
274 fn clone(&self) -> Self {
275 *self
276 }
277}
278
279impl<A: Algorithm + ?Sized, T> Copy for Validator<'_, A, T> {}
280
281impl<A: Algorithm + ?Sized, T: DeserializeOwned> Validator<'_, A, T> {
282 pub fn validate<H: Clone>(self, token: &UntrustedToken<H>) -> Result<Token<T, H>, ValidationError> {
284 self.validate_for_signed_token(token).map(|signed| signed.token)
285 }
286
287 pub fn validate_for_signed_token<H: Clone>(
290 self,
291 token: &UntrustedToken<H>,
292 ) -> Result<SignedToken<A, T, H>, ValidationError> {
293 let expected_alg = self.algorithm.name();
294 if expected_alg != token.algorithm() {
295 return Err(ValidationError::AlgorithmMismatch {
296 expected: expected_alg.into_owned(),
297 actual: token.algorithm().to_owned(),
298 });
299 }
300
301 let signature = token.signature_bytes();
302 if let Some(expected_len) = A::Signature::LENGTH {
303 if signature.len() != expected_len.get() {
304 return Err(ValidationError::InvalidSignatureLen {
305 expected: expected_len.get(),
306 actual: signature.len(),
307 });
308 }
309 }
310
311 let signature = A::Signature::try_from_slice(signature).map_err(ValidationError::MalformedSignature)?;
312 let claims = token.deserialize_claims_unchecked::<T>()?;
315 if !self.algorithm.verify_signature(&signature, self.verifying_key, &token.signed_data) {
316 return Err(ValidationError::InvalidSignature);
317 }
318
319 Ok(SignedToken { signature, token: Token::new(token.header().clone(), claims) })
320 }
321}
322
323fn encode_base64_buf(source: impl AsRef<[u8]>, buffer: &mut Vec<u8>) {
324 let source = source.as_ref();
325 let previous_len = buffer.len();
326 let claims_len = Base64UrlUnpadded::encoded_len(source);
327 buffer.resize(previous_len + claims_len, 0);
328 Base64UrlUnpadded::encode(source, &mut buffer[previous_len..])
329 .expect("miscalculated base64-encoded length; this should never happen");
330}