modo/auth/session/jwt/
encoder.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::Serialize;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtSessionsConfig;
10use super::error::JwtError;
11use super::signer::{HmacSigner, TokenSigner};
12
13pub struct JwtEncoder {
18 inner: Arc<JwtEncoderInner>,
19}
20
21struct JwtEncoderInner {
22 signer: Arc<dyn TokenSigner>,
23 default_expiry: Option<Duration>,
24 validation: super::validation::ValidationConfig,
25}
26
27impl JwtEncoder {
28 pub fn from_config(config: &JwtSessionsConfig) -> Self {
34 let signer = HmacSigner::new(config.signing_secret.as_bytes());
35 Self {
36 inner: Arc::new(JwtEncoderInner {
37 signer: Arc::new(signer),
38 default_expiry: Some(Duration::from_secs(config.access_ttl_secs)),
39 validation: super::validation::ValidationConfig {
40 leeway: Duration::ZERO,
41 require_issuer: config.issuer.clone(),
42 require_audience: None,
43 },
44 }),
45 }
46 }
47
48 pub(super) fn verifier(&self) -> Arc<dyn super::signer::TokenVerifier> {
51 self.inner.signer.clone() as Arc<dyn super::signer::TokenVerifier>
54 }
55
56 pub(super) fn validation(&self) -> super::validation::ValidationConfig {
59 self.inner.validation.clone()
60 }
61
62 pub fn encode<T: Serialize>(&self, claims: &T) -> Result<String> {
79 let claims_json = if let Some(default_exp) = self.inner.default_expiry {
80 let mut value = serde_json::to_value(claims).map_err(|_| {
81 Error::internal("failed to serialize token")
82 .chain(JwtError::SerializationFailed)
83 .with_code(JwtError::SerializationFailed.code())
84 })?;
85 if value.get("exp").is_none() {
86 let now = std::time::SystemTime::now()
87 .duration_since(std::time::UNIX_EPOCH)
88 .expect("system clock before UNIX epoch")
89 .as_secs();
90 value["exp"] = serde_json::Value::Number((now + default_exp.as_secs()).into());
91 }
92 serde_json::to_vec(&value)
93 } else {
94 serde_json::to_vec(claims)
95 }
96 .map_err(|_| {
97 Error::internal("unauthorized")
98 .chain(JwtError::SerializationFailed)
99 .with_code(JwtError::SerializationFailed.code())
100 })?;
101
102 let alg = self.inner.signer.algorithm_name();
103 let header = format!(r#"{{"alg":"{alg}","typ":"JWT"}}"#);
104 let header_b64 = base64url::encode(header.as_bytes());
105 let payload_b64 = base64url::encode(&claims_json);
106
107 let header_payload = format!("{header_b64}.{payload_b64}");
108 let signature = self.inner.signer.sign(header_payload.as_bytes())?;
109 let signature_b64 = base64url::encode(&signature);
110
111 Ok(format!("{header_payload}.{signature_b64}"))
112 }
113}
114
115impl Clone for JwtEncoder {
116 fn clone(&self) -> Self {
117 Self {
118 inner: self.inner.clone(),
119 }
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use serde::Deserialize;
127
128 use super::super::claims::Claims;
129
130 fn test_config() -> JwtSessionsConfig {
131 JwtSessionsConfig {
132 signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
133 ..JwtSessionsConfig::default()
134 }
135 }
136
137 #[test]
138 fn encode_produces_three_part_token() {
139 let encoder = JwtEncoder::from_config(&test_config());
140 let claims = Claims::new().with_exp(9999999999);
141 let token = encoder.encode(&claims).unwrap();
142 assert_eq!(token.split('.').count(), 3);
143 }
144
145 #[test]
146 fn encode_header_contains_hs256() {
147 let encoder = JwtEncoder::from_config(&test_config());
148 let claims = Claims::new().with_exp(9999999999);
149 let token = encoder.encode(&claims).unwrap();
150 let header_b64 = token.split('.').next().unwrap();
151 let header_bytes = base64url::decode(header_b64).unwrap();
152 let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
153 assert_eq!(header["alg"], "HS256");
154 assert_eq!(header["typ"], "JWT");
155 }
156
157 #[test]
158 fn encode_with_default_expiry_auto_sets_exp() {
159 let config = test_config(); let encoder = JwtEncoder::from_config(&config);
162 let claims = Claims::new(); let token = encoder.encode(&claims).unwrap();
164 let payload_b64 = token.split('.').nth(1).unwrap();
165 let payload_bytes = base64url::decode(payload_b64).unwrap();
166 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
167 assert!(payload.get("exp").is_some());
168 }
169
170 #[test]
171 fn encode_explicit_exp_not_overwritten() {
172 let config = test_config();
173 let encoder = JwtEncoder::from_config(&config);
174 let claims = Claims::new().with_exp(42);
175 let token = encoder.encode(&claims).unwrap();
176 let payload_b64 = token.split('.').nth(1).unwrap();
177 let payload_bytes = base64url::decode(payload_b64).unwrap();
178 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
179 assert_eq!(payload["exp"], 42);
180 }
181
182 #[test]
183 fn encode_custom_struct_directly() {
184 #[derive(Debug, Clone, Serialize, Deserialize)]
185 struct CustomPayload {
186 sub: String,
187 role: String,
188 exp: u64,
189 }
190
191 let encoder = JwtEncoder::from_config(&test_config());
192 let payload = CustomPayload {
193 sub: "user_1".into(),
194 role: "admin".into(),
195 exp: 9999999999,
196 };
197 let token = encoder.encode(&payload).unwrap();
198 assert_eq!(token.split('.').count(), 3);
199 }
200
201 #[test]
202 fn clone_produces_working_encoder() {
203 let encoder = JwtEncoder::from_config(&test_config());
204 let cloned = encoder.clone();
205 let claims = Claims::new().with_exp(9999999999);
206 assert!(cloned.encode(&claims).is_ok());
207 }
208}