Skip to main content

modo/auth/session/jwt/
encoder.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::Serialize;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtSessionsConfig;
10use super::error::JwtError;
11use super::signer::{HmacSigner, TokenSigner};
12
13/// JWT token encoder. Signs tokens using a [`TokenSigner`].
14///
15/// Register in `Registry` for handler access via `Service<JwtEncoder>`.
16/// Cloning is cheap — state is stored behind `Arc`.
17pub struct JwtEncoder {
18    inner: Arc<JwtEncoderInner>,
19}
20
21struct JwtEncoderInner {
22    signer: Arc<dyn TokenSigner>,
23    default_expiry: Option<Duration>,
24    validation: super::validation::ValidationConfig,
25}
26
27impl JwtEncoder {
28    /// Creates a `JwtEncoder` from YAML configuration.
29    ///
30    /// Uses `HmacSigner` (HS256) with the configured secret. The validation
31    /// config (issuer) is stored so a matching `JwtDecoder`
32    /// can be created via `JwtDecoder::from(&encoder)`.
33    pub fn from_config(config: &JwtSessionsConfig) -> Self {
34        let signer = HmacSigner::new(config.signing_secret.as_bytes());
35        Self {
36            inner: Arc::new(JwtEncoderInner {
37                signer: Arc::new(signer),
38                default_expiry: Some(Duration::from_secs(config.access_ttl_secs)),
39                validation: super::validation::ValidationConfig {
40                    leeway: Duration::ZERO,
41                    require_issuer: config.issuer.clone(),
42                    require_audience: None,
43                },
44            }),
45        }
46    }
47
48    /// Returns a reference to the inner signer (as verifier).
49    /// Used by `JwtDecoder::from(&encoder)` to share the same key.
50    pub(super) fn verifier(&self) -> Arc<dyn super::signer::TokenVerifier> {
51        // Trait upcasting: Arc<dyn TokenSigner> → Arc<dyn TokenVerifier>
52        // Stabilized in Rust 1.76.
53        self.inner.signer.clone() as Arc<dyn super::signer::TokenVerifier>
54    }
55
56    /// Returns a clone of the validation config.
57    /// Used by `JwtDecoder::from(&encoder)`.
58    pub(super) fn validation(&self) -> super::validation::ValidationConfig {
59        self.inner.validation.clone()
60    }
61
62    /// Encodes a serializable payload into a signed JWT token string.
63    ///
64    /// If the payload serializes to a JSON object without an `exp` field and
65    /// `default_expiry` is configured, `exp` is automatically set to
66    /// `now + default_expiry` before signing. An explicitly set `exp` field
67    /// is never overwritten.
68    ///
69    /// The system auth flow passes [`Claims`](super::claims::Claims) here.
70    /// Custom auth flows can pass any `Serialize` struct directly.
71    ///
72    /// # Errors
73    ///
74    /// Returns `Error::internal` with [`JwtError::SerializationFailed`](super::JwtError::SerializationFailed)
75    /// if the payload cannot be serialized to JSON, or
76    /// [`JwtError::SigningFailed`](super::JwtError::SigningFailed) if the HMAC signing
77    /// operation fails.
78    pub fn encode<T: Serialize>(&self, claims: &T) -> Result<String> {
79        // Auto-fill exp if missing and default_expiry is configured
80        let claims_json = if let Some(default_exp) = self.inner.default_expiry {
81            let mut value = serde_json::to_value(claims).map_err(|_| {
82                Error::internal("failed to serialize token")
83                    .chain(JwtError::SerializationFailed)
84                    .with_code(JwtError::SerializationFailed.code())
85            })?;
86            // Only inject exp when the payload has no exp field already
87            if value.get("exp").is_none() {
88                let now = std::time::SystemTime::now()
89                    .duration_since(std::time::UNIX_EPOCH)
90                    .expect("system clock before UNIX epoch")
91                    .as_secs();
92                value["exp"] = serde_json::Value::Number((now + default_exp.as_secs()).into());
93            }
94            serde_json::to_vec(&value)
95        } else {
96            serde_json::to_vec(claims)
97        }
98        .map_err(|_| {
99            Error::internal("unauthorized")
100                .chain(JwtError::SerializationFailed)
101                .with_code(JwtError::SerializationFailed.code())
102        })?;
103
104        let alg = self.inner.signer.algorithm_name();
105        let header = format!(r#"{{"alg":"{alg}","typ":"JWT"}}"#);
106        let header_b64 = base64url::encode(header.as_bytes());
107        let payload_b64 = base64url::encode(&claims_json);
108
109        let header_payload = format!("{header_b64}.{payload_b64}");
110        let signature = self.inner.signer.sign(header_payload.as_bytes())?;
111        let signature_b64 = base64url::encode(&signature);
112
113        Ok(format!("{header_payload}.{signature_b64}"))
114    }
115}
116
117impl Clone for JwtEncoder {
118    fn clone(&self) -> Self {
119        Self {
120            inner: self.inner.clone(),
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use serde::Deserialize;
129
130    use super::super::claims::Claims;
131
132    fn test_config() -> JwtSessionsConfig {
133        JwtSessionsConfig {
134            signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
135            ..JwtSessionsConfig::default()
136        }
137    }
138
139    #[test]
140    fn encode_produces_three_part_token() {
141        let encoder = JwtEncoder::from_config(&test_config());
142        let claims = Claims::new().with_exp(9999999999);
143        let token = encoder.encode(&claims).unwrap();
144        assert_eq!(token.split('.').count(), 3);
145    }
146
147    #[test]
148    fn encode_header_contains_hs256() {
149        let encoder = JwtEncoder::from_config(&test_config());
150        let claims = Claims::new().with_exp(9999999999);
151        let token = encoder.encode(&claims).unwrap();
152        let header_b64 = token.split('.').next().unwrap();
153        let header_bytes = base64url::decode(header_b64).unwrap();
154        let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
155        assert_eq!(header["alg"], "HS256");
156        assert_eq!(header["typ"], "JWT");
157    }
158
159    #[test]
160    fn encode_with_default_expiry_auto_sets_exp() {
161        // access_ttl_secs is always used as default expiry — no manual override needed.
162        let config = test_config(); // access_ttl_secs defaults to 900
163        let encoder = JwtEncoder::from_config(&config);
164        let claims = Claims::new(); // no exp — should be auto-filled from access_ttl_secs
165        let token = encoder.encode(&claims).unwrap();
166        let payload_b64 = token.split('.').nth(1).unwrap();
167        let payload_bytes = base64url::decode(payload_b64).unwrap();
168        let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
169        assert!(payload.get("exp").is_some());
170    }
171
172    #[test]
173    fn encode_explicit_exp_not_overwritten() {
174        let config = test_config();
175        let encoder = JwtEncoder::from_config(&config);
176        let claims = Claims::new().with_exp(42);
177        let token = encoder.encode(&claims).unwrap();
178        let payload_b64 = token.split('.').nth(1).unwrap();
179        let payload_bytes = base64url::decode(payload_b64).unwrap();
180        let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
181        assert_eq!(payload["exp"], 42);
182    }
183
184    #[test]
185    fn encode_custom_struct_directly() {
186        #[derive(Debug, Clone, Serialize, Deserialize)]
187        struct CustomPayload {
188            sub: String,
189            role: String,
190            exp: u64,
191        }
192
193        let encoder = JwtEncoder::from_config(&test_config());
194        let payload = CustomPayload {
195            sub: "user_1".into(),
196            role: "admin".into(),
197            exp: 9999999999,
198        };
199        let token = encoder.encode(&payload).unwrap();
200        assert_eq!(token.split('.').count(), 3);
201    }
202
203    #[test]
204    fn clone_produces_working_encoder() {
205        let encoder = JwtEncoder::from_config(&test_config());
206        let cloned = encoder.clone();
207        let claims = Claims::new().with_exp(9999999999);
208        assert!(cloned.encode(&claims).is_ok());
209    }
210}