1use chrono::{DateTime, Duration, Utc};
2use parity_scale_codec::{Decode, Encode, MaxEncodedLen};
3use scale_info::TypeInfo;
4use serde::{Deserialize, Serialize};
5
6use crate::{Claim, ValidationError};
7
8#[derive(Debug, Clone, Copy)]
29#[non_exhaustive]
30pub struct TimeOptions<F = fn() -> DateTime<Utc>> {
31 pub leeway: Duration,
33 pub clock_fn: F,
35}
36
37impl<F: Fn() -> DateTime<Utc>> TimeOptions<F> {
38 pub const fn new(leeway: Duration, clock_fn: F) -> Self {
40 Self { leeway, clock_fn }
41 }
42}
43
44impl TimeOptions {
45 #[cfg(feature = "clock")]
47 #[cfg_attr(docsrs, doc(cfg(feature = "clock")))]
48 pub fn from_leeway(leeway: Duration) -> Self {
49 Self { leeway, clock_fn: Utc::now }
50 }
51}
52
53#[cfg(feature = "clock")]
57impl Default for TimeOptions {
58 fn default() -> Self {
59 Self::from_leeway(Duration::seconds(60))
60 }
61}
62
63#[derive(
65 Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize, Encode, Decode, TypeInfo, MaxEncodedLen,
66)]
67pub struct Empty {}
68
69#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Ord, PartialOrd)]
77pub struct Claims<T> {
78 #[serde(rename = "exp", default, skip_serializing_if = "Option::is_none", with = "self::serde_timestamp")]
80 pub expiration: Option<DateTime<Utc>>,
81
82 #[serde(rename = "nbf", default, skip_serializing_if = "Option::is_none", with = "self::serde_timestamp")]
84 pub not_before: Option<DateTime<Utc>>,
85
86 #[serde(rename = "iat", default, skip_serializing_if = "Option::is_none", with = "self::serde_timestamp")]
88 pub issued_at: Option<DateTime<Utc>>,
89
90 #[serde(flatten)]
92 pub custom: T,
93}
94
95impl Claims<Empty> {
96 pub const fn empty() -> Self {
98 Self { expiration: None, not_before: None, issued_at: None, custom: Empty {} }
99 }
100}
101
102impl<T> Claims<T> {
103 pub const fn new(custom_claims: T) -> Self {
105 Self { expiration: None, not_before: None, issued_at: None, custom: custom_claims }
106 }
107
108 #[must_use]
111 pub fn set_duration<F>(self, options: &TimeOptions<F>, duration: Duration) -> Self
112 where
113 F: Fn() -> DateTime<Utc>,
114 {
115 Self { expiration: Some((options.clock_fn)() + duration), ..self }
116 }
117
118 #[must_use]
121 pub fn set_duration_and_issuance<F>(self, options: &TimeOptions<F>, duration: Duration) -> Self
122 where
123 F: Fn() -> DateTime<Utc>,
124 {
125 let issued_at = (options.clock_fn)();
126 Self { expiration: Some(issued_at + duration), issued_at: Some(issued_at), ..self }
127 }
128
129 #[must_use]
131 pub fn set_not_before(self, moment: DateTime<Utc>) -> Self {
132 Self { not_before: Some(moment), ..self }
133 }
134
135 pub fn validate_expiration<F>(&self, options: &TimeOptions<F>) -> Result<&Self, ValidationError>
140 where
141 F: Fn() -> DateTime<Utc>,
142 {
143 self.expiration.map_or(Err(ValidationError::NoClaim(Claim::Expiration)), |expiration| {
144 let expiration_with_leeway =
145 expiration.checked_add_signed(options.leeway).unwrap_or(DateTime::<Utc>::MAX_UTC);
146 if (options.clock_fn)() > expiration_with_leeway {
147 Err(ValidationError::Expired)
148 } else {
149 Ok(self)
150 }
151 })
152 }
153
154 pub fn validate_maturity<F>(&self, options: &TimeOptions<F>) -> Result<&Self, ValidationError>
159 where
160 F: Fn() -> DateTime<Utc>,
161 {
162 self.not_before.map_or(Err(ValidationError::NoClaim(Claim::NotBefore)), |not_before| {
163 if (options.clock_fn)() < not_before - options.leeway {
164 Err(ValidationError::NotMature)
165 } else {
166 Ok(self)
167 }
168 })
169 }
170}
171
172mod serde_timestamp {
173 use chrono::{offset::TimeZone, DateTime, Utc};
174 use serde::{
175 de::{Error as DeError, Visitor},
176 Deserializer, Serializer,
177 };
178
179 use core::fmt;
180
181 struct TimestampVisitor;
182
183 impl<'de> Visitor<'de> for TimestampVisitor {
184 type Value = DateTime<Utc>;
185
186 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
187 formatter.write_str("UTC timestamp")
188 }
189
190 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
191 where
192 E: DeError,
193 {
194 Utc.timestamp_opt(value, 0).single().ok_or_else(|| E::custom("UTC timestamp overflow"))
195 }
196
197 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
198 where
199 E: DeError,
200 {
201 let value = i64::try_from(value).map_err(DeError::custom)?;
202 Utc.timestamp_opt(value, 0).single().ok_or_else(|| E::custom("UTC timestamp overflow"))
203 }
204
205 #[allow(clippy::cast_possible_truncation)]
206 fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
208 where
209 E: DeError,
210 {
211 Utc.timestamp_opt(value as i64, 0).single().ok_or_else(|| E::custom("UTC timestamp overflow"))
212 }
213 }
214
215 pub fn serialize<S: Serializer>(time: &Option<DateTime<Utc>>, serializer: S) -> Result<S::Ok, S::Error> {
216 serializer.serialize_i64(time.unwrap().timestamp())
218 }
219
220 pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Option<DateTime<Utc>>, D::Error> {
221 deserializer.deserialize_i64(TimestampVisitor).map(Some)
222 }
223}
224
225#[cfg(all(test, feature = "clock"))]
226mod tests {
227 use super::*;
228 use assert_matches::assert_matches;
229 use chrono::TimeZone;
230
231 #[test]
232 fn empty_claims_can_be_serialized() {
233 let mut claims = Claims::empty();
234 assert!(serde_json::to_string(&claims).is_ok());
235 claims.expiration = Some(Utc::now());
236 assert!(serde_json::to_string(&claims).is_ok());
237 claims.not_before = Some(Utc::now());
238 assert!(serde_json::to_string(&claims).is_ok());
239 }
240
241 #[test]
242 #[cfg(feature = "ciborium")]
243 fn empty_claims_can_be_serialized_to_cbor() {
244 let mut claims = Claims::empty();
245 assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
246 claims.expiration = Some(Utc::now());
247 assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
248 claims.not_before = Some(Utc::now());
249 assert!(ciborium::into_writer(&claims, &mut vec![]).is_ok());
250 }
251
252 #[test]
253 fn expired_claim() {
254 let mut claims = Claims::empty();
255 let time_options = TimeOptions::default();
256 assert_matches!(
257 claims.validate_expiration(&time_options).unwrap_err(),
258 ValidationError::NoClaim(Claim::Expiration)
259 );
260
261 claims.expiration = Some(DateTime::<Utc>::MAX_UTC);
262 assert!(claims.validate_expiration(&time_options).is_ok());
263
264 claims.expiration = Some(Utc::now() - Duration::hours(1));
265 assert_matches!(claims.validate_expiration(&time_options).unwrap_err(), ValidationError::Expired);
266
267 claims.expiration = Some(Utc::now() - Duration::seconds(10));
268 assert!(claims.validate_expiration(&time_options).is_ok());
270 assert_matches!(
272 claims.validate_expiration(&TimeOptions::from_leeway(Duration::seconds(5))).unwrap_err(),
273 ValidationError::Expired
274 );
275 let expiration = claims.expiration.unwrap();
277 assert!(claims.validate_expiration(&TimeOptions::new(Duration::seconds(3), move || { expiration })).is_ok());
278 }
279
280 #[test]
281 fn immature_claim() {
282 let mut claims = Claims::empty();
283 let time_options = TimeOptions::default();
284 assert_matches!(
285 claims.validate_maturity(&time_options).unwrap_err(),
286 ValidationError::NoClaim(Claim::NotBefore)
287 );
288
289 claims.not_before = Some(Utc::now() + Duration::hours(1));
290 assert_matches!(claims.validate_maturity(&time_options).unwrap_err(), ValidationError::NotMature);
291
292 claims.not_before = Some(Utc::now() + Duration::seconds(10));
293 assert!(claims.validate_maturity(&time_options).is_ok());
295 assert_matches!(
297 claims.validate_maturity(&TimeOptions::from_leeway(Duration::seconds(5))).unwrap_err(),
298 ValidationError::NotMature
299 );
300 }
301 #[test]
302 fn float_timestamp() {
303 let claims = "{\"exp\": 1.691203462e+9}";
304 let claims: Claims<Empty> = serde_json::from_str(claims).unwrap();
305 let timestamp = Utc.timestamp_opt(1_691_203_462, 0).single().unwrap();
306 assert_eq!(claims.expiration, Some(timestamp));
307 }
308
309 #[test]
310 fn float_timestamp_errors() {
311 let invalid_claims = ["{\"exp\": 1e20}", "{\"exp\": -1e20}"];
312 for claims in invalid_claims {
313 let err = serde_json::from_str::<Claims<Empty>>(claims).unwrap_err();
314 let err = err.to_string();
315 assert!(err.contains("UTC timestamp overflow"), "{err}");
316 }
317 }
318}