1use duration_str::parse;
9use serde::{
10 de::{Error, MapAccess, SeqAccess, Visitor},
11 Deserialize, Deserializer, Serialize, Serializer,
12};
13use std::{fmt, ops::Deref, time::Duration};
14
15pub fn deserialize<'a, D>(d: D) -> Result<Duration, D::Error>
17where
18 D: Deserializer<'a>,
19{
20 d.deserialize_any(DurationVisitor)
21}
22
23pub fn serialize<S>(d: &Duration, s: S) -> Result<S::Ok, S::Error>
25where
26 S: Serializer,
27{
28 d.serialize(s)
29}
30
31#[derive(Serialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
32#[serde(into = "Duration")]
33pub struct NonZeroDuration(Duration);
34
35impl NonZeroDuration {
36 pub fn new(value: Duration) -> Option<Self> {
37 if value.is_zero() {
38 None
39 } else {
40 Some(Self(value))
41 }
42 }
43}
44
45impl From<NonZeroDuration> for Duration {
52 fn from(val: NonZeroDuration) -> Self {
53 val.0
54 }
55}
56
57impl TryFrom<Duration> for NonZeroDuration {
58 type Error = &'static str;
59 fn try_from(value: Duration) -> Result<Self, Self::Error> {
60 if value.is_zero() {
61 Err("duration can't be zero")
62 } else {
63 Ok(Self(value))
64 }
65 }
66}
67
68impl Deref for NonZeroDuration {
69 type Target = Duration;
70 fn deref(&self) -> &Self::Target {
71 &self.0
72 }
73}
74
75impl<'de> Deserialize<'de> for NonZeroDuration {
76 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
77 where
78 D: Deserializer<'de>,
79 {
80 deserializer
81 .deserialize_any(DurationVisitor)?
82 .try_into()
83 .map_err(D::Error::custom)
84 }
85}
86
87#[derive(Deserialize)]
88#[serde(field_identifier, rename_all = "lowercase")]
89enum Field {
90 Secs,
91 Nanos,
92}
93
94fn check_overflow<E>(secs: u64, nanos: u32) -> Result<(), E>
95where
96 E: Error,
97{
98 static NANOS_PER_SEC: u32 = 1_000_000_000;
99 match secs.checked_add((nanos / NANOS_PER_SEC) as u64) {
100 Some(_) => Ok(()),
101 None => Err(E::custom("overflow deserializing SystemTime epoch offset")),
102 }
103}
104
105struct DurationVisitor;
107
108impl<'de> Visitor<'de> for DurationVisitor {
109 type Value = Duration;
110
111 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
112 formatter.write_str("struct Duration")
113 }
114
115 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
116 where
117 A: SeqAccess<'de>,
118 {
119 let secs: u64 = match seq.next_element()? {
120 Some(value) => value,
121 None => {
122 return Err(Error::invalid_length(0, &self));
123 }
124 };
125 let nanos: u32 = match seq.next_element()? {
126 Some(value) => value,
127 None => {
128 return Err(Error::invalid_length(1, &self));
129 }
130 };
131 check_overflow(secs, nanos)?;
132 Ok(Duration::new(secs, nanos))
133 }
134
135 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
136 where
137 A: MapAccess<'de>,
138 {
139 let mut secs: Option<u64> = None;
140 let mut nanos: Option<u32> = None;
141 while let Some(key) = map.next_key()? {
142 match key {
143 Field::Secs => {
144 if secs.is_some() {
145 return Err(<A::Error as Error>::duplicate_field("secs"));
146 }
147 secs = Some(map.next_value()?);
148 }
149 Field::Nanos => {
150 if nanos.is_some() {
151 return Err(<A::Error as Error>::duplicate_field("nanos"));
152 }
153 nanos = Some(map.next_value()?);
154 }
155 }
156 }
157 let secs = match secs {
158 Some(secs) => secs,
159 None => return Err(<A::Error as Error>::missing_field("secs")),
160 };
161 let nanos = match nanos {
162 Some(nanos) => nanos,
163 None => return Err(<A::Error as Error>::missing_field("nanos")),
164 };
165 check_overflow(secs, nanos)?;
166 Ok(Duration::new(secs, nanos))
167 }
168
169 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
170 where
171 E: Error,
172 {
173 check_overflow(v, 0)?;
174 Ok(Duration::from_secs(v))
175 }
176
177 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
178 where
179 E: Error,
180 {
181 parse(v).map_err(Error::custom)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use anyhow::Result;
189
190 #[derive(Deserialize, Serialize)]
191 struct Test {
192 #[serde(with = "super")]
193 time: Duration,
194 }
195 #[test]
196 fn der() -> Result<()> {
197 let t = serde_json::from_str::<Test>(r#"{"time": 1}"#)?;
198 assert_eq!(t.time, Duration::from_secs(1));
199
200 let t = serde_json::from_str::<Test>(r#"{"time": "1m"}"#)?;
201 assert_eq!(t.time, Duration::from_secs(60));
202
203 let t = serde_json::from_str::<Test>(r#"{"time": [1, 1]}"#)?;
204 assert_eq!(t.time, Duration::new(1, 1));
205
206 let t = serde_json::from_str::<Test>(r#"{"time": {"secs": 1, "nanos": 1}}"#)?;
207 assert_eq!(t.time, Duration::new(1, 1));
208
209 let t = serde_json::from_str::<Test>(r#"{"time": "1m"}"#)?;
210 let json = serde_json::to_string(&t)?;
211 let t = serde_json::from_str::<Test>(&json)?;
212 assert_eq!(t.time, Duration::from_secs(60));
213
214 let t = serde_json::from_str::<Test>(r#"{"time": 0}"#)?;
215 assert_eq!(t.time, Duration::from_secs(0));
216 Ok(())
217 }
218
219 #[derive(Deserialize, Serialize)]
220 struct TestNonZero {
221 time: NonZeroDuration,
222 }
223 #[test]
224 fn non_zero() -> Result<()> {
225 let t = serde_json::from_str::<TestNonZero>(r#"{"time": 1}"#)?;
226 assert_eq!(t.time, Duration::from_secs(1).try_into().unwrap());
227
228 let t = serde_json::from_str::<TestNonZero>(r#"{"time": "1m"}"#)?;
229 assert_eq!(t.time, Duration::from_secs(60).try_into().unwrap());
230
231 let t = serde_json::from_str::<TestNonZero>(r#"{"time": [1, 1]}"#)?;
232 assert_eq!(t.time, Duration::new(1, 1).try_into().unwrap());
233
234 let t = serde_json::from_str::<TestNonZero>(r#"{"time": {"secs": 1, "nanos": 1}}"#)?;
235 assert_eq!(t.time, Duration::new(1, 1).try_into().unwrap());
236
237 let t = serde_json::from_str::<TestNonZero>(r#"{"time": "1m"}"#)?;
238 let json = serde_json::to_string(&t)?;
239 let t = serde_json::from_str::<TestNonZero>(&json)?;
240 assert_eq!(t.time, Duration::from_secs(60).try_into().unwrap());
241
242 let t = serde_json::from_str::<TestNonZero>(r#"{"time": 0}"#);
243 assert!(t.is_err());
244 Ok(())
245 }
246}