1use std::sync::Arc;
2use std::time::Duration;
3
4use serde::Serialize;
5
6use crate::encoding::base64url;
7use crate::{Error, Result};
8
9use super::config::JwtConfig;
10use super::error::JwtError;
11use super::signer::{HmacSigner, TokenSigner};
12
13pub struct JwtEncoder {
18 inner: Arc<JwtEncoderInner>,
19}
20
21struct JwtEncoderInner {
22 signer: Arc<dyn TokenSigner>,
23 default_expiry: Option<Duration>,
24 validation: super::validation::ValidationConfig,
25}
26
27impl JwtEncoder {
28 pub fn from_config(config: &JwtConfig) -> Self {
34 let signer = HmacSigner::new(config.secret.as_bytes());
35 Self {
36 inner: Arc::new(JwtEncoderInner {
37 signer: Arc::new(signer),
38 default_expiry: config.default_expiry.map(Duration::from_secs),
39 validation: super::validation::ValidationConfig {
40 leeway: Duration::from_secs(config.leeway),
41 require_issuer: config.issuer.clone(),
42 require_audience: config.audience.clone(),
43 },
44 }),
45 }
46 }
47
48 pub(super) fn verifier(&self) -> Arc<dyn super::signer::TokenVerifier> {
51 self.inner.signer.clone() as Arc<dyn super::signer::TokenVerifier>
54 }
55
56 pub(super) fn validation(&self) -> super::validation::ValidationConfig {
59 self.inner.validation.clone()
60 }
61
62 pub fn encode<T: Serialize>(&self, claims: &super::claims::Claims<T>) -> Result<String> {
75 let claims_json = if claims.exp.is_none() {
77 if let Some(default_exp) = self.inner.default_expiry {
78 let now = std::time::SystemTime::now()
79 .duration_since(std::time::UNIX_EPOCH)
80 .expect("system clock before UNIX epoch")
81 .as_secs();
82 let mut value = serde_json::to_value(claims).map_err(|_| {
83 Error::internal("failed to serialize token")
84 .chain(JwtError::SerializationFailed)
85 .with_code(JwtError::SerializationFailed.code())
86 })?;
87 value["exp"] = serde_json::Value::Number((now + default_exp.as_secs()).into());
88 serde_json::to_vec(&value)
89 } else {
90 serde_json::to_vec(claims)
91 }
92 } else {
93 serde_json::to_vec(claims)
94 }
95 .map_err(|_| {
96 Error::internal("unauthorized")
97 .chain(JwtError::SerializationFailed)
98 .with_code(JwtError::SerializationFailed.code())
99 })?;
100
101 let alg = self.inner.signer.algorithm_name();
102 let header = format!(r#"{{"alg":"{alg}","typ":"JWT"}}"#);
103 let header_b64 = base64url::encode(header.as_bytes());
104 let payload_b64 = base64url::encode(&claims_json);
105
106 let header_payload = format!("{header_b64}.{payload_b64}");
107 let signature = self.inner.signer.sign(header_payload.as_bytes())?;
108 let signature_b64 = base64url::encode(&signature);
109
110 Ok(format!("{header_payload}.{signature_b64}"))
111 }
112}
113
114impl Clone for JwtEncoder {
115 fn clone(&self) -> Self {
116 Self {
117 inner: self.inner.clone(),
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use serde::Deserialize;
126
127 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
128 struct TestClaims {
129 role: String,
130 }
131
132 fn test_config() -> JwtConfig {
133 JwtConfig {
134 secret: "test-secret-key-at-least-32-bytes-long!".into(),
135 default_expiry: None,
136 leeway: 0,
137 issuer: None,
138 audience: None,
139 }
140 }
141
142 #[test]
143 fn encode_produces_three_part_token() {
144 let encoder = JwtEncoder::from_config(&test_config());
145 let claims = super::super::claims::Claims::new(TestClaims {
146 role: "admin".into(),
147 })
148 .with_exp(9999999999);
149 let token = encoder.encode(&claims).unwrap();
150 assert_eq!(token.split('.').count(), 3);
151 }
152
153 #[test]
154 fn encode_header_contains_hs256() {
155 let encoder = JwtEncoder::from_config(&test_config());
156 let claims = super::super::claims::Claims::new(()).with_exp(9999999999);
157 let token = encoder.encode(&claims).unwrap();
158 let header_b64 = token.split('.').next().unwrap();
159 let header_bytes = base64url::decode(header_b64).unwrap();
160 let header: serde_json::Value = serde_json::from_slice(&header_bytes).unwrap();
161 assert_eq!(header["alg"], "HS256");
162 assert_eq!(header["typ"], "JWT");
163 }
164
165 #[test]
166 fn encode_with_default_expiry_auto_sets_exp() {
167 let mut config = test_config();
168 config.default_expiry = Some(3600);
169 let encoder = JwtEncoder::from_config(&config);
170 let claims = super::super::claims::Claims::new(());
171 let token = encoder.encode(&claims).unwrap();
173 let payload_b64 = token.split('.').nth(1).unwrap();
174 let payload_bytes = base64url::decode(payload_b64).unwrap();
175 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
176 assert!(payload.get("exp").is_some());
177 }
178
179 #[test]
180 fn encode_explicit_exp_not_overwritten() {
181 let mut config = test_config();
182 config.default_expiry = Some(3600);
183 let encoder = JwtEncoder::from_config(&config);
184 let claims = super::super::claims::Claims::new(()).with_exp(42);
185 let token = encoder.encode(&claims).unwrap();
186 let payload_b64 = token.split('.').nth(1).unwrap();
187 let payload_bytes = base64url::decode(payload_b64).unwrap();
188 let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap();
189 assert_eq!(payload["exp"], 42);
190 }
191
192 #[test]
193 fn clone_produces_working_encoder() {
194 let encoder = JwtEncoder::from_config(&test_config());
195 let cloned = encoder.clone();
196 let claims = super::super::claims::Claims::new(()).with_exp(9999999999);
197 assert!(cloned.encode(&claims).is_ok());
198 }
199}