protobuf_web_token/
lib.rs

1use std::{
2    fmt::Display,
3    time::{Duration, SystemTime},
4};
5
6use ed25519_dalek::{Signature, Signer as _, SigningKey, Verifier as _, VerifyingKey};
7use prost::Message;
8
9use base64::{engine::general_purpose, Engine as _};
10
11#[cfg(test)]
12mod jwt;
13
14mod proto {
15    include!(concat!(env!("OUT_DIR"), "/pwt.rs"));
16}
17pub extern crate ed25519_dalek as ed25519;
18
19#[derive(Clone)]
20pub struct Signer {
21    key: SigningKey,
22}
23
24#[derive(Copy, Clone, PartialEq, Eq)]
25pub struct Verifier {
26    key: VerifyingKey,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
30pub struct TokenData<CLAIMS> {
31    pub valid_until: SystemTime,
32    pub claims: CLAIMS,
33}
34
35struct Base64Claims<'a>(&'a str);
36
37struct Base64Signature<'a>(&'a str);
38
39struct BytesClaims(Vec<u8>);
40
41#[derive(Debug, PartialEq, Eq)]
42pub enum Error {
43    InvalidFormat,
44    InvalidBase64,
45    InvalidSignature,
46    SignatureMismatch,
47    ProtobufDecodeError,
48    MissingValidUntil,
49    TokenExpired,
50}
51
52impl Signer {
53    /// Creates a new `Signer` from an ed25519 `SigningKey` (private key)
54    pub fn new(key: SigningKey) -> Self {
55        Signer { key }
56    }
57
58    /// Creates a `Verifier` which can decode and verify PWT but can not sign them
59    pub fn as_verifier(&self) -> Verifier {
60        Verifier {
61            key: self.key.verifying_key(),
62        }
63    }
64
65    /// Encodes a `Message` into a PWT. Uses the URL-safe string representation:
66    /// {data_in_base64}.{signature_of_encoded_bytes_in_base64}.
67    pub fn sign<T: Message>(&self, data: &T, valid_for: Duration) -> String {
68        let proto_token = self.create_proto_token(data, valid_for);
69        let (base64, signature) = self.sign_proto_token(&proto_token);
70        format!("{base64}.{signature}")
71    }
72
73    /// Encodes a `Message` into a PWT. Uses the compact byte representation via a protobuf message with
74    /// 2 fields (data and signature).
75    pub fn sign_to_bytes<T: Message>(&self, data: &T, valid_for: Duration) -> Vec<u8> {
76        let proto_token = self.create_proto_token(data, valid_for);
77        let bytes = proto_token.encode_to_vec();
78        let signature = self.key.sign(&bytes);
79        proto::SignedToken {
80            data: bytes,
81            signature: signature.to_bytes().to_vec(),
82        }
83        .encode_to_vec()
84    }
85
86    fn create_proto_token<T: Message>(&self, data: &T, valid_for: Duration) -> proto::Token {
87        let bytes = data.encode_to_vec();
88        proto::Token {
89            valid_until: Some((SystemTime::now() + valid_for).into()),
90            claims: bytes,
91        }
92    }
93
94    fn sign_proto_token(&self, proto_token: &proto::Token) -> (String, String) {
95        let bytes = proto_token.encode_to_vec();
96        let signature = self.key.sign(&bytes);
97        let base64 = general_purpose::URL_SAFE_NO_PAD.encode(&bytes);
98        let signature = general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes());
99        (base64, signature)
100    }
101}
102
103impl Verifier {
104    /// Creates a new `Verifier` from an ed25519 VerifyingKey
105    pub fn new(key: VerifyingKey) -> Self {
106        Self { key }
107    }
108
109    pub fn verify<CLAIMS: Message + Default>(
110        &self,
111        token: &str,
112    ) -> Result<TokenData<CLAIMS>, Error> {
113        let (claims, signature) = parse_token(token)?;
114        let bytes = claims.to_bytes()?;
115        self.verify_signature(&bytes, &signature)?;
116
117        let token_data = bytes.decode_metadata()?;
118        let claims =
119            CLAIMS::decode(token_data.claims.as_slice()).map_err(|_| Error::ProtobufDecodeError)?;
120        Ok(TokenData {
121            valid_until: token_data.valid_until,
122            claims,
123        })
124    }
125
126    pub fn verify_bytes<CLAIMS: Message + Default>(
127        &self,
128        token: &[u8],
129    ) -> Result<TokenData<CLAIMS>, Error> {
130        let proto::SignedToken { data, signature } =
131            proto::SignedToken::decode(token).map_err(|_| Error::ProtobufDecodeError)?;
132        let signature = Signature::from_slice(&signature).map_err(|_| Error::InvalidSignature)?;
133        self.key
134            .verify(&data, &signature)
135            .map_err(|_| Error::SignatureMismatch)?;
136
137        let token_data = BytesClaims(data).decode_metadata()?;
138        let claims =
139            CLAIMS::decode(token_data.claims.as_slice()).map_err(|_| Error::ProtobufDecodeError)?;
140        Ok(TokenData {
141            valid_until: token_data.valid_until,
142            claims,
143        })
144    }
145
146    pub fn verify_and_check_expiry<CLAIMS: Message + Default>(
147        &self,
148        token: &str,
149    ) -> Result<CLAIMS, Error> {
150        let (claims, signature) = parse_token(token)?;
151        let bytes = claims.to_bytes()?;
152        self.verify_signature(&bytes, &signature)?;
153
154        let token_data = bytes.decode_metadata()?;
155
156        let now = SystemTime::now();
157        if now > token_data.valid_until {
158            return Result::Err(Error::TokenExpired);
159        }
160
161        CLAIMS::decode(token_data.claims.as_slice()).map_err(|_| Error::ProtobufDecodeError)
162    }
163
164    pub fn verify_bytes_and_check_expiry<CLAIMS: Message + Default>(
165        &self,
166        token: &[u8],
167    ) -> Result<CLAIMS, Error> {
168        let proto::SignedToken { data, signature } =
169            proto::SignedToken::decode(token).map_err(|_| Error::ProtobufDecodeError)?;
170        let signature = Signature::from_slice(&signature).map_err(|_| Error::InvalidSignature)?;
171        self.key
172            .verify(&data, &signature)
173            .map_err(|_| Error::SignatureMismatch)?;
174
175        let token_data = BytesClaims(data).decode_metadata()?;
176
177        let now = SystemTime::now();
178        if now > token_data.valid_until {
179            return Result::Err(Error::TokenExpired);
180        }
181
182        CLAIMS::decode(token_data.claims.as_slice()).map_err(|_| Error::ProtobufDecodeError)
183    }
184
185    fn verify_signature(
186        &self,
187        bytes: &BytesClaims,
188        signature: &Base64Signature,
189    ) -> Result<(), Error> {
190        let signature = general_purpose::URL_SAFE_NO_PAD
191            .decode(signature.0)
192            .map_err(|_| Error::InvalidBase64)?;
193        let signature =
194            Signature::from_slice(signature.as_slice()).map_err(|_| Error::InvalidSignature)?;
195
196        self.key
197            .verify(&bytes.0, &signature)
198            .map_err(|_| Error::SignatureMismatch)?;
199        Ok(())
200    }
201}
202
203impl<'a> Base64Claims<'a> {
204    pub fn to_bytes(&'a self) -> Result<BytesClaims, Error> {
205        general_purpose::URL_SAFE_NO_PAD
206            .decode(self.0)
207            .map(BytesClaims)
208            .map_err(|_| Error::InvalidBase64)
209    }
210}
211
212impl BytesClaims {
213    pub fn decode_metadata(&self) -> Result<TokenData<Vec<u8>>, Error> {
214        let token =
215            proto::Token::decode(self.0.as_slice()).map_err(|_| Error::ProtobufDecodeError)?;
216        let valid_until: SystemTime = token
217            .valid_until
218            .ok_or(Error::MissingValidUntil)?
219            .try_into()
220            .map_err(|_| Error::MissingValidUntil)?;
221        Ok(TokenData {
222            valid_until,
223            claims: token.claims,
224        })
225    }
226}
227
228fn parse_token(token: &str) -> Result<(Base64Claims<'_>, Base64Signature<'_>), Error> {
229    let (data, signature) = token.split_once('.').ok_or(Error::InvalidFormat)?;
230    Ok((Base64Claims(data), Base64Signature(signature)))
231}
232
233pub fn decode<CLAIMS: Message + Default>(token: &str) -> Result<TokenData<CLAIMS>, Error> {
234    let (data, _signature) = token.split_once('.').ok_or(Error::InvalidFormat)?;
235    let bytes = general_purpose::URL_SAFE_NO_PAD
236        .decode(data)
237        .map_err(|_| Error::InvalidBase64)?;
238
239    let decoded_metadata =
240        proto::Token::decode(bytes.as_slice()).map_err(|_| Error::ProtobufDecodeError)?;
241    let valid_until = decoded_metadata
242        .valid_until
243        .ok_or(Error::MissingValidUntil)?
244        .try_into()
245        .map_err(|_| Error::MissingValidUntil)?;
246    let claims = CLAIMS::decode(decoded_metadata.claims.as_slice())
247        .map_err(|_| Error::ProtobufDecodeError)?;
248    Ok(TokenData {
249        valid_until,
250        claims,
251    })
252}
253
254impl Display for Error {
255    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        match self {
257            Error::InvalidFormat => f.write_str(
258                "Invalid Token Format. Expected two string segments seperated by a dot ('.')",
259            ),
260            Error::InvalidBase64 => f.write_str(
261                "A part of the token was not valid base64 (A-Z, a-z, 0-9, -, _, no padding)",
262            ),
263            Error::InvalidSignature => {
264                f.write_str("The signature is not a valid Ed25519 signature")
265            }
266            Error::SignatureMismatch => f.write_str(
267                "The signature does not match the given data (probably the token was manipulated)",
268            ),
269            Error::ProtobufDecodeError => f.write_str(
270                "The data encoded in the token did not match the expected protobuf format.",
271            ),
272            Error::MissingValidUntil => {
273                f.write_str("The data encoded in the token did not include an expiry time")
274            }
275            Error::TokenExpired => f.write_str("The token is expired"),
276        }
277    }
278}
279
280impl std::error::Error for Error {}
281
282#[cfg(test)]
283mod tests {
284    use std::time::{Duration, SystemTime};
285
286    use ed25519::pkcs8::DecodePrivateKey;
287    use serde::Serialize;
288
289    use super::*;
290    use crate::jwt;
291
292    mod proto {
293        include!(concat!(env!("OUT_DIR"), "/test.rs"));
294    }
295
296    #[derive(Debug, Clone, Serialize)]
297    struct Simple {
298        some_claim: String,
299    }
300
301    fn init_signer() -> Signer {
302        let pem = std::fs::read("test_resources/private.pem").unwrap();
303        let pem = String::from_utf8(pem).unwrap();
304        let key = SigningKey::from_pkcs8_pem(&pem).unwrap();
305        Signer { key }
306    }
307
308    #[test]
309    fn happy_case() {
310        let pwt_signer = init_signer();
311        let simple = proto::Simple {
312            some_claim: "testabcd".to_string(),
313        };
314        let pwt = pwt_signer.sign(&simple, Duration::from_secs(5));
315        assert_eq!(
316            pwt_signer
317                .as_verifier()
318                .verify_and_check_expiry::<proto::Simple>(&pwt),
319            Result::Ok(simple)
320        );
321    }
322
323    #[test]
324    fn happy_case_bytes() {
325        let pwt_signer = init_signer();
326        let simple = proto::Simple {
327            some_claim: "testabcd".to_string(),
328        };
329        let pwt = pwt_signer.sign_to_bytes(&simple, Duration::from_secs(5));
330        println!("{}{pwt:?}", pwt.len());
331        assert_eq!(
332            pwt_signer
333                .as_verifier()
334                .verify_bytes_and_check_expiry::<proto::Simple>(&pwt),
335            Result::Ok(simple)
336        );
337    }
338
339    #[test]
340    fn signature_is_verified_and_prevents_tampering() {
341        let pwt_signer = init_signer();
342        let proto_token = pwt_signer.create_proto_token(
343            &proto::Simple {
344                some_claim: "test contents".to_string(),
345            },
346            Duration::from_secs(5),
347        );
348        let (_data, signature) = pwt_signer.sign_proto_token(&proto_token);
349        let other_proto_token = pwt_signer.create_proto_token(
350            &proto::Simple {
351                some_claim: "tampered contents".to_string(),
352            },
353            Duration::from_secs(5),
354        );
355        let (other_data, _) = pwt_signer.sign_proto_token(&other_proto_token);
356
357        let tampered_token = format!("{other_data}.{signature}");
358
359        assert_eq!(
360            pwt_signer
361                .as_verifier()
362                .verify::<proto::Simple>(&tampered_token),
363            Result::Err(Error::SignatureMismatch)
364        );
365    }
366
367    #[test]
368    fn signature_is_verified_and_prevents_tampering_bytes() {
369        let pwt_signer = init_signer();
370        let proto_token = pwt_signer.create_proto_token(
371            &proto::Simple {
372                some_claim: "test contents".to_string(),
373            },
374            Duration::from_secs(5),
375        );
376
377        let data = proto_token.encode_to_vec();
378        let signature = pwt_signer.key.sign(&data);
379        let other_proto_token = pwt_signer.create_proto_token(
380            &proto::Simple {
381                some_claim: "tampered contents".to_string(),
382            },
383            Duration::from_secs(5),
384        );
385        let other_data = other_proto_token.encode_to_vec();
386
387        let tampered_token = super::proto::SignedToken {
388            data: other_data,
389            signature: signature.to_bytes().to_vec(),
390        }
391        .encode_to_vec();
392
393        assert_eq!(
394            pwt_signer
395                .as_verifier()
396                .verify_bytes::<proto::Simple>(&tampered_token),
397            Result::Err(Error::SignatureMismatch)
398        );
399    }
400
401    #[test]
402    fn invalid_format() {
403        let pwt_signer = init_signer();
404        assert_eq!(
405            pwt_signer.as_verifier().verify::<()>("invalid"),
406            Result::Err(Error::InvalidFormat)
407        );
408    }
409
410    #[test]
411    fn invalid_base64() {
412        let pwt_signer = init_signer();
413        assert_eq!(
414            pwt_signer.as_verifier().verify::<()>("invalid.base64"),
415            Result::Err(Error::InvalidBase64)
416        );
417    }
418
419    #[test]
420    fn invalid_signature() {
421        let pwt_signer = init_signer();
422        let base64 = general_purpose::URL_SAFE_NO_PAD.encode("base64");
423        assert_eq!(
424            pwt_signer
425                .as_verifier()
426                .verify::<()>(&format!("{base64}.{base64}")),
427            Result::Err(Error::InvalidSignature)
428        );
429    }
430
431    #[test]
432    fn protobuf_decode_mismatch() {
433        let pwt_signer = init_signer();
434        let pwt = pwt_signer.sign(
435            &proto::Simple {
436                some_claim: "test contents".to_string(),
437            },
438            Duration::from_secs(5),
439        );
440        assert_eq!(
441            pwt_signer.as_verifier().verify::<proto::Complex>(&pwt),
442            Result::Err(Error::ProtobufDecodeError)
443        );
444    }
445
446    #[test]
447    fn size_is_smaller_than_jwt() {
448        let jwt_signer = jwt::init_jwt_signer();
449        let pwt_signer = init_signer();
450
451        let pwt = pwt_signer.sign(
452            &proto::Simple {
453                some_claim: "test contents".to_string(),
454            },
455            Duration::from_secs(300),
456        );
457        println!("{pwt}");
458        let jwt = jwt::jwt_encode(
459            &jwt_signer,
460            Simple {
461                some_claim: "test contents".to_string(),
462            },
463            300,
464        );
465        let pwt_len = f64::from(u32::try_from(pwt.len()).unwrap());
466        let jwt_len = f64::from(u32::try_from(jwt.len()).unwrap());
467        assert!(
468            pwt_len * 1.2 < jwt_len,
469            "{pwt} was not small enough in comparison to {jwt}"
470        );
471    }
472
473    #[derive(Debug, Clone, Serialize)]
474    struct Complex {
475        email: String,
476        user_name: String,
477        user_id: String,
478        valid_until: SystemTime,
479        roles: Vec<String>,
480        nested: Nested,
481    }
482
483    #[derive(Debug, Clone, Serialize)]
484    struct Nested {
485        team_id: String,
486        team_name: String,
487    }
488
489    #[test]
490    fn size_is_smaller_than_jwt_complex() {
491        let jwt_signer = jwt::init_jwt_signer();
492        let pwt_signer = init_signer();
493        let now = SystemTime::now();
494
495        let pwt = pwt_signer.sign(
496            &proto::Complex {
497                email: "andreas.molitor@andrena.de".to_string(),
498                user_name: "Andreas Molitor".to_string(),
499                user_id: 123456789,
500                roles: vec![
501                    proto::Role::ReadFeatureFoo.into(),
502                    proto::Role::WriteFeatureFoo.into(),
503                    proto::Role::ReadFeatureBar.into(),
504                ],
505                nested: Some(proto::Nested {
506                    team_id: 3432535236263,
507                    team_name: "andrena".to_string(),
508                }),
509            },
510            Duration::from_secs(300),
511        );
512        let jwt = jwt::jwt_encode(
513            &jwt_signer,
514            Complex {
515                email: "andreas.molitor@andrena.de".to_string(),
516                user_name: "Andreas Molitor".to_string(),
517                user_id: "123456789".to_string(),
518                valid_until: (now + Duration::from_secs(5)),
519                roles: vec![
520                    "ReadFeatureFoo".to_string(),
521                    "WriteFeatureFoo".to_string(),
522                    "ReadFeatureBar".to_string(),
523                ],
524                nested: Nested {
525                    team_id: "3432535236263".to_string(),
526                    team_name: "andrena".to_string(),
527                },
528            },
529            300,
530        );
531        let pwt_len = f64::from(u32::try_from(pwt.len()).unwrap());
532        let jwt_len = f64::from(u32::try_from(jwt.len()).unwrap());
533        assert!(
534            pwt_len * 2.0 < jwt_len,
535            "{pwt} was not small enough in comparison to {jwt}"
536        );
537    }
538
539    #[test]
540    #[ignore] // generate only if specifically requested (with cargo test -- --ignored)
541    fn generate_fuzz_outputs() -> Result<(), Box<dyn std::error::Error>> {
542        use rand::distributions::{Alphanumeric, DistString};
543
544        let pwt_signer = init_signer();
545        let mut fuzz_output = Vec::new();
546
547        for i in 1..100 {
548            let random_string = Alphanumeric.sample_string(&mut rand::thread_rng(), i);
549            let pwt = pwt_signer.sign(
550                &proto::Simple {
551                    some_claim: random_string.clone(),
552                },
553                Duration::from_secs(500),
554            );
555            let pwt_bytes = pwt_signer.sign_to_bytes(
556                &proto::Simple {
557                    some_claim: random_string.clone(),
558                },
559                Duration::from_secs(500),
560            );
561            let data: TokenData<proto::Simple> = pwt_signer.as_verifier().verify(&pwt)?;
562            let timestamp = data
563                .valid_until
564                .duration_since(SystemTime::UNIX_EPOCH)?
565                .as_secs();
566            let json = serde_json::json!({
567                "input": random_string,
568                "output": pwt,
569                "output_binary": pwt_bytes,
570                "timestamp": timestamp
571            });
572            fuzz_output.push(json);
573        }
574        let file_contents = serde_json::to_string_pretty(&fuzz_output)?;
575        std::fs::create_dir_all("fuzz")?;
576        std::fs::write("fuzz/rust.json", file_contents)?;
577        Ok(())
578    }
579}