edc_dataplane_proxy/service/
token.rs1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
2use base64::Engine;
3use bon::Builder;
4use ed25519_compact::PublicKey;
5use jsonwebtoken::{jwk::JwkSet, Algorithm, DecodingKey, EncodingKey, TokenData};
6use secrecy::{ExposeSecret, SecretString};
7use serde::{de::DeserializeOwned, Serialize};
8use serde_json::json;
9use thiserror::Error;
10
11#[cfg(test)]
12use mockall::{automock, predicate::*};
13
14use crate::extensions::KeyFormat;
15
16#[cfg_attr(test, automock)]
17pub trait TokenManager {
18 fn issue<T: Serialize + 'static>(&self, claims: &T) -> Result<String, TokenError>;
19 fn validate<T: DeserializeOwned + 'static>(
20 &self,
21 token: &str,
22 ) -> Result<TokenData<T>, TokenError>;
23
24 fn keys(&self) -> Result<JwkSet, TokenError>;
25}
26
27#[derive(Builder, Clone)]
28pub struct TokenManagerImpl {
29 #[builder(into)]
30 encoding_key: SecretString,
31 #[builder(into)]
32 decoding_key: String,
33 #[builder(into)]
34 audience: String,
35 algorithm: Algorithm,
36 #[builder(into)]
37 kid: String,
38 #[builder(into)]
39 format: KeyFormat,
40 leeway: u64,
41}
42
43impl TokenManager for TokenManagerImpl {
44 fn issue<T: Serialize>(&self, claims: &T) -> Result<String, TokenError> {
45 let encoding_key = self.encoding_key()?;
46 let mut header = jsonwebtoken::Header::new(self.algorithm);
47 header.kid = Some(self.kid.clone());
48 let token =
49 jsonwebtoken::encode(&header, claims, &encoding_key).map_err(TokenError::Encode)?;
50
51 Ok(token)
52 }
53
54 fn validate<T: DeserializeOwned>(&self, token: &str) -> Result<TokenData<T>, TokenError> {
55 let decoding_key = self.decoding_key()?;
56 let mut validation = jsonwebtoken::Validation::new(self.algorithm);
57 validation.leeway = self.leeway;
58 validation.set_audience(&[&self.audience]);
59 jsonwebtoken::decode::<T>(token, &decoding_key, &validation).map_err(TokenError::Decode)
60 }
61
62 fn keys(&self) -> Result<JwkSet, TokenError> {
63 match self.algorithm {
64 Algorithm::EdDSA => {
65 let pk = PublicKey::from_pem(&self.decoding_key).map_err(TokenError::Ed25519)?;
66
67 let x_b64 = URL_SAFE_NO_PAD.encode(pk.as_ref());
68
69 let jwk = json!({
70 "kty": "OKP", "crv": "Ed25519", "x": x_b64, "use": "sig", "alg": "EdDSA", "kid": self.kid
76 });
77
78 Ok(JwkSet {
79 keys: vec![serde_json::from_value(jwk).unwrap()],
80 })
81 }
82 _ => todo!(),
83 }
84 }
85}
86
87impl TokenManagerImpl {
88 pub fn audience(&self) -> &str {
89 &self.audience
90 }
91
92 fn encoding_key(&self) -> Result<EncodingKey, TokenError> {
93 match (self.algorithm, &self.format) {
94 (Algorithm::EdDSA, KeyFormat::Pem) => {
95 EncodingKey::from_ed_pem(self.encoding_key.expose_secret().as_bytes())
96 .map_err(TokenError::Format)
97 }
98 _ => Err(TokenError::UnsupportedFormat(self.algorithm, self.format)),
99 }
100 }
101
102 fn decoding_key(&self) -> Result<DecodingKey, TokenError> {
103 match (self.algorithm, &self.format) {
104 (Algorithm::EdDSA, KeyFormat::Pem) => {
105 DecodingKey::from_ed_pem(self.decoding_key.as_bytes()).map_err(TokenError::Format)
106 }
107 _ => Err(TokenError::UnsupportedFormat(self.algorithm, self.format)),
108 }
109 }
110}
111
112#[derive(Error, Debug, PartialEq)]
113pub enum TokenError {
114 #[error("Error encoding token")]
115 Encode(jsonwebtoken::errors::Error),
116 #[error("Error decoding token")]
117 Decode(jsonwebtoken::errors::Error),
118 #[error("Error keys format")]
119 Format(jsonwebtoken::errors::Error),
120 #[error("Unsupported format: {0:?} {1:?}")]
121 UnsupportedFormat(Algorithm, KeyFormat),
122 #[error("Error ed25519: {0}")]
123 Ed25519(ed25519_compact::Error),
124}
125
126#[cfg(test)]
127mod tests {
128 use crate::{
129 extensions::KeyFormat,
130 service::token::{TokenError, TokenManager},
131 };
132
133 use super::TokenManagerImpl;
134 use ed25519_compact::{KeyPair, Seed};
135 use jsonwebtoken::{errors::ErrorKind, Algorithm};
136 use serde_json::{json, Value};
137
138 fn generate_key_pair() -> (String, String) {
139 let key_pair = KeyPair::from_seed(Seed::default());
140 (key_pair.sk.to_pem(), key_pair.pk.to_pem())
141 }
142
143 fn create_token_manager() -> TokenManagerImpl {
144 let (private_key, public_key) = generate_key_pair();
145
146 TokenManagerImpl::builder()
147 .encoding_key(private_key)
148 .decoding_key(public_key)
149 .algorithm(Algorithm::EdDSA)
150 .audience("audience")
151 .format(KeyFormat::Pem)
152 .leeway(0)
153 .kid("kid")
154 .build()
155 }
156
157 #[test]
158 fn issue_and_validate() {
159 let manager = create_token_manager();
160 let exp = chrono::Utc::now();
161 let claims = json!({"iss": "test", "aud": "audience", "exp" : exp.timestamp()});
162
163 let token = manager.issue(&claims).unwrap();
164 let token_claims = manager.validate::<Value>(&token).unwrap();
165
166 assert_eq!(token_claims.claims, claims);
167 }
168
169 #[test]
170 fn issue_and_validate_wrong_aud() {
171 let manager = create_token_manager();
172 let exp = chrono::Utc::now();
173 let claims = json!({"iss": "test", "aud": "wrong", "exp" : exp.timestamp()});
174
175 let token = manager.issue(&claims).unwrap();
176 let result = manager.validate::<Value>(&token).unwrap_err();
177
178 if let TokenError::Decode(err) = result {
179 assert_eq!(err.kind(), &ErrorKind::InvalidAudience);
180 } else {
181 panic!("Wrong type")
182 }
183 }
184}