nomad_api_types/
duration.rs

1use std::{error::Error, fmt::Display};
2
3use chrono::TimeDelta;
4use serde::{
5    de::{self, Visitor},
6    Deserializer, Serializer,
7};
8
9#[allow(dead_code)]
10pub fn serialize_duration<S>(duration: &TimeDelta, serializer: S) -> Result<S::Ok, S::Error>
11where
12    S: Serializer,
13{
14    match duration.num_nanoseconds() {
15        // TODO: implement some nicer printing instead of directly using ns.
16        // A duration of 10 minutes should be serialized as "10m" and not
17        // "600000000000ns".
18        Some(ns) => serializer.serialize_str(format!("{}ns", ns).as_str()),
19        None => Err(serde::ser::Error::custom(
20            "duration nanoseconds out of range of i64",
21        )),
22    }
23}
24
25#[allow(dead_code)]
26pub fn serialize_duration_option<S>(
27    duration: &Option<TimeDelta>,
28    serializer: S,
29) -> Result<S::Ok, S::Error>
30where
31    S: Serializer,
32{
33    match duration {
34        Some(duration) => serialize_duration(duration, serializer),
35        None => serializer.serialize_none(),
36    }
37}
38
39#[derive(Debug)]
40enum DurationDeserializationError {
41    InvalidDuration { duration: String },
42    InvalidUnit { duration: String, unit: String },
43    MissingUnit { duration: String },
44}
45impl Error for DurationDeserializationError {}
46impl Display for DurationDeserializationError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::InvalidDuration { duration } => write!(f, "invalid duration \"{}\"", duration),
50            Self::InvalidUnit { duration, unit } => {
51                write!(f, "unknown unit \"{}\" in duration \"{}\"", unit, duration)
52            }
53            Self::MissingUnit { duration } => {
54                write!(f, "missing unit in duration \"{}\"", duration)
55            }
56        }
57    }
58}
59
60struct DurationVisitor;
61
62impl<'de> Visitor<'de> for DurationVisitor {
63    type Value = TimeDelta;
64
65    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
66        formatter.write_str("an i64 representing a duration in nanoseconds or a duration string")
67    }
68
69    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
70    where
71        E: serde::de::Error,
72    {
73        Ok(TimeDelta::nanoseconds(v))
74    }
75
76    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
77    where
78        E: de::Error,
79    {
80        i64::try_from(v)
81            .map(TimeDelta::nanoseconds)
82            .map_err(E::custom)
83    }
84
85    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
86    where
87        E: de::Error,
88    {
89        parse_duration_string(v).map_err(de::Error::custom)
90    }
91}
92
93#[allow(dead_code)]
94pub fn deserialize_duration<'de, D>(deserializer: D) -> Result<TimeDelta, D::Error>
95where
96    D: Deserializer<'de>,
97{
98    deserializer.deserialize_any(DurationVisitor)
99}
100
101// /// Small helper struct that makes deserializing Option<TimeDelta> easier.
102// #[derive(Debug, Deserialize)]
103// struct WrappedTimeDelta(#[serde(deserialize_with = "deserialize_duration")] TimeDelta);
104
105struct DurationOptionVisitor;
106
107impl<'de> Visitor<'de> for DurationOptionVisitor {
108    type Value = Option<TimeDelta>;
109
110    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
111        formatter.write_str("an i64 representing a duration in nanoseconds or a duration string")
112    }
113
114    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
115    where
116        E: de::Error,
117    {
118        Ok(Some(TimeDelta::nanoseconds(v)))
119    }
120
121    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
122    where
123        E: de::Error,
124    {
125        i64::try_from(v)
126            .map(|v| Some(TimeDelta::nanoseconds(v)))
127            .map_err(E::custom)
128    }
129
130    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
131    where
132        E: de::Error,
133    {
134        if v.is_empty() {
135            Ok(None)
136        } else {
137            parse_duration_string(v)
138                .map_err(de::Error::custom)
139                .map(Some)
140        }
141    }
142
143    fn visit_none<E>(self) -> Result<Self::Value, E>
144    where
145        E: de::Error,
146    {
147        Ok(None)
148    }
149
150    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
151    where
152        D: Deserializer<'de>,
153    {
154        deserializer.deserialize_any(DurationOptionVisitor)
155    }
156
157    fn visit_unit<E>(self) -> Result<Self::Value, E>
158    where
159        E: de::Error,
160    {
161        Ok(None)
162    }
163}
164
165#[allow(dead_code)]
166pub fn deserialize_duration_option<'de, D>(deserializer: D) -> Result<Option<TimeDelta>, D::Error>
167where
168    D: Deserializer<'de>,
169{
170    deserializer.deserialize_option(DurationOptionVisitor)
171}
172
173/// A non-greedy implementation of the take_while iterator method.
174/// Only works with Peekable instances. Doesn't consume the next item if the
175/// predicate returns `false`.
176fn cautious_take_while<'a, I, F, T>(
177    iter: &'a mut std::iter::Peekable<I>,
178    predicate: F,
179) -> impl Iterator<Item = T> + 'a
180where
181    I: Iterator<Item = T> + 'a,
182    F: Fn(&T) -> bool + 'a,
183{
184    std::iter::from_fn(move || {
185        if let Some(next) = iter.peek() {
186            if predicate(next) {
187                iter.next()
188            } else {
189                None
190            }
191        } else {
192            None
193        }
194    })
195}
196
197fn parse_duration_string(v: &str) -> Result<TimeDelta, DurationDeserializationError> {
198    let mut chars = v.chars().peekable();
199    // let mut nanoseconds = 0i64;
200    let mut duration = TimeDelta::zero();
201    let mut is_negative = false;
202
203    // Consume [-+]?
204    if let Some(c) = chars.next_if(|&c| c == '-' || c == '+') {
205        is_negative = c == '-';
206    }
207
208    // Special case: if all that is left is "0", this is zero.
209    if chars.clone().nth(1).is_none() && chars.clone().next() == Some('0') {
210        return Ok(TimeDelta::zero());
211    }
212    if chars.peek().is_none() {
213        return Err(DurationDeserializationError::InvalidDuration {
214            duration: v.to_string(),
215        });
216    }
217
218    while chars.peek().is_some() {
219        let leading_int = cautious_take_while(chars.by_ref(), is_digit).collect::<String>();
220
221        chars.by_ref().next_if_eq(&'.'); // consume optional decimal point
222
223        let fraction_str = cautious_take_while(chars.by_ref(), is_digit).collect::<String>();
224
225        if leading_int.is_empty() && fraction_str.is_empty() {
226            return Err(DurationDeserializationError::InvalidDuration {
227                duration: v.to_string(),
228            });
229        }
230
231        let unit =
232            cautious_take_while(chars.by_ref(), |c| *c != '.' && !is_digit(c)).collect::<String>();
233
234        if unit.is_empty() {
235            return Err(DurationDeserializationError::MissingUnit {
236                duration: v.to_string(),
237            });
238        }
239
240        let multiplier = get_unit_multiplier(unit, v)?;
241
242        let ns = if leading_int.is_empty() {
243            0i64
244        } else {
245            match leading_int.parse() {
246                Ok(v) => v,
247                Err(_) => {
248                    return Err(DurationDeserializationError::InvalidDuration {
249                        duration: v.to_string(),
250                    })
251                }
252            }
253        };
254        let ns = ns.checked_mul(multiplier).ok_or_else(|| {
255            DurationDeserializationError::InvalidDuration {
256                duration: v.to_string(),
257            }
258        })?;
259
260        let fraction = if fraction_str.is_empty() {
261            0i64
262        } else {
263            match fraction_str.parse() {
264                Ok(v) => v,
265                Err(_) => {
266                    return Err(DurationDeserializationError::InvalidDuration {
267                        duration: v.to_string(),
268                    })
269                }
270            }
271        };
272        let ns = if fraction > 0 {
273            // float64 is needed to be nanosecond accurate for fractions of hours.
274            // v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
275            let fraction_length = u32::try_from(fraction_str.len()).map_err(|_| {
276                DurationDeserializationError::InvalidDuration {
277                    duration: v.to_string(),
278                }
279            })?;
280
281            let scale = i64::pow(10, fraction_length);
282            let f = fraction as f64 * (multiplier as f64 / scale as f64);
283
284            ns.checked_add(f as i64).ok_or_else(|| {
285                DurationDeserializationError::InvalidDuration {
286                    duration: v.to_string(),
287                }
288            })?
289        } else {
290            ns
291        };
292
293        duration = duration
294            .checked_add(&TimeDelta::nanoseconds(ns))
295            .ok_or_else(|| DurationDeserializationError::InvalidDuration {
296                duration: v.to_string(),
297            })?;
298    }
299
300    Ok(if is_negative { -duration } else { duration })
301}
302
303fn is_digit(c: &char) -> bool {
304    '0' <= *c && *c <= '9'
305}
306
307fn get_unit_multiplier(unit: String, duration: &str) -> Result<i64, DurationDeserializationError> {
308    let multiplier = match unit.as_str() {
309        "ns" => Ok(TimeDelta::nanoseconds(1)),
310        "us" => Ok(TimeDelta::microseconds(1)),
311        "\u{00B5}s" => Ok(TimeDelta::microseconds(1)), // U+00B5 = micro symbol // "µs" => Ok(TimeDelta::microseconds(1)), // U+00B5 = micro symbol
312        "\u{03BC}s" => Ok(TimeDelta::microseconds(1)), // U+03BC = Greek letter mu // "μs" => Ok(TimeDelta::microseconds(1)), // U+03BC = Greek letter mu
313        "ms" => Ok(TimeDelta::milliseconds(1)),
314        "s" => Ok(TimeDelta::seconds(1)),
315        "m" => Ok(TimeDelta::minutes(1)),
316        "h" => Ok(TimeDelta::hours(1)),
317        _ => Err(DurationDeserializationError::InvalidUnit {
318            duration: duration.to_string(),
319            unit,
320        }),
321    }?;
322    match multiplier.num_nanoseconds() {
323        Some(ns) => Ok(ns),
324        _ => Err(DurationDeserializationError::InvalidDuration {
325            duration: duration.to_string(),
326        }),
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use chrono::TimeDelta;
333    use serde::Deserialize;
334
335    #[derive(Debug, Deserialize, PartialEq, Eq)]
336    struct OptionalDuration {
337        #[serde(deserialize_with = "super::deserialize_duration_option", default)]
338        duration: Option<chrono::TimeDelta>,
339    }
340
341    #[derive(Debug, Deserialize, PartialEq, Eq)]
342    struct Duration {
343        #[serde(deserialize_with = "super::deserialize_duration")]
344        duration: chrono::TimeDelta,
345    }
346
347    #[test]
348    pub fn deserialize_duration_parses_u64() {
349        let input = r#"{"duration":1234}"#;
350
351        let x1 = serde_json::from_str::<Duration>(input).unwrap();
352        let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
353
354        assert_eq!(x1.duration.num_nanoseconds(), Some(1234));
355        assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(1234));
356    }
357
358    #[test]
359    pub fn deserialize_duration_parses_i64() {
360        let input = r#"{"duration":-1234}"#;
361
362        let x1 = serde_json::from_str::<Duration>(input).unwrap();
363        let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
364
365        assert_eq!(x1.duration.num_nanoseconds(), Some(-1234));
366        assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(-1234));
367    }
368
369    #[test]
370    pub fn deserialize_duration_parses_positive_duration_string() {
371        let input = r#"{"duration":"+1h1m1s1ms1us1ns"}"#;
372        let expected = 1; // hours
373        let expected = expected * 60 + 1; // minutes
374        let expected = expected * 60 + 1; // seconds
375        let expected = expected * 1000 + 1; // milliseconds
376        let expected = expected * 1000 + 1; // microseconds
377        let expected = expected * 1000 + 1; // nanoseconds
378
379        let x1 = serde_json::from_str::<Duration>(input).unwrap();
380        let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
381
382        assert_eq!(x1.duration.num_nanoseconds(), Some(expected));
383        assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(expected));
384    }
385
386    #[test]
387    pub fn deserialize_duration_parses_negative_duration_string() {
388        let input = r#"{"duration":"-1h1m1s1ms1us1ns"}"#;
389        let expected = 1; // hours
390        let expected = expected * 60 + 1; // minutes
391        let expected = expected * 60 + 1; // seconds
392        let expected = expected * 1000 + 1; // milliseconds
393        let expected = expected * 1000 + 1; // microseconds
394        let expected = expected * 1000 + 1; // nanoseconds
395        let expected = -expected; // negative
396
397        let x1 = serde_json::from_str::<Duration>(input).unwrap();
398        let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
399
400        assert_eq!(x1.duration.num_nanoseconds(), Some(expected));
401        assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(expected));
402    }
403
404    #[test]
405    pub fn deserialize_duration_parses_fractional_duration_string() {
406        let input = r#"{"duration":"1.5h15m"}"#;
407        let expected = 1; // hours
408        let expected = expected * 60 + 45; // minutes
409        let expected = expected * 60; // seconds
410        let expected = expected * 1000; // milliseconds
411        let expected = expected * 1000; // microseconds
412        let expected = expected * 1000; // nanoseconds
413
414        let x1 = serde_json::from_str::<Duration>(input).unwrap();
415        let x2 = serde_json::from_str::<OptionalDuration>(input).unwrap();
416
417        assert_eq!(x1.duration.num_nanoseconds(), Some(expected));
418        assert_eq!(x2.duration.unwrap().num_nanoseconds(), Some(expected));
419    }
420
421    #[test]
422    pub fn deserialize_duration_parses_duration_strings() {
423        // These test cases are taken from the go time.Duration implementation
424        // see https://cs.opensource.google/go/go/+/master:src/time/time_test.go;l=613
425        let test_pairs = [
426            ("0s", TimeDelta::nanoseconds(0)),
427            ("1ns", TimeDelta::nanoseconds(1)),
428            ("1.1µs", TimeDelta::nanoseconds(1100)), // u00B5 (Micro Sign)
429            ("1.1μs", TimeDelta::nanoseconds(1100)), // u03BC (Greek Small Letter Mu)
430            ("1.1us", TimeDelta::nanoseconds(1100)),
431            ("2.2ms", TimeDelta::microseconds(2200)),
432            ("3.3s", TimeDelta::milliseconds(3300)),
433            ("4m5s", TimeDelta::minutes(4) + TimeDelta::seconds(5)),
434            (
435                "4m5.001s",
436                TimeDelta::minutes(4) + TimeDelta::milliseconds(5001),
437            ),
438            (
439                "5h6m7.001s",
440                TimeDelta::hours(5) + TimeDelta::minutes(6) + TimeDelta::milliseconds(7001),
441            ),
442            (
443                "8m0.000000001s",
444                TimeDelta::minutes(8) + TimeDelta::nanoseconds(1),
445            ),
446            ("2562047h47m16.854775807s", TimeDelta::nanoseconds(i64::MAX)),
447            (
448                "-2562047h47m16.854775808s",
449                TimeDelta::nanoseconds(i64::MIN),
450            ),
451        ];
452
453        for (input, expected) in test_pairs {
454            let json_input = format!(r#"{{"duration":"{}"}}"#, input);
455            assert_eq!(
456                serde_json::from_str::<Duration>(json_input.as_str())
457                    .unwrap()
458                    .duration,
459                expected
460            );
461            assert_eq!(
462                serde_json::from_str::<OptionalDuration>(json_input.as_str())
463                    .unwrap()
464                    .duration
465                    .unwrap(),
466                expected
467            );
468        }
469    }
470
471    #[test]
472    pub fn deserialize_duration_option_returns_none_when_null() {
473        let input = r#"{"duration":null}"#;
474
475        let x = serde_json::from_str::<OptionalDuration>(input).unwrap();
476
477        assert_eq!(x, OptionalDuration { duration: None })
478    }
479
480    #[test]
481    pub fn deserialize_duration_option_returns_none_when_missing() {
482        let input = r#"{}"#;
483
484        let x = serde_json::from_str::<OptionalDuration>(input).unwrap();
485
486        assert_eq!(x, OptionalDuration { duration: None })
487    }
488
489    #[test]
490    pub fn deserialize_duration_option_returns_none_for_empty_string() {
491        let input = r#"{"duration":""}"#;
492
493        let x = serde_json::from_str::<OptionalDuration>(input).unwrap();
494
495        assert_eq!(x, OptionalDuration { duration: None })
496    }
497}