Skip to main content

modo/auth/jwt/
encoder.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::Serialize;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtConfig;
10use super::error::JwtError;
11use super::signer::{HmacSigner, TokenSigner};
12
13/// JWT token encoder. Signs tokens using a [`TokenSigner`].
14///
15/// Register in `Registry` for handler access via `Service<JwtEncoder>`.
16/// Cloning is cheap — state is stored behind `Arc`.
17pub struct JwtEncoder {
18    inner: Arc<JwtEncoderInner>,
19}
20
21struct JwtEncoderInner {
22    signer: Arc<dyn TokenSigner>,
23    default_expiry: Option<Duration>,
24    validation: super::validation::ValidationConfig,
25}
26
27impl JwtEncoder {
28    /// Creates a `JwtEncoder` from YAML configuration.
29    ///
30    /// Uses `HmacSigner` (HS256) with the configured secret. The validation
31    /// config (leeway, issuer, audience) is stored so a matching `JwtDecoder`
32    /// can be created via `JwtDecoder::from(&encoder)`.
33    pub fn from_config(config: &JwtConfig) -> Self {
34        let signer = HmacSigner::new(config.secret.as_bytes());
35        Self {
36            inner: Arc::new(JwtEncoderInner {
37                signer: Arc::new(signer),
38                default_expiry: config.default_expiry.map(Duration::from_secs),
39                validation: super::validation::ValidationConfig {
40                    leeway: Duration::from_secs(config.leeway),
41                    require_issuer: config.issuer.clone(),
42                    require_audience: config.audience.clone(),
43                },
44            }),
45        }
46    }
47
48    /// Returns a reference to the inner signer (as verifier).
49    /// Used by `JwtDecoder::from(&encoder)` to share the same key.
50    pub(super) fn verifier(&self) -> Arc<dyn super::signer::TokenVerifier> {
51        // Trait upcasting: Arc<dyn TokenSigner> → Arc<dyn TokenVerifier>
52        // Stabilized in Rust 1.76.
53        self.inner.signer.clone() as Arc<dyn super::signer::TokenVerifier>
54    }
55
56    /// Returns a clone of the validation config.
57    /// Used by `JwtDecoder::from(&encoder)`.
58    pub(super) fn validation(&self) -> super::validation::ValidationConfig {
59        self.inner.validation.clone()
60    }
61
62    /// Encodes claims into a signed JWT token string.
63    ///
64    /// If `claims.exp` is `None` and `default_expiry` is configured,
65    /// `exp` is automatically set to `now + default_expiry` before signing.
66    /// An explicitly set `exp` is never overwritten.
67    ///
68    /// # Errors
69    ///
70    /// Returns `Error::internal` with [`JwtError::SerializationFailed`](super::JwtError::SerializationFailed)
71    /// if the claims cannot be serialized to JSON, or
72    /// [`JwtError::SigningFailed`](super::JwtError::SigningFailed) if the HMAC signing
73    /// operation fails.
74    pub fn encode<T: Serialize>(&self, claims: &super::claims::Claims<T>) -> Result<String> {
75        // Auto-fill exp if missing and default_expiry is configured
76        let claims_json = if claims.exp.is_none() {
77            if let Some(default_exp) = self.inner.default_expiry {
78                let now = std::time::SystemTime::now()
79                    .duration_since(std::time::UNIX_EPOCH)
80                    .expect("system clock before UNIX epoch")
81                    .as_secs();
82                let mut value = serde_json::to_value(claims).map_err(|_| {
83                    Error::internal("failed to serialize token")
84                        .chain(JwtError::SerializationFailed)
85                        .with_code(JwtError::SerializationFailed.code())
86                })?;
87                value["exp"] = serde_json::Value::Number((now + default_exp.as_secs()).into());
88                serde_json::to_vec(&value)
89            } else {
90                serde_json::to_vec(claims)
91            }
92        } else {
93            serde_json::to_vec(claims)
94        }
95        .map_err(|_| {
96            Error::internal("unauthorized")
97                .chain(JwtError::SerializationFailed)
98                .with_code(JwtError::SerializationFailed.code())
99        })?;
100
101        let alg = self.inner.signer.algorithm_name();
102        let header = format!(r#"{{"alg":"{alg}","typ":"JWT"}}"#);
103        let header_b64 = base64url::encode(header.as_bytes());
104        let payload_b64 = base64url::encode(&claims_json);
105
106        let header_payload = format!("{header_b64}.{payload_b64}");
107        let signature = self.inner.signer.sign(header_payload.as_bytes())?;
108        let signature_b64 = base64url::encode(&signature);
109
110        Ok(format!("{header_payload}.{signature_b64}"))
111    }
112}
113
114impl Clone for JwtEncoder {
115    fn clone(&self) -> Self {
116        Self {
117            inner: self.inner.clone(),
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use serde::Deserialize;
126
127    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
128    struct TestClaims {
129        role: String,
130    }
131
132    fn test_config() -> JwtConfig {
133        JwtConfig {
134            secret: "test-secret-key-at-least-32-bytes-long!".into(),
135            default_expiry: None,
136            leeway: 0,
137            issuer: None,
138            audience: None,
139        }
140    }
141
142    #[test]
143    fn encode_produces_three_part_token() {
144        let encoder = JwtEncoder::from_config(&test_config());
145        let claims = super::super::claims::Claims::new(TestClaims {
146            role: "admin".into(),
147        })
148        .with_exp(9999999999);
149        let token = encoder.encode(&claims).unwrap();
150        assert_eq!(token.split('.').count(), 3);
151    }
152
153    #[test]
154    fn encode_header_contains_hs256() {
155        let encoder = JwtEncoder::from_config(&test_config());
156        let claims = super::super::claims::Claims::new(()).with_exp(9999999999);
157        let token = encoder.encode(&claims).unwrap();
158        let header_b64 = token.split('.').next().unwrap();
159        let header_bytes = base64url::decode(header_b64).unwrap();
160        let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
161        assert_eq!(header["alg"], "HS256");
162        assert_eq!(header["typ"], "JWT");
163    }
164
165    #[test]
166    fn encode_with_default_expiry_auto_sets_exp() {
167        let mut config = test_config();
168        config.default_expiry = Some(3600);
169        let encoder = JwtEncoder::from_config(&config);
170        let claims = super::super::claims::Claims::new(());
171        // claims.exp is None — should be auto-filled
172        let token = encoder.encode(&claims).unwrap();
173        let payload_b64 = token.split('.').nth(1).unwrap();
174        let payload_bytes = base64url::decode(payload_b64).unwrap();
175        let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
176        assert!(payload.get("exp").is_some());
177    }
178
179    #[test]
180    fn encode_explicit_exp_not_overwritten() {
181        let mut config = test_config();
182        config.default_expiry = Some(3600);
183        let encoder = JwtEncoder::from_config(&config);
184        let claims = super::super::claims::Claims::new(()).with_exp(42);
185        let token = encoder.encode(&claims).unwrap();
186        let payload_b64 = token.split('.').nth(1).unwrap();
187        let payload_bytes = base64url::decode(payload_b64).unwrap();
188        let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
189        assert_eq!(payload["exp"], 42);
190    }
191
192    #[test]
193    fn clone_produces_working_encoder() {
194        let encoder = JwtEncoder::from_config(&test_config());
195        let cloned = encoder.clone();
196        let claims = super::super::claims::Claims::new(()).with_exp(9999999999);
197        assert!(cloned.encode(&claims).is_ok());
198    }
199}