gnostr_relay/
duration.rs

1//! flexible duration deserializer
2//! deserialize format:
3//! map: {"secs": 1, "nanos": 1}
4//! list: [1, 1]
5//! u64 as seconds: 1
6//! str by [`duration_str`]: 3m+1s
7//!
8use duration_str::parse;
9use serde::{
10    de::{Error, MapAccess, SeqAccess, Visitor},
11    Deserialize, Deserializer, Serialize, Serializer,
12};
13use std::{fmt, ops::Deref, time::Duration};
14
15/// Deserialize a `Duration`
16pub fn deserialize<'a, D>(d: D) -> Result<Duration, D::Error>
17where
18    D: Deserializer<'a>,
19{
20    d.deserialize_any(DurationVisitor)
21}
22
23/// Serialize a `Duration`
24pub fn serialize<S>(d: &Duration, s: S) -> Result<S::Ok, S::Error>
25where
26    S: Serializer,
27{
28    d.serialize(s)
29}
30
31#[derive(Serialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
32#[serde(into = "Duration")]
33pub struct NonZeroDuration(Duration);
34
35impl NonZeroDuration {
36    pub fn new(value: Duration) -> Option<Self> {
37        if value.is_zero() {
38            None
39        } else {
40            Some(Self(value))
41        }
42    }
43}
44
45// impl Into<Duration> for NonZeroDuration {
46//     fn into(self) -> Duration {
47//         self.0
48//     }
49// }
50
51impl From<NonZeroDuration> for Duration {
52    fn from(val: NonZeroDuration) -> Self {
53        val.0
54    }
55}
56
57impl TryFrom<Duration> for NonZeroDuration {
58    type Error = &'static str;
59    fn try_from(value: Duration) -> Result<Self, Self::Error> {
60        if value.is_zero() {
61            Err("duration can't be zero")
62        } else {
63            Ok(Self(value))
64        }
65    }
66}
67
68impl Deref for NonZeroDuration {
69    type Target = Duration;
70    fn deref(&self) -> &Self::Target {
71        &self.0
72    }
73}
74
75impl<'de> Deserialize<'de> for NonZeroDuration {
76    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
77    where
78        D: Deserializer<'de>,
79    {
80        deserializer
81            .deserialize_any(DurationVisitor)?
82            .try_into()
83            .map_err(D::Error::custom)
84    }
85}
86
87#[derive(Deserialize)]
88#[serde(field_identifier, rename_all = "lowercase")]
89enum Field {
90    Secs,
91    Nanos,
92}
93
94fn check_overflow<E>(secs: u64, nanos: u32) -> Result<(), E>
95where
96    E: Error,
97{
98    static NANOS_PER_SEC: u32 = 1_000_000_000;
99    match secs.checked_add((nanos / NANOS_PER_SEC) as u64) {
100        Some(_) => Ok(()),
101        None => Err(E::custom("overflow deserializing SystemTime epoch offset")),
102    }
103}
104
105// source: https://github.com/serde-rs/serde/blob/20a48c9580445b82e570c237159e4bce8b95831b/serde/src/de/impls.rs#L2037
106struct DurationVisitor;
107
108impl<'de> Visitor<'de> for DurationVisitor {
109    type Value = Duration;
110
111    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
112        formatter.write_str("struct Duration")
113    }
114
115    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
116    where
117        A: SeqAccess<'de>,
118    {
119        let secs: u64 = match seq.next_element()? {
120            Some(value) => value,
121            None => {
122                return Err(Error::invalid_length(0, &self));
123            }
124        };
125        let nanos: u32 = match seq.next_element()? {
126            Some(value) => value,
127            None => {
128                return Err(Error::invalid_length(1, &self));
129            }
130        };
131        check_overflow(secs, nanos)?;
132        Ok(Duration::new(secs, nanos))
133    }
134
135    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
136    where
137        A: MapAccess<'de>,
138    {
139        let mut secs: Option<u64> = None;
140        let mut nanos: Option<u32> = None;
141        while let Some(key) = map.next_key()? {
142            match key {
143                Field::Secs => {
144                    if secs.is_some() {
145                        return Err(<A::Error as Error>::duplicate_field("secs"));
146                    }
147                    secs = Some(map.next_value()?);
148                }
149                Field::Nanos => {
150                    if nanos.is_some() {
151                        return Err(<A::Error as Error>::duplicate_field("nanos"));
152                    }
153                    nanos = Some(map.next_value()?);
154                }
155            }
156        }
157        let secs = match secs {
158            Some(secs) => secs,
159            None => return Err(<A::Error as Error>::missing_field("secs")),
160        };
161        let nanos = match nanos {
162            Some(nanos) => nanos,
163            None => return Err(<A::Error as Error>::missing_field("nanos")),
164        };
165        check_overflow(secs, nanos)?;
166        Ok(Duration::new(secs, nanos))
167    }
168
169    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
170    where
171        E: Error,
172    {
173        check_overflow(v, 0)?;
174        Ok(Duration::from_secs(v))
175    }
176
177    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
178    where
179        E: Error,
180    {
181        parse(v).map_err(Error::custom)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use anyhow::Result;
189
190    #[derive(Deserialize, Serialize)]
191    struct Test {
192        #[serde(with = "super")]
193        time: Duration,
194    }
195    #[test]
196    fn der() -> Result<()> {
197        let t = serde_json::from_str::<Test>(r#"{"time": 1}"#)?;
198        assert_eq!(t.time, Duration::from_secs(1));
199
200        let t = serde_json::from_str::<Test>(r#"{"time": "1m"}"#)?;
201        assert_eq!(t.time, Duration::from_secs(60));
202
203        let t = serde_json::from_str::<Test>(r#"{"time": [1, 1]}"#)?;
204        assert_eq!(t.time, Duration::new(1, 1));
205
206        let t = serde_json::from_str::<Test>(r#"{"time": {"secs": 1, "nanos": 1}}"#)?;
207        assert_eq!(t.time, Duration::new(1, 1));
208
209        let t = serde_json::from_str::<Test>(r#"{"time": "1m"}"#)?;
210        let json = serde_json::to_string(&t)?;
211        let t = serde_json::from_str::<Test>(&json)?;
212        assert_eq!(t.time, Duration::from_secs(60));
213
214        let t = serde_json::from_str::<Test>(r#"{"time": 0}"#)?;
215        assert_eq!(t.time, Duration::from_secs(0));
216        Ok(())
217    }
218
219    #[derive(Deserialize, Serialize)]
220    struct TestNonZero {
221        time: NonZeroDuration,
222    }
223    #[test]
224    fn non_zero() -> Result<()> {
225        let t = serde_json::from_str::<TestNonZero>(r#"{"time": 1}"#)?;
226        assert_eq!(t.time, Duration::from_secs(1).try_into().unwrap());
227
228        let t = serde_json::from_str::<TestNonZero>(r#"{"time": "1m"}"#)?;
229        assert_eq!(t.time, Duration::from_secs(60).try_into().unwrap());
230
231        let t = serde_json::from_str::<TestNonZero>(r#"{"time": [1, 1]}"#)?;
232        assert_eq!(t.time, Duration::new(1, 1).try_into().unwrap());
233
234        let t = serde_json::from_str::<TestNonZero>(r#"{"time": {"secs": 1, "nanos": 1}}"#)?;
235        assert_eq!(t.time, Duration::new(1, 1).try_into().unwrap());
236
237        let t = serde_json::from_str::<TestNonZero>(r#"{"time": "1m"}"#)?;
238        let json = serde_json::to_string(&t)?;
239        let t = serde_json::from_str::<TestNonZero>(&json)?;
240        assert_eq!(t.time, Duration::from_secs(60).try_into().unwrap());
241
242        let t = serde_json::from_str::<TestNonZero>(r#"{"time": 0}"#);
243        assert!(t.is_err());
244        Ok(())
245    }
246}