1use async_trait::async_trait;
2use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
3use serde::{de::DeserializeOwned, Serialize};
4use std::marker::PhantomData;
5use tide::{Middleware, Next, Request, Response, StatusCode};
6
7pub fn jwtsign<Claims: Serialize + DeserializeOwned + Send + Sync + 'static>(
8 claims: &Claims,
9 key: &EncodingKey,
10) -> Result<String, jsonwebtoken::errors::Error> {
11 encode(&Header::default(), claims, key)
12}
13
14pub fn jwtsign_secret<Claims: Serialize + DeserializeOwned + Send + Sync + 'static>(
15 claims: &Claims,
16 key: &str,
17) -> Result<String, jsonwebtoken::errors::Error> {
18 encode(
19 &Header::default(),
20 claims,
21 &EncodingKey::from_base64_secret(key)?,
22 )
23}
24
25pub fn jwtsign_with<Claims: Serialize + DeserializeOwned + Send + Sync + 'static>(
26 header: &Header,
27 claims: &Claims,
28 key: &EncodingKey,
29) -> Result<String, jsonwebtoken::errors::Error> {
30 encode(header, claims, key)
31}
32
33pub struct JwtAuthenticationDecoder<Claims: DeserializeOwned + Send + Sync + 'static> {
34 validation: Validation,
35 key: DecodingKey,
36 _claims: PhantomData<Claims>,
37}
38
39impl<Claims: DeserializeOwned + Send + Sync + 'static> JwtAuthenticationDecoder<Claims> {
40 pub fn default(key: DecodingKey) -> Self {
41 Self::new(Validation::default(), key)
42 }
43
44 pub fn new(validation: Validation, key: DecodingKey) -> Self {
45 Self {
46 validation,
47 key,
48 _claims: PhantomData::default(),
49 }
50 }
51}
52
53#[async_trait]
54impl<State, Claims> Middleware<State> for JwtAuthenticationDecoder<Claims>
55where
56 State: Clone + Send + Sync + 'static,
57 Claims: DeserializeOwned + Send + Sync + 'static,
58{
59 async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> tide::Result {
60 let header = req.header("Authorization");
61 if header.is_none() {
62 return Ok(next.run(req).await);
63 }
64
65 let values: Vec<_> = header.unwrap().into_iter().collect();
66
67 if values.is_empty() {
68 return Ok(next.run(req).await);
69 }
70
71 if values.len() > 1 {
72 return Ok(Response::new(StatusCode::Unauthorized));
73 }
74
75 for value in values {
76 let value = value.as_str();
77 if !value.starts_with("Bearer") {
78 continue;
79 }
80
81 let token = &value["Bearer ".len()..];
82 println!("found authorization token: {token}");
83 let data = match decode::<Claims>(token, &self.key, &self.validation) {
84 Ok(c) => c,
85 Err(_) => {
86 return Ok(Response::new(StatusCode::Unauthorized));
87 }
88 };
89
90 req.set_ext(data.claims);
91 break;
92 }
93
94 Ok(next.run(req).await)
95 }
96}