1use std::{error::Error, fmt::Display};
2
3use chrono::TimeDelta;
4use serde::{
5 de::{self, Visitor},
6 Deserializer, Serializer,
7};
8
9#[allow(dead_code)]
10pub fn serialize_duration<S>(duration: &TimeDelta, serializer: S) -> Result<S::Ok, S::Error>
11where
12 S: Serializer,
13{
14 match duration.num_nanoseconds() {
15 Some(ns) => serializer.serialize_str(format!("{}ns", ns).as_str()),
19 None => Err(serde::ser::Error::custom(
20 "duration nanoseconds out of range of i64",
21 )),
22 }
23}
24
25#[allow(dead_code)]
26pub fn serialize_duration_option<S>(
27 duration: &Option<TimeDelta>,
28 serializer: S,
29) -> Result<S::Ok, S::Error>
30where
31 S: Serializer,
32{
33 match duration {
34 Some(duration) => serialize_duration(duration, serializer),
35 None => serializer.serialize_none(),
36 }
37}
38
39#[derive(Debug)]
40enum DurationDeserializationError {
41 InvalidDuration { duration: String },
42 InvalidUnit { duration: String, unit: String },
43 MissingUnit { duration: String },
44}
45impl Error for DurationDeserializationError {}
46impl Display for DurationDeserializationError {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Self::InvalidDuration { duration } => write!(f, "invalid duration \"{}\"", duration),
50 Self::InvalidUnit { duration, unit } => {
51 write!(f, "unknown unit \"{}\" in duration \"{}\"", unit, duration)
52 }
53 Self::MissingUnit { duration } => {
54 write!(f, "missing unit in duration \"{}\"", duration)
55 }
56 }
57 }
58}
59
60struct DurationVisitor;
61
62impl<'de> Visitor<'de> for DurationVisitor {
63 type Value = TimeDelta;
64
65 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
66 formatter.write_str("an i64 representing a duration in nanoseconds or a duration string")
67 }
68
69 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
70 where
71 E: serde::de::Error,
72 {
73 Ok(TimeDelta::nanoseconds(v))
74 }
75
76 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
77 where
78 E: de::Error,
79 {
80 i64::try_from(v)
81 .map(TimeDelta::nanoseconds)
82 .map_err(E::custom)
83 }
84
85 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
86 where
87 E: de::Error,
88 {
89 parse_duration_string(v).map_err(de::Error::custom)
90 }
91}
92
93#[allow(dead_code)]
94pub fn deserialize_duration<'de, D>(deserializer: D) -> Result<TimeDelta, D::Error>
95where
96 D: Deserializer<'de>,
97{
98 deserializer.deserialize_any(DurationVisitor)
99}
100
101struct DurationOptionVisitor;
106
107impl<'de> Visitor<'de> for DurationOptionVisitor {
108 type Value = Option<TimeDelta>;
109
110 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
111 formatter.write_str("an i64 representing a duration in nanoseconds or a duration string")
112 }
113
114 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
115 where
116 E: de::Error,
117 {
118 Ok(Some(TimeDelta::nanoseconds(v)))
119 }
120
121 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
122 where
123 E: de::Error,
124 {
125 i64::try_from(v)
126 .map(|v| Some(TimeDelta::nanoseconds(v)))
127 .map_err(E::custom)
128 }
129
130 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
131 where
132 E: de::Error,
133 {
134 if v.is_empty() {
135 Ok(None)
136 } else {
137 parse_duration_string(v)
138 .map_err(de::Error::custom)
139 .map(Some)
140 }
141 }
142
143 fn visit_none<E>(self) -> Result<Self::Value, E>
144 where
145 E: de::Error,
146 {
147 Ok(None)
148 }
149
150 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
151 where
152 D: Deserializer<'de>,
153 {
154 deserializer.deserialize_any(DurationOptionVisitor)
155 }
156
157 fn visit_unit<E>(self) -> Result<Self::Value, E>
158 where
159 E: de::Error,
160 {
161 Ok(None)
162 }
163}
164
165#[allow(dead_code)]
166pub fn deserialize_duration_option<'de, D>(deserializer: D) -> Result<Option<TimeDelta>, D::Error>
167where
168 D: Deserializer<'de>,
169{
170 deserializer.deserialize_option(DurationOptionVisitor)
171}
172
173fn cautious_take_while<'a, I, F, T>(
177 iter: &'a mut std::iter::Peekable<I>,
178 predicate: F,
179) -> impl Iterator<Item = T> + 'a
180where
181 I: Iterator<Item = T> + 'a,
182 F: Fn(&T) -> bool + 'a,
183{
184 std::iter::from_fn(move || {
185 if let Some(next) = iter.peek() {
186 if predicate(next) {
187 iter.next()
188 } else {
189 None
190 }
191 } else {
192 None
193 }
194 })
195}
196
197fn parse_duration_string(v: &str) -> Result<TimeDelta, DurationDeserializationError> {
198 let mut chars = v.chars().peekable();
199 let mut duration = TimeDelta::zero();
201 let mut is_negative = false;
202
203 if let Some(c) = chars.next_if(|&c| c == '-' || c == '+') {
205 is_negative = c == '-';
206 }
207
208 if chars.clone().nth(1).is_none() && chars.clone().next() == Some('0') {
210 return Ok(TimeDelta::zero());
211 }
212 if chars.peek().is_none() {
213 return Err(DurationDeserializationError::InvalidDuration {
214 duration: v.to_string(),
215 });
216 }
217
218 while chars.peek().is_some() {
219 let leading_int = cautious_take_while(chars.by_ref(), is_digit).collect::<String>();
220
221 chars.by_ref().next_if_eq(&'.'); let fraction_str = cautious_take_while(chars.by_ref(), is_digit).collect::<String>();
224
225 if leading_int.is_empty() && fraction_str.is_empty() {
226 return Err(DurationDeserializationError::InvalidDuration {
227 duration: v.to_string(),
228 });
229 }
230
231 let unit =
232 cautious_take_while(chars.by_ref(), |c| *c != '.' && !is_digit(c)).collect::<String>();
233
234 if unit.is_empty() {
235 return Err(DurationDeserializationError::MissingUnit {
236 duration: v.to_string(),
237 });
238 }
239
240 let multiplier = get_unit_multiplier(unit, v)?;
241
242 let ns = if leading_int.is_empty() {
243 0i64
244 } else {
245 match leading_int.parse() {
246 Ok(v) => v,
247 Err(_) => {
248 return Err(DurationDeserializationError::InvalidDuration {
249 duration: v.to_string(),
250 })
251 }
252 }
253 };
254 let ns = ns.checked_mul(multiplier).ok_or_else(|| {
255 DurationDeserializationError::InvalidDuration {
256 duration: v.to_string(),
257 }
258 })?;
259
260 let fraction = if fraction_str.is_empty() {
261 0i64
262 } else {
263 match fraction_str.parse() {
264 Ok(v) => v,
265 Err(_) => {
266 return Err(DurationDeserializationError::InvalidDuration {
267 duration: v.to_string(),
268 })
269 }
270 }
271 };
272 let ns = if fraction > 0 {
273 let fraction_length = u32::try_from(fraction_str.len()).map_err(|_| {
276 DurationDeserializationError::InvalidDuration {
277 duration: v.to_string(),
278 }
279 })?;
280
281 let scale = i64::pow(10, fraction_length);
282 let f = fraction as f64 * (multiplier as f64 / scale as f64);
283
284 ns.checked_add(f as i64).ok_or_else(|| {
285 DurationDeserializationError::InvalidDuration {
286 duration: v.to_string(),
287 }
288 })?
289 } else {
290 ns
291 };
292
293 duration = duration
294 .checked_add(&TimeDelta::nanoseconds(ns))
295 .ok_or_else(|| DurationDeserializationError::InvalidDuration {
296 duration: v.to_string(),
297 })?;
298 }
299
300 Ok(if is_negative { -duration } else { duration })
301}
302
303fn is_digit(c: &char) -> bool {
304 '0' <= *c && *c <= '9'
305}
306
307fn get_unit_multiplier(unit: String, duration: &str) -> Result<i64, DurationDeserializationError> {
308 let multiplier = match unit.as_str() {
309 "ns" => Ok(TimeDelta::nanoseconds(1)),
310 "us" => Ok(TimeDelta::microseconds(1)),
311 "\u{00B5}s" => Ok(TimeDelta::microseconds(1)), "\u{03BC}s" => Ok(TimeDelta::microseconds(1)), "ms" => Ok(TimeDelta::milliseconds(1)),
314 "s" => Ok(TimeDelta::seconds(1)),
315 "m" => Ok(TimeDelta::minutes(1)),
316 "h" => Ok(TimeDelta::hours(1)),
317 _ => Err(DurationDeserializationError::InvalidUnit {
318 duration: duration.to_string(),
319 unit,
320 }),
321 }?;
322 match multiplier.num_nanoseconds() {
323 Some(ns) => Ok(ns),
324 _ => Err(DurationDeserializationError::InvalidDuration {
325 duration: duration.to_string(),
326 }),
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use chrono::TimeDelta;
333 use serde::Deserialize;
334
335 #[derive(Debug, Deserialize, PartialEq, Eq)]
336 struct OptionalDuration {
337 #[serde(deserialize_with = "super::deserialize_duration_option", default)]
338 duration: Option<chrono::TimeDelta>,
339 }
340
341 #[derive(Debug, Deserialize, PartialEq, Eq)]
342 struct Duration {
343 #[serde(deserialize_with = "super::deserialize_duration")]
344 duration: chrono::TimeDelta,
345 }
346
347 #[test]
348 pub fn deserialize_duration_parses_u64() {
349 let input = r#"{"duration":1234}"#;
350
351 let x1 = serde_json::from_str::<Duration>(input).unwrap();
352 let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
353
354 assert_eq!(x1.duration.num_nanoseconds(), Some(1234));
355 assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(1234));
356 }
357
358 #[test]
359 pub fn deserialize_duration_parses_i64() {
360 let input = r#"{"duration":-1234}"#;
361
362 let x1 = serde_json::from_str::<Duration>(input).unwrap();
363 let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
364
365 assert_eq!(x1.duration.num_nanoseconds(), Some(-1234));
366 assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(-1234));
367 }
368
369 #[test]
370 pub fn deserialize_duration_parses_positive_duration_string() {
371 let input = r#"{"duration":"+1h1m1s1ms1us1ns"}"#;
372 let expected = 1; let expected = expected * 60 + 1; let expected = expected * 60 + 1; let expected = expected * 1000 + 1; let expected = expected * 1000 + 1; let expected = expected * 1000 + 1; let x1 = serde_json::from_str::<Duration>(input).unwrap();
380 let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
381
382 assert_eq!(x1.duration.num_nanoseconds(), Some(expected));
383 assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(expected));
384 }
385
386 #[test]
387 pub fn deserialize_duration_parses_negative_duration_string() {
388 let input = r#"{"duration":"-1h1m1s1ms1us1ns"}"#;
389 let expected = 1; let expected = expected * 60 + 1; let expected = expected * 60 + 1; let expected = expected * 1000 + 1; let expected = expected * 1000 + 1; let expected = expected * 1000 + 1; let expected = -expected; let x1 = serde_json::from_str::<Duration>(input).unwrap();
398 let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
399
400 assert_eq!(x1.duration.num_nanoseconds(), Some(expected));
401 assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(expected));
402 }
403
404 #[test]
405 pub fn deserialize_duration_parses_fractional_duration_string() {
406 let input = r#"{"duration":"1.5h15m"}"#;
407 let expected = 1; let expected = expected * 60 + 45; let expected = expected * 60; let expected = expected * 1000; let expected = expected * 1000; let expected = expected * 1000; let x1 = serde_json::from_str::<Duration>(input).unwrap();
415 let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
416
417 assert_eq!(x1.duration.num_nanoseconds(), Some(expected));
418 assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(expected));
419 }
420
421 #[test]
422 pub fn deserialize_duration_parses_duration_strings() {
423 let test_pairs = [
426 ("0s", TimeDelta::nanoseconds(0)),
427 ("1ns", TimeDelta::nanoseconds(1)),
428 ("1.1µs", TimeDelta::nanoseconds(1100)), ("1.1μs", TimeDelta::nanoseconds(1100)), ("1.1us", TimeDelta::nanoseconds(1100)),
431 ("2.2ms", TimeDelta::microseconds(2200)),
432 ("3.3s", TimeDelta::milliseconds(3300)),
433 ("4m5s", TimeDelta::minutes(4) + TimeDelta::seconds(5)),
434 (
435 "4m5.001s",
436 TimeDelta::minutes(4) + TimeDelta::milliseconds(5001),
437 ),
438 (
439 "5h6m7.001s",
440 TimeDelta::hours(5) + TimeDelta::minutes(6) + TimeDelta::milliseconds(7001),
441 ),
442 (
443 "8m0.000000001s",
444 TimeDelta::minutes(8) + TimeDelta::nanoseconds(1),
445 ),
446 ("2562047h47m16.854775807s", TimeDelta::nanoseconds(i64::MAX)),
447 (
448 "-2562047h47m16.854775808s",
449 TimeDelta::nanoseconds(i64::MIN),
450 ),
451 ];
452
453 for (input, expected) in test_pairs {
454 let json_input = format!(r#"{{"duration":"{}"}}"#, input);
455 assert_eq!(
456 serde_json::from_str::<Duration>(json_input.as_str())
457 .unwrap()
458 .duration,
459 expected
460 );
461 assert_eq!(
462 serde_json::from_str::<OptionalDuration>(json_input.as_str())
463 .unwrap()
464 .duration
465 .unwrap(),
466 expected
467 );
468 }
469 }
470
471 #[test]
472 pub fn deserialize_duration_option_returns_none_when_null() {
473 let input = r#"{"duration":null}"#;
474
475 let x = serde_json::from_str::<OptionalDuration>(input).unwrap();
476
477 assert_eq!(x, OptionalDuration { duration: None })
478 }
479
480 #[test]
481 pub fn deserialize_duration_option_returns_none_when_missing() {
482 let input = r#"{}"#;
483
484 let x = serde_json::from_str::<OptionalDuration>(input).unwrap();
485
486 assert_eq!(x, OptionalDuration { duration: None })
487 }
488
489 #[test]
490 pub fn deserialize_duration_option_returns_none_for_empty_string() {
491 let input = r#"{"duration":""}"#;
492
493 let x = serde_json::from_str::<OptionalDuration>(input).unwrap();
494
495 assert_eq!(x, OptionalDuration { duration: None })
496 }
497}