edc_dataplane_proxy/service/
token.rs

1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
2use base64::Engine;
3use bon::Builder;
4use ed25519_compact::PublicKey;
5use jsonwebtoken::{jwk::JwkSet, Algorithm, DecodingKey, EncodingKey, TokenData};
6use secrecy::{ExposeSecret, SecretString};
7use serde::{de::DeserializeOwned, Serialize};
8use serde_json::json;
9use thiserror::Error;
10
11#[cfg(test)]
12use mockall::{automock, predicate::*};
13
14use crate::extensions::KeyFormat;
15
16#[cfg_attr(test, automock)]
17pub trait TokenManager {
18    fn issue<T: Serialize + 'static>(&self, claims: &T) -> Result<String, TokenError>;
19    fn validate<T: DeserializeOwned + 'static>(
20        &self,
21        token: &str,
22    ) -> Result<TokenData<T>, TokenError>;
23
24    fn keys(&self) -> Result<JwkSet, TokenError>;
25}
26
27#[derive(Builder, Clone)]
28pub struct TokenManagerImpl {
29    #[builder(into)]
30    encoding_key: SecretString,
31    #[builder(into)]
32    decoding_key: String,
33    #[builder(into)]
34    audience: String,
35    algorithm: Algorithm,
36    #[builder(into)]
37    kid: String,
38    #[builder(into)]
39    format: KeyFormat,
40    leeway: u64,
41}
42
43impl TokenManager for TokenManagerImpl {
44    fn issue<T: Serialize>(&self, claims: &T) -> Result<String, TokenError> {
45        let encoding_key = self.encoding_key()?;
46        let mut header = jsonwebtoken::Header::new(self.algorithm);
47        header.kid = Some(self.kid.clone());
48        let token =
49            jsonwebtoken::encode(&header, claims, &encoding_key).map_err(TokenError::Encode)?;
50
51        Ok(token)
52    }
53
54    fn validate<T: DeserializeOwned>(&self, token: &str) -> Result<TokenData<T>, TokenError> {
55        let decoding_key = self.decoding_key()?;
56        let mut validation = jsonwebtoken::Validation::new(self.algorithm);
57        validation.leeway = self.leeway;
58        validation.set_audience(&[&self.audience]);
59        jsonwebtoken::decode::<T>(token, &decoding_key, &validation).map_err(TokenError::Decode)
60    }
61
62    fn keys(&self) -> Result<JwkSet, TokenError> {
63        match self.algorithm {
64            Algorithm::EdDSA => {
65                let pk = PublicKey::from_pem(&self.decoding_key).map_err(TokenError::Ed25519)?;
66
67                let x_b64 = URL_SAFE_NO_PAD.encode(pk.as_ref());
68
69                let jwk = json!({
70                    "kty": "OKP",                  // Key type for Ed25519
71                    "crv": "Ed25519",             // Curve name
72                    "x": x_b64,                   // Base64 URL-encoded key
73                    "use": "sig",                 // Typically for signing
74                    "alg": "EdDSA",                // Algorithm name
75                    "kid": self.kid
76                });
77
78                Ok(JwkSet {
79                    keys: vec![serde_json::from_value(jwk).unwrap()],
80                })
81            }
82            _ => todo!(),
83        }
84    }
85}
86
87impl TokenManagerImpl {
88    pub fn audience(&self) -> &str {
89        &self.audience
90    }
91
92    fn encoding_key(&self) -> Result<EncodingKey, TokenError> {
93        match (self.algorithm, &self.format) {
94            (Algorithm::EdDSA, KeyFormat::Pem) => {
95                EncodingKey::from_ed_pem(self.encoding_key.expose_secret().as_bytes())
96                    .map_err(TokenError::Format)
97            }
98            _ => Err(TokenError::UnsupportedFormat(self.algorithm, self.format)),
99        }
100    }
101
102    fn decoding_key(&self) -> Result<DecodingKey, TokenError> {
103        match (self.algorithm, &self.format) {
104            (Algorithm::EdDSA, KeyFormat::Pem) => {
105                DecodingKey::from_ed_pem(self.decoding_key.as_bytes()).map_err(TokenError::Format)
106            }
107            _ => Err(TokenError::UnsupportedFormat(self.algorithm, self.format)),
108        }
109    }
110}
111
112#[derive(Error, Debug, PartialEq)]
113pub enum TokenError {
114    #[error("Error encoding token")]
115    Encode(jsonwebtoken::errors::Error),
116    #[error("Error decoding token")]
117    Decode(jsonwebtoken::errors::Error),
118    #[error("Error keys format")]
119    Format(jsonwebtoken::errors::Error),
120    #[error("Unsupported format: {0:?} {1:?}")]
121    UnsupportedFormat(Algorithm, KeyFormat),
122    #[error("Error ed25519: {0}")]
123    Ed25519(ed25519_compact::Error),
124}
125
126#[cfg(test)]
127mod tests {
128    use crate::{
129        extensions::KeyFormat,
130        service::token::{TokenError, TokenManager},
131    };
132
133    use super::TokenManagerImpl;
134    use ed25519_compact::{KeyPair, Seed};
135    use jsonwebtoken::{errors::ErrorKind, Algorithm};
136    use serde_json::{json, Value};
137
138    fn generate_key_pair() -> (String, String) {
139        let key_pair = KeyPair::from_seed(Seed::default());
140        (key_pair.sk.to_pem(), key_pair.pk.to_pem())
141    }
142
143    fn create_token_manager() -> TokenManagerImpl {
144        let (private_key, public_key) = generate_key_pair();
145
146        TokenManagerImpl::builder()
147            .encoding_key(private_key)
148            .decoding_key(public_key)
149            .algorithm(Algorithm::EdDSA)
150            .audience("audience")
151            .format(KeyFormat::Pem)
152            .leeway(0)
153            .kid("kid")
154            .build()
155    }
156
157    #[test]
158    fn issue_and_validate() {
159        let manager = create_token_manager();
160        let exp = chrono::Utc::now();
161        let claims = json!({"iss": "test", "aud": "audience", "exp" : exp.timestamp()});
162
163        let token = manager.issue(&claims).unwrap();
164        let token_claims = manager.validate::<Value>(&token).unwrap();
165
166        assert_eq!(token_claims.claims, claims);
167    }
168
169    #[test]
170    fn issue_and_validate_wrong_aud() {
171        let manager = create_token_manager();
172        let exp = chrono::Utc::now();
173        let claims = json!({"iss": "test", "aud": "wrong", "exp" : exp.timestamp()});
174
175        let token = manager.issue(&claims).unwrap();
176        let result = manager.validate::<Value>(&token).unwrap_err();
177
178        if let TokenError::Decode(err) = result {
179            assert_eq!(err.kind(), &ErrorKind::InvalidAudience);
180        } else {
181            panic!("Wrong type")
182        }
183    }
184}