keygate_jwt/algorithms/
es256.rs

1use std::convert::TryFrom;
2
3use base64ct::{Base64UrlUnpadded, Encoding};
4use p256::ecdsa::{self, signature::DigestVerifier as _, signature::RandomizedDigestSigner as _};
5use p256::pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey};
6use p256::NonZeroScalar;
7use serde::{de::DeserializeOwned, Serialize};
8use sha2::{Digest, Sha256};
9
10use crate::claims::*;
11use crate::common::*;
12use crate::error::*;
13use crate::jwt_header::*;
14use crate::token::*;
15
16#[doc(hidden)]
17#[derive(Debug, Clone)]
18pub struct P256PublicKey(ecdsa::VerifyingKey);
19
20impl AsRef<ecdsa::VerifyingKey> for P256PublicKey {
21    fn as_ref(&self) -> &ecdsa::VerifyingKey {
22        &self.0
23    }
24}
25
26impl P256PublicKey {
27    pub fn from_bytes(raw: &[u8]) -> Result<Self, JWTError> {
28        let p256_pk =
29            ecdsa::VerifyingKey::from_sec1_bytes(raw).map_err(|_| JWTError::InvalidPublicKey)?;
30        Ok(P256PublicKey(p256_pk))
31    }
32
33    pub fn from_der(der: &[u8]) -> Result<Self, JWTError> {
34        let p256_pk = ecdsa::VerifyingKey::from_public_key_der(der)
35            .map_err(|_| JWTError::InvalidPublicKey)?;
36        Ok(P256PublicKey(p256_pk))
37    }
38
39    pub fn from_pem(pem: &str) -> Result<Self, JWTError> {
40        let p256_pk = ecdsa::VerifyingKey::from_public_key_pem(pem)
41            .map_err(|_| JWTError::InvalidPublicKey)?;
42        Ok(P256PublicKey(p256_pk))
43    }
44
45    pub fn to_bytes(&self) -> Vec<u8> {
46        self.0.to_encoded_point(true).as_bytes().to_vec()
47    }
48
49    pub fn to_bytes_uncompressed(&self) -> Vec<u8> {
50        self.0.to_encoded_point(false).as_bytes().to_vec()
51    }
52
53    pub fn to_der(&self) -> Result<Vec<u8>, JWTError> {
54        let p256_pk = p256::PublicKey::from(self.0);
55        Ok(p256_pk
56            .to_public_key_der()
57            .map_err(|_| JWTError::InvalidPublicKey)?
58            .as_ref()
59            .to_vec())
60    }
61
62    pub fn to_pem(&self) -> Result<String, JWTError> {
63        let p256_pk = p256::PublicKey::from(self.0);
64        p256_pk
65            .to_public_key_pem(Default::default())
66            .map_err(|_| JWTError::InvalidPublicKey)
67    }
68}
69
70#[doc(hidden)]
71pub struct P256KeyPair {
72    p256_sk: ecdsa::SigningKey,
73    metadata: Option<KeyMetadata>,
74}
75
76impl AsRef<ecdsa::SigningKey> for P256KeyPair {
77    fn as_ref(&self) -> &ecdsa::SigningKey {
78        &self.p256_sk
79    }
80}
81
82impl P256KeyPair {
83    pub fn from_bytes(raw: &[u8]) -> Result<Self, JWTError> {
84        let p256_sk =
85            ecdsa::SigningKey::from_bytes(raw.into()).map_err(|_| JWTError::InvalidKeyPair)?;
86        Ok(P256KeyPair {
87            p256_sk,
88            metadata: None,
89        })
90    }
91
92    pub fn from_der(der: &[u8]) -> Result<Self, JWTError> {
93        let p256_sk =
94            ecdsa::SigningKey::from_pkcs8_der(der).map_err(|_| JWTError::InvalidKeyPair)?;
95        Ok(P256KeyPair {
96            p256_sk,
97            metadata: None,
98        })
99    }
100
101    pub fn from_pem(pem: &str) -> Result<Self, JWTError> {
102        let p256_sk =
103            ecdsa::SigningKey::from_pkcs8_pem(pem).map_err(|_| JWTError::InvalidKeyPair)?;
104        Ok(P256KeyPair {
105            p256_sk,
106            metadata: None,
107        })
108    }
109
110    pub fn to_bytes(&self) -> Vec<u8> {
111        self.p256_sk.to_bytes().to_vec()
112    }
113
114    pub fn to_der(&self) -> Result<Vec<u8>, JWTError> {
115        let scalar = NonZeroScalar::from_repr(self.p256_sk.to_bytes());
116        if bool::from(scalar.is_none()) {
117            return Err(JWTError::InvalidKeyPair);
118        }
119        let p256_sk =
120            p256::SecretKey::from(NonZeroScalar::from_repr(scalar.unwrap().into()).unwrap());
121        Ok(p256_sk
122            .to_pkcs8_der()
123            .map_err(|_| JWTError::InvalidKeyPair)?
124            .as_bytes()
125            .to_vec())
126    }
127
128    pub fn to_pem(&self) -> Result<String, JWTError> {
129        let scalar = NonZeroScalar::from_repr(self.p256_sk.to_bytes());
130        if bool::from(scalar.is_none()) {
131            return Err(JWTError::InvalidKeyPair);
132        }
133        let p256_sk =
134            p256::SecretKey::from(NonZeroScalar::from_repr(scalar.unwrap().into()).unwrap());
135        Ok(p256_sk
136            .to_pkcs8_pem(Default::default())
137            .map_err(|_| JWTError::InvalidKeyPair)?
138            .to_string())
139    }
140
141    pub fn public_key(&self) -> P256PublicKey {
142        let p256_pk = self.p256_sk.verifying_key();
143        P256PublicKey(*p256_pk)
144    }
145
146    pub fn generate() -> Self {
147        let mut rng = rand::thread_rng();
148        let p256_sk = ecdsa::SigningKey::random(&mut rng);
149        P256KeyPair {
150            p256_sk,
151            metadata: None,
152        }
153    }
154}
155
156pub trait ECDSAP256KeyPairLike {
157    fn jwt_alg_name() -> &'static str;
158    fn key_pair(&self) -> &P256KeyPair;
159    fn key_id(&self) -> &Option<String>;
160    fn metadata(&self) -> &Option<KeyMetadata>;
161    fn attach_metadata(&mut self, metadata: KeyMetadata) -> Result<(), JWTError>;
162
163    fn sign<CustomClaims: Serialize + DeserializeOwned>(
164        &self,
165        claims: JWTClaims<CustomClaims>,
166    ) -> Result<String, JWTError> {
167        let jwt_header = JWTHeader::new(Self::jwt_alg_name().to_string(), self.key_id().clone())
168            .with_metadata(self.metadata());
169        Token::build(&jwt_header, claims, |authenticated| {
170            let mut digest = Sha256::new();
171            digest.update(authenticated.as_bytes());
172            let mut rng = rand::thread_rng();
173            let signature: ecdsa::Signature = self
174                .key_pair()
175                .as_ref()
176                .sign_digest_with_rng(&mut rng, digest);
177
178            Ok(signature.to_vec())
179        })
180    }
181}
182
183pub trait ECDSAP256PublicKeyLike {
184    fn jwt_alg_name() -> &'static str;
185    fn public_key(&self) -> &P256PublicKey;
186    fn key_id(&self) -> &Option<String>;
187    fn set_key_id(&mut self, key_id: String);
188
189    fn verify_token<CustomClaims: Serialize + DeserializeOwned>(
190        &self,
191        token: &str,
192        options: Option<VerificationOptions>,
193    ) -> Result<JWTClaims<CustomClaims>, JWTError> {
194        Token::verify(
195            Self::jwt_alg_name(),
196            token,
197            options,
198            |authenticated, signature| {
199                let ecdsa_signature = ecdsa::Signature::try_from(signature)
200                    .map_err(|_| JWTError::InvalidSignature)?;
201                let mut digest = Sha256::new();
202                digest.update(authenticated.as_bytes());
203                self.public_key()
204                    .as_ref()
205                    .verify_digest(digest, &ecdsa_signature)
206                    .map_err(|_| JWTError::InvalidSignature)?;
207                Ok(())
208            },
209        )
210    }
211
212    fn create_key_id(&mut self) -> Result<String, JWTError> {
213        let mut hasher = sha2::Sha256::new();
214        hasher.update(self.public_key().to_bytes());
215        let key_id = Base64UrlUnpadded::encode_string(&hasher.finalize());
216        self.set_key_id(key_id.clone());
217        Ok(key_id)
218    }
219}
220
221pub struct ES256KeyPair {
222    key_pair: P256KeyPair,
223    key_id: Option<String>,
224}
225
226#[derive(Debug, Clone)]
227pub struct ES256PublicKey {
228    pk: P256PublicKey,
229    key_id: Option<String>,
230}
231
232impl ECDSAP256KeyPairLike for ES256KeyPair {
233    fn jwt_alg_name() -> &'static str {
234        "ES256"
235    }
236
237    fn key_pair(&self) -> &P256KeyPair {
238        &self.key_pair
239    }
240
241    fn key_id(&self) -> &Option<String> {
242        &self.key_id
243    }
244
245    fn metadata(&self) -> &Option<KeyMetadata> {
246        &self.key_pair.metadata
247    }
248
249    fn attach_metadata(&mut self, metadata: KeyMetadata) -> Result<(), JWTError> {
250        self.key_pair.metadata = Some(metadata);
251        Ok(())
252    }
253}
254
255impl ES256KeyPair {
256    pub fn from_bytes(raw: &[u8]) -> Result<Self, JWTError> {
257        Ok(ES256KeyPair {
258            key_pair: P256KeyPair::from_bytes(raw)?,
259            key_id: None,
260        })
261    }
262
263    pub fn from_der(der: &[u8]) -> Result<Self, JWTError> {
264        Ok(ES256KeyPair {
265            key_pair: P256KeyPair::from_der(der)?,
266            key_id: None,
267        })
268    }
269
270    pub fn from_pem(pem: &str) -> Result<Self, JWTError> {
271        Ok(ES256KeyPair {
272            key_pair: P256KeyPair::from_pem(pem)?,
273            key_id: None,
274        })
275    }
276
277    pub fn to_bytes(&self) -> Vec<u8> {
278        self.key_pair.to_bytes()
279    }
280
281    pub fn to_der(&self) -> Result<Vec<u8>, JWTError> {
282        self.key_pair.to_der()
283    }
284
285    pub fn to_pem(&self) -> Result<String, JWTError> {
286        self.key_pair.to_pem()
287    }
288
289    pub fn public_key(&self) -> ES256PublicKey {
290        ES256PublicKey {
291            pk: self.key_pair.public_key(),
292            key_id: self.key_id.clone(),
293        }
294    }
295
296    pub fn generate() -> Self {
297        ES256KeyPair {
298            key_pair: P256KeyPair::generate(),
299            key_id: None,
300        }
301    }
302
303    pub fn with_key_id(mut self, key_id: &str) -> Self {
304        self.key_id = Some(key_id.to_string());
305        self
306    }
307}
308
309impl ECDSAP256PublicKeyLike for ES256PublicKey {
310    fn jwt_alg_name() -> &'static str {
311        "ES256"
312    }
313
314    fn public_key(&self) -> &P256PublicKey {
315        &self.pk
316    }
317
318    fn key_id(&self) -> &Option<String> {
319        &self.key_id
320    }
321
322    fn set_key_id(&mut self, key_id: String) {
323        self.key_id = Some(key_id);
324    }
325}
326
327impl ES256PublicKey {
328    pub fn from_bytes(raw: &[u8]) -> Result<Self, JWTError> {
329        Ok(ES256PublicKey {
330            pk: P256PublicKey::from_bytes(raw)?,
331            key_id: None,
332        })
333    }
334
335    pub fn from_der(der: &[u8]) -> Result<Self, JWTError> {
336        Ok(ES256PublicKey {
337            pk: P256PublicKey::from_der(der)?,
338            key_id: None,
339        })
340    }
341
342    pub fn from_pem(pem: &str) -> Result<Self, JWTError> {
343        Ok(ES256PublicKey {
344            pk: P256PublicKey::from_pem(pem)?,
345            key_id: None,
346        })
347    }
348
349    pub fn to_bytes(&self) -> Vec<u8> {
350        self.pk.to_bytes()
351    }
352
353    pub fn to_der(&self) -> Result<Vec<u8>, JWTError> {
354        self.pk.to_der()
355    }
356
357    pub fn to_pem(&self) -> Result<String, JWTError> {
358        self.pk.to_pem()
359    }
360
361    pub fn with_key_id(mut self, key_id: &str) -> Self {
362        self.key_id = Some(key_id.to_string());
363        self
364    }
365}