1use std::{
2 fmt::Display,
3 time::{Duration, SystemTime},
4};
5
6use ed25519_dalek::{Signature, Signer as _, SigningKey, Verifier as _, VerifyingKey};
7use prost::Message;
8
9use base64::{engine::general_purpose, Engine as _};
10
11#[cfg(test)]
12mod jwt;
13
14mod proto {
15 include!(concat!(env!("OUT_DIR"), "/pwt.rs"));
16}
17pub extern crate ed25519_dalek as ed25519;
18
19#[derive(Clone)]
20pub struct Signer {
21 key: SigningKey,
22}
23
24#[derive(Copy, Clone, PartialEq, Eq)]
25pub struct Verifier {
26 key: VerifyingKey,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
30pub struct TokenData<CLAIMS> {
31 pub valid_until: SystemTime,
32 pub claims: CLAIMS,
33}
34
35struct Base64Claims<'a>(&'a str);
36
37struct Base64Signature<'a>(&'a str);
38
39struct BytesClaims(Vec<u8>);
40
41#[derive(Debug, PartialEq, Eq)]
42pub enum Error {
43 InvalidFormat,
44 InvalidBase64,
45 InvalidSignature,
46 SignatureMismatch,
47 ProtobufDecodeError,
48 MissingValidUntil,
49 TokenExpired,
50}
51
52impl Signer {
53 pub fn new(key: SigningKey) -> Self {
55 Signer { key }
56 }
57
58 pub fn as_verifier(&self) -> Verifier {
60 Verifier {
61 key: self.key.verifying_key(),
62 }
63 }
64
65 pub fn sign<T: Message>(&self, data: &T, valid_for: Duration) -> String {
68 let proto_token = self.create_proto_token(data, valid_for);
69 let (base64, signature) = self.sign_proto_token(&proto_token);
70 format!("{base64}.{signature}")
71 }
72
73 pub fn sign_to_bytes<T: Message>(&self, data: &T, valid_for: Duration) -> Vec<u8> {
76 let proto_token = self.create_proto_token(data, valid_for);
77 let bytes = proto_token.encode_to_vec();
78 let signature = self.key.sign(&bytes);
79 proto::SignedToken {
80 data: bytes,
81 signature: signature.to_bytes().to_vec(),
82 }
83 .encode_to_vec()
84 }
85
86 fn create_proto_token<T: Message>(&self, data: &T, valid_for: Duration) -> proto::Token {
87 let bytes = data.encode_to_vec();
88 proto::Token {
89 valid_until: Some((SystemTime::now() + valid_for).into()),
90 claims: bytes,
91 }
92 }
93
94 fn sign_proto_token(&self, proto_token: &proto::Token) -> (String, String) {
95 let bytes = proto_token.encode_to_vec();
96 let signature = self.key.sign(&bytes);
97 let base64 = general_purpose::URL_SAFE_NO_PAD.encode(&bytes);
98 let signature = general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes());
99 (base64, signature)
100 }
101}
102
103impl Verifier {
104 pub fn new(key: VerifyingKey) -> Self {
106 Self { key }
107 }
108
109 pub fn verify<CLAIMS: Message + Default>(
110 &self,
111 token: &str,
112 ) -> Result<TokenData<CLAIMS>, Error> {
113 let (claims, signature) = parse_token(token)?;
114 let bytes = claims.to_bytes()?;
115 self.verify_signature(&bytes, &signature)?;
116
117 let token_data = bytes.decode_metadata()?;
118 let claims =
119 CLAIMS::decode(token_data.claims.as_slice()).map_err(|_| Error::ProtobufDecodeError)?;
120 Ok(TokenData {
121 valid_until: token_data.valid_until,
122 claims,
123 })
124 }
125
126 pub fn verify_bytes<CLAIMS: Message + Default>(
127 &self,
128 token: &[u8],
129 ) -> Result<TokenData<CLAIMS>, Error> {
130 let proto::SignedToken { data, signature } =
131 proto::SignedToken::decode(token).map_err(|_| Error::ProtobufDecodeError)?;
132 let signature = Signature::from_slice(&signature).map_err(|_| Error::InvalidSignature)?;
133 self.key
134 .verify(&data, &signature)
135 .map_err(|_| Error::SignatureMismatch)?;
136
137 let token_data = BytesClaims(data).decode_metadata()?;
138 let claims =
139 CLAIMS::decode(token_data.claims.as_slice()).map_err(|_| Error::ProtobufDecodeError)?;
140 Ok(TokenData {
141 valid_until: token_data.valid_until,
142 claims,
143 })
144 }
145
146 pub fn verify_and_check_expiry<CLAIMS: Message + Default>(
147 &self,
148 token: &str,
149 ) -> Result<CLAIMS, Error> {
150 let (claims, signature) = parse_token(token)?;
151 let bytes = claims.to_bytes()?;
152 self.verify_signature(&bytes, &signature)?;
153
154 let token_data = bytes.decode_metadata()?;
155
156 let now = SystemTime::now();
157 if now > token_data.valid_until {
158 return Result::Err(Error::TokenExpired);
159 }
160
161 CLAIMS::decode(token_data.claims.as_slice()).map_err(|_| Error::ProtobufDecodeError)
162 }
163
164 pub fn verify_bytes_and_check_expiry<CLAIMS: Message + Default>(
165 &self,
166 token: &[u8],
167 ) -> Result<CLAIMS, Error> {
168 let proto::SignedToken { data, signature } =
169 proto::SignedToken::decode(token).map_err(|_| Error::ProtobufDecodeError)?;
170 let signature = Signature::from_slice(&signature).map_err(|_| Error::InvalidSignature)?;
171 self.key
172 .verify(&data, &signature)
173 .map_err(|_| Error::SignatureMismatch)?;
174
175 let token_data = BytesClaims(data).decode_metadata()?;
176
177 let now = SystemTime::now();
178 if now > token_data.valid_until {
179 return Result::Err(Error::TokenExpired);
180 }
181
182 CLAIMS::decode(token_data.claims.as_slice()).map_err(|_| Error::ProtobufDecodeError)
183 }
184
185 fn verify_signature(
186 &self,
187 bytes: &BytesClaims,
188 signature: &Base64Signature,
189 ) -> Result<(), Error> {
190 let signature = general_purpose::URL_SAFE_NO_PAD
191 .decode(signature.0)
192 .map_err(|_| Error::InvalidBase64)?;
193 let signature =
194 Signature::from_slice(signature.as_slice()).map_err(|_| Error::InvalidSignature)?;
195
196 self.key
197 .verify(&bytes.0, &signature)
198 .map_err(|_| Error::SignatureMismatch)?;
199 Ok(())
200 }
201}
202
203impl<'a> Base64Claims<'a> {
204 pub fn to_bytes(&'a self) -> Result<BytesClaims, Error> {
205 general_purpose::URL_SAFE_NO_PAD
206 .decode(self.0)
207 .map(BytesClaims)
208 .map_err(|_| Error::InvalidBase64)
209 }
210}
211
212impl BytesClaims {
213 pub fn decode_metadata(&self) -> Result<TokenData<Vec<u8>>, Error> {
214 let token =
215 proto::Token::decode(self.0.as_slice()).map_err(|_| Error::ProtobufDecodeError)?;
216 let valid_until: SystemTime = token
217 .valid_until
218 .ok_or(Error::MissingValidUntil)?
219 .try_into()
220 .map_err(|_| Error::MissingValidUntil)?;
221 Ok(TokenData {
222 valid_until,
223 claims: token.claims,
224 })
225 }
226}
227
228fn parse_token(token: &str) -> Result<(Base64Claims<'_>, Base64Signature<'_>), Error> {
229 let (data, signature) = token.split_once('.').ok_or(Error::InvalidFormat)?;
230 Ok((Base64Claims(data), Base64Signature(signature)))
231}
232
233pub fn decode<CLAIMS: Message + Default>(token: &str) -> Result<TokenData<CLAIMS>, Error> {
234 let (data, _signature) = token.split_once('.').ok_or(Error::InvalidFormat)?;
235 let bytes = general_purpose::URL_SAFE_NO_PAD
236 .decode(data)
237 .map_err(|_| Error::InvalidBase64)?;
238
239 let decoded_metadata =
240 proto::Token::decode(bytes.as_slice()).map_err(|_| Error::ProtobufDecodeError)?;
241 let valid_until = decoded_metadata
242 .valid_until
243 .ok_or(Error::MissingValidUntil)?
244 .try_into()
245 .map_err(|_| Error::MissingValidUntil)?;
246 let claims = CLAIMS::decode(decoded_metadata.claims.as_slice())
247 .map_err(|_| Error::ProtobufDecodeError)?;
248 Ok(TokenData {
249 valid_until,
250 claims,
251 })
252}
253
254impl Display for Error {
255 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 match self {
257 Error::InvalidFormat => f.write_str(
258 "Invalid Token Format. Expected two string segments seperated by a dot ('.')",
259 ),
260 Error::InvalidBase64 => f.write_str(
261 "A part of the token was not valid base64 (A-Z, a-z, 0-9, -, _, no padding)",
262 ),
263 Error::InvalidSignature => {
264 f.write_str("The signature is not a valid Ed25519 signature")
265 }
266 Error::SignatureMismatch => f.write_str(
267 "The signature does not match the given data (probably the token was manipulated)",
268 ),
269 Error::ProtobufDecodeError => f.write_str(
270 "The data encoded in the token did not match the expected protobuf format.",
271 ),
272 Error::MissingValidUntil => {
273 f.write_str("The data encoded in the token did not include an expiry time")
274 }
275 Error::TokenExpired => f.write_str("The token is expired"),
276 }
277 }
278}
279
280impl std::error::Error for Error {}
281
282#[cfg(test)]
283mod tests {
284 use std::time::{Duration, SystemTime};
285
286 use ed25519::pkcs8::DecodePrivateKey;
287 use serde::Serialize;
288
289 use super::*;
290 use crate::jwt;
291
292 mod proto {
293 include!(concat!(env!("OUT_DIR"), "/test.rs"));
294 }
295
296 #[derive(Debug, Clone, Serialize)]
297 struct Simple {
298 some_claim: String,
299 }
300
301 fn init_signer() -> Signer {
302 let pem = std::fs::read("test_resources/private.pem").unwrap();
303 let pem = String::from_utf8(pem).unwrap();
304 let key = SigningKey::from_pkcs8_pem(&pem).unwrap();
305 Signer { key }
306 }
307
308 #[test]
309 fn happy_case() {
310 let pwt_signer = init_signer();
311 let simple = proto::Simple {
312 some_claim: "testabcd".to_string(),
313 };
314 let pwt = pwt_signer.sign(&simple, Duration::from_secs(5));
315 assert_eq!(
316 pwt_signer
317 .as_verifier()
318 .verify_and_check_expiry::<proto::Simple>(&pwt),
319 Result::Ok(simple)
320 );
321 }
322
323 #[test]
324 fn happy_case_bytes() {
325 let pwt_signer = init_signer();
326 let simple = proto::Simple {
327 some_claim: "testabcd".to_string(),
328 };
329 let pwt = pwt_signer.sign_to_bytes(&simple, Duration::from_secs(5));
330 println!("{}{pwt:?}", pwt.len());
331 assert_eq!(
332 pwt_signer
333 .as_verifier()
334 .verify_bytes_and_check_expiry::<proto::Simple>(&pwt),
335 Result::Ok(simple)
336 );
337 }
338
339 #[test]
340 fn signature_is_verified_and_prevents_tampering() {
341 let pwt_signer = init_signer();
342 let proto_token = pwt_signer.create_proto_token(
343 &proto::Simple {
344 some_claim: "test contents".to_string(),
345 },
346 Duration::from_secs(5),
347 );
348 let (_data, signature) = pwt_signer.sign_proto_token(&proto_token);
349 let other_proto_token = pwt_signer.create_proto_token(
350 &proto::Simple {
351 some_claim: "tampered contents".to_string(),
352 },
353 Duration::from_secs(5),
354 );
355 let (other_data, _) = pwt_signer.sign_proto_token(&other_proto_token);
356
357 let tampered_token = format!("{other_data}.{signature}");
358
359 assert_eq!(
360 pwt_signer
361 .as_verifier()
362 .verify::<proto::Simple>(&tampered_token),
363 Result::Err(Error::SignatureMismatch)
364 );
365 }
366
367 #[test]
368 fn signature_is_verified_and_prevents_tampering_bytes() {
369 let pwt_signer = init_signer();
370 let proto_token = pwt_signer.create_proto_token(
371 &proto::Simple {
372 some_claim: "test contents".to_string(),
373 },
374 Duration::from_secs(5),
375 );
376
377 let data = proto_token.encode_to_vec();
378 let signature = pwt_signer.key.sign(&data);
379 let other_proto_token = pwt_signer.create_proto_token(
380 &proto::Simple {
381 some_claim: "tampered contents".to_string(),
382 },
383 Duration::from_secs(5),
384 );
385 let other_data = other_proto_token.encode_to_vec();
386
387 let tampered_token = super::proto::SignedToken {
388 data: other_data,
389 signature: signature.to_bytes().to_vec(),
390 }
391 .encode_to_vec();
392
393 assert_eq!(
394 pwt_signer
395 .as_verifier()
396 .verify_bytes::<proto::Simple>(&tampered_token),
397 Result::Err(Error::SignatureMismatch)
398 );
399 }
400
401 #[test]
402 fn invalid_format() {
403 let pwt_signer = init_signer();
404 assert_eq!(
405 pwt_signer.as_verifier().verify::<()>("invalid"),
406 Result::Err(Error::InvalidFormat)
407 );
408 }
409
410 #[test]
411 fn invalid_base64() {
412 let pwt_signer = init_signer();
413 assert_eq!(
414 pwt_signer.as_verifier().verify::<()>("invalid.base64"),
415 Result::Err(Error::InvalidBase64)
416 );
417 }
418
419 #[test]
420 fn invalid_signature() {
421 let pwt_signer = init_signer();
422 let base64 = general_purpose::URL_SAFE_NO_PAD.encode("base64");
423 assert_eq!(
424 pwt_signer
425 .as_verifier()
426 .verify::<()>(&format!("{base64}.{base64}")),
427 Result::Err(Error::InvalidSignature)
428 );
429 }
430
431 #[test]
432 fn protobuf_decode_mismatch() {
433 let pwt_signer = init_signer();
434 let pwt = pwt_signer.sign(
435 &proto::Simple {
436 some_claim: "test contents".to_string(),
437 },
438 Duration::from_secs(5),
439 );
440 assert_eq!(
441 pwt_signer.as_verifier().verify::<proto::Complex>(&pwt),
442 Result::Err(Error::ProtobufDecodeError)
443 );
444 }
445
446 #[test]
447 fn size_is_smaller_than_jwt() {
448 let jwt_signer = jwt::init_jwt_signer();
449 let pwt_signer = init_signer();
450
451 let pwt = pwt_signer.sign(
452 &proto::Simple {
453 some_claim: "test contents".to_string(),
454 },
455 Duration::from_secs(300),
456 );
457 println!("{pwt}");
458 let jwt = jwt::jwt_encode(
459 &jwt_signer,
460 Simple {
461 some_claim: "test contents".to_string(),
462 },
463 300,
464 );
465 let pwt_len = f64::from(u32::try_from(pwt.len()).unwrap());
466 let jwt_len = f64::from(u32::try_from(jwt.len()).unwrap());
467 assert!(
468 pwt_len * 1.2 < jwt_len,
469 "{pwt} was not small enough in comparison to {jwt}"
470 );
471 }
472
473 #[derive(Debug, Clone, Serialize)]
474 struct Complex {
475 email: String,
476 user_name: String,
477 user_id: String,
478 valid_until: SystemTime,
479 roles: Vec<String>,
480 nested: Nested,
481 }
482
483 #[derive(Debug, Clone, Serialize)]
484 struct Nested {
485 team_id: String,
486 team_name: String,
487 }
488
489 #[test]
490 fn size_is_smaller_than_jwt_complex() {
491 let jwt_signer = jwt::init_jwt_signer();
492 let pwt_signer = init_signer();
493 let now = SystemTime::now();
494
495 let pwt = pwt_signer.sign(
496 &proto::Complex {
497 email: "andreas.molitor@andrena.de".to_string(),
498 user_name: "Andreas Molitor".to_string(),
499 user_id: 123456789,
500 roles: vec![
501 proto::Role::ReadFeatureFoo.into(),
502 proto::Role::WriteFeatureFoo.into(),
503 proto::Role::ReadFeatureBar.into(),
504 ],
505 nested: Some(proto::Nested {
506 team_id: 3432535236263,
507 team_name: "andrena".to_string(),
508 }),
509 },
510 Duration::from_secs(300),
511 );
512 let jwt = jwt::jwt_encode(
513 &jwt_signer,
514 Complex {
515 email: "andreas.molitor@andrena.de".to_string(),
516 user_name: "Andreas Molitor".to_string(),
517 user_id: "123456789".to_string(),
518 valid_until: (now + Duration::from_secs(5)),
519 roles: vec![
520 "ReadFeatureFoo".to_string(),
521 "WriteFeatureFoo".to_string(),
522 "ReadFeatureBar".to_string(),
523 ],
524 nested: Nested {
525 team_id: "3432535236263".to_string(),
526 team_name: "andrena".to_string(),
527 },
528 },
529 300,
530 );
531 let pwt_len = f64::from(u32::try_from(pwt.len()).unwrap());
532 let jwt_len = f64::from(u32::try_from(jwt.len()).unwrap());
533 assert!(
534 pwt_len * 2.0 < jwt_len,
535 "{pwt} was not small enough in comparison to {jwt}"
536 );
537 }
538
539 #[test]
540 #[ignore] fn generate_fuzz_outputs() -> Result<(), Box<dyn std::error::Error>> {
542 use rand::distributions::{Alphanumeric, DistString};
543
544 let pwt_signer = init_signer();
545 let mut fuzz_output = Vec::new();
546
547 for i in 1..100 {
548 let random_string = Alphanumeric.sample_string(&mut rand::thread_rng(), i);
549 let pwt = pwt_signer.sign(
550 &proto::Simple {
551 some_claim: random_string.clone(),
552 },
553 Duration::from_secs(500),
554 );
555 let pwt_bytes = pwt_signer.sign_to_bytes(
556 &proto::Simple {
557 some_claim: random_string.clone(),
558 },
559 Duration::from_secs(500),
560 );
561 let data: TokenData<proto::Simple> = pwt_signer.as_verifier().verify(&pwt)?;
562 let timestamp = data
563 .valid_until
564 .duration_since(SystemTime::UNIX_EPOCH)?
565 .as_secs();
566 let json = serde_json::json!({
567 "input": random_string,
568 "output": pwt,
569 "output_binary": pwt_bytes,
570 "timestamp": timestamp
571 });
572 fuzz_output.push(json);
573 }
574 let file_contents = serde_json::to_string_pretty(&fuzz_output)?;
575 std::fs::create_dir_all("fuzz")?;
576 std::fs::write("fuzz/rust.json", file_contents)?;
577 Ok(())
578 }
579}