modo/auth/session/jwt/
encoder.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::Serialize;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtSessionsConfig;
10use super::error::JwtError;
11use super::signer::{HmacSigner, TokenSigner};
12
13pub struct JwtEncoder {
18 inner: Arc<JwtEncoderInner>,
19}
20
21struct JwtEncoderInner {
22 signer: Arc<dyn TokenSigner>,
23 default_expiry: Option<Duration>,
24 validation: super::validation::ValidationConfig,
25}
26
27impl JwtEncoder {
28 pub fn from_config(config: &JwtSessionsConfig) -> Self {
34 let signer = HmacSigner::new(config.signing_secret.as_bytes());
35 Self {
36 inner: Arc::new(JwtEncoderInner {
37 signer: Arc::new(signer),
38 default_expiry: Some(Duration::from_secs(config.access_ttl_secs)),
39 validation: super::validation::ValidationConfig {
40 leeway: Duration::ZERO,
41 require_issuer: config.issuer.clone(),
42 require_audience: None,
43 },
44 }),
45 }
46 }
47
48 pub(super) fn verifier(&self) -> Arc<dyn super::signer::TokenVerifier> {
51 self.inner.signer.clone() as Arc<dyn super::signer::TokenVerifier>
54 }
55
56 pub(super) fn validation(&self) -> super::validation::ValidationConfig {
59 self.inner.validation.clone()
60 }
61
62 pub fn encode<T: Serialize>(&self, claims: &T) -> Result<String> {
79 let claims_json = if let Some(default_exp) = self.inner.default_expiry {
81 let mut value = serde_json::to_value(claims).map_err(|_| {
82 Error::internal("failed to serialize token")
83 .chain(JwtError::SerializationFailed)
84 .with_code(JwtError::SerializationFailed.code())
85 })?;
86 if value.get("exp").is_none() {
88 let now = std::time::SystemTime::now()
89 .duration_since(std::time::UNIX_EPOCH)
90 .expect("system clock before UNIX epoch")
91 .as_secs();
92 value["exp"] = serde_json::Value::Number((now + default_exp.as_secs()).into());
93 }
94 serde_json::to_vec(&value)
95 } else {
96 serde_json::to_vec(claims)
97 }
98 .map_err(|_| {
99 Error::internal("unauthorized")
100 .chain(JwtError::SerializationFailed)
101 .with_code(JwtError::SerializationFailed.code())
102 })?;
103
104 let alg = self.inner.signer.algorithm_name();
105 let header = format!(r#"{{"alg":"{alg}","typ":"JWT"}}"#);
106 let header_b64 = base64url::encode(header.as_bytes());
107 let payload_b64 = base64url::encode(&claims_json);
108
109 let header_payload = format!("{header_b64}.{payload_b64}");
110 let signature = self.inner.signer.sign(header_payload.as_bytes())?;
111 let signature_b64 = base64url::encode(&signature);
112
113 Ok(format!("{header_payload}.{signature_b64}"))
114 }
115}
116
117impl Clone for JwtEncoder {
118 fn clone(&self) -> Self {
119 Self {
120 inner: self.inner.clone(),
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use serde::Deserialize;
129
130 use super::super::claims::Claims;
131
132 fn test_config() -> JwtSessionsConfig {
133 JwtSessionsConfig {
134 signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
135 ..JwtSessionsConfig::default()
136 }
137 }
138
139 #[test]
140 fn encode_produces_three_part_token() {
141 let encoder = JwtEncoder::from_config(&test_config());
142 let claims = Claims::new().with_exp(9999999999);
143 let token = encoder.encode(&claims).unwrap();
144 assert_eq!(token.split('.').count(), 3);
145 }
146
147 #[test]
148 fn encode_header_contains_hs256() {
149 let encoder = JwtEncoder::from_config(&test_config());
150 let claims = Claims::new().with_exp(9999999999);
151 let token = encoder.encode(&claims).unwrap();
152 let header_b64 = token.split('.').next().unwrap();
153 let header_bytes = base64url::decode(header_b64).unwrap();
154 let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
155 assert_eq!(header["alg"], "HS256");
156 assert_eq!(header["typ"], "JWT");
157 }
158
159 #[test]
160 fn encode_with_default_expiry_auto_sets_exp() {
161 let config = test_config(); let encoder = JwtEncoder::from_config(&config);
164 let claims = Claims::new(); let token = encoder.encode(&claims).unwrap();
166 let payload_b64 = token.split('.').nth(1).unwrap();
167 let payload_bytes = base64url::decode(payload_b64).unwrap();
168 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
169 assert!(payload.get("exp").is_some());
170 }
171
172 #[test]
173 fn encode_explicit_exp_not_overwritten() {
174 let config = test_config();
175 let encoder = JwtEncoder::from_config(&config);
176 let claims = Claims::new().with_exp(42);
177 let token = encoder.encode(&claims).unwrap();
178 let payload_b64 = token.split('.').nth(1).unwrap();
179 let payload_bytes = base64url::decode(payload_b64).unwrap();
180 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
181 assert_eq!(payload["exp"], 42);
182 }
183
184 #[test]
185 fn encode_custom_struct_directly() {
186 #[derive(Debug, Clone, Serialize, Deserialize)]
187 struct CustomPayload {
188 sub: String,
189 role: String,
190 exp: u64,
191 }
192
193 let encoder = JwtEncoder::from_config(&test_config());
194 let payload = CustomPayload {
195 sub: "user_1".into(),
196 role: "admin".into(),
197 exp: 9999999999,
198 };
199 let token = encoder.encode(&payload).unwrap();
200 assert_eq!(token.split('.').count(), 3);
201 }
202
203 #[test]
204 fn clone_produces_working_encoder() {
205 let encoder = JwtEncoder::from_config(&test_config());
206 let cloned = encoder.clone();
207 let claims = Claims::new().with_exp(9999999999);
208 assert!(cloned.encode(&claims).is_ok());
209 }
210}