pg_interval_sql_json_binding/
lib.rs

1extern crate pg_interval;
2extern crate postgres_types;
3extern crate serde;
4
5use std::error::Error;
6use std::str::FromStr;
7use postgres_types::{accepts, FromSql, IsNull, to_sql_checked, ToSql, Type};
8use postgres_types::private::BytesMut;
9use serde::{Deserialize, Serialize};
10use serde::ser::{SerializeStruct};
11use std::convert::TryInto;
12use std::fmt::{Display, Formatter, Write};
13
14#[derive(Debug, PartialEq, Eq)]
15pub struct ParseError {
16    pg: pg_interval::ParseError,
17}
18
19impl From<pg_interval::ParseError> for ParseError {
20    fn from(pg: pg_interval::ParseError) -> ParseError {
21        ParseError {
22            pg
23        }
24    }
25}
26
27impl Into<pg_interval::ParseError> for ParseError {
28    fn into(self) -> pg_interval::ParseError {
29        self.pg
30    }
31}
32
33impl Display for ParseError {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        match &self.pg {
36            pg_interval::ParseError::InvalidInterval(s) => write!(f, "{}", s),
37            pg_interval::ParseError::InvalidTime(s) => write!(f, "{}", s),
38            pg_interval::ParseError::InvalidYearMonth(s) => write!(f, "{}", s),
39            pg_interval::ParseError::ParseIntErr(s) => write!(f, "{}", s),
40            pg_interval::ParseError::ParseFloatErr(s) => write!(f, "{}", s),
41        }
42    }
43}
44
45impl Error for ParseError {
46}
47
48#[derive(Debug)]
49pub struct Interval {
50    pg: pg_interval::Interval,
51}
52
53impl Interval {
54    pub fn new(interval: &str) -> Result<Interval, ParseError> {
55        Ok(Interval {
56            pg: pg_interval::Interval::from_postgres(&interval)?,
57        })
58    }
59
60    pub fn inner(&self) -> &pg_interval::Interval {
61        &self.pg
62    }
63
64    pub fn bytes(&self) -> Vec<u8> {
65        let mut buf = vec![0u8, 16];
66        buf[0..8].copy_from_slice(&self.pg.microseconds.to_be_bytes());
67        buf[8..12].copy_from_slice(&self.pg.days.to_be_bytes());
68        buf[12..16].copy_from_slice(&self.pg.months.to_be_bytes());
69        buf
70    }
71}
72
73impl FromStr for Interval {
74    type Err = ParseError;
75
76    fn from_str(s: &str) -> Result<Self, Self::Err> {
77        Interval::new(s)
78    }
79}
80
81impl Display for Interval {
82    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
83        let years = self.pg.months / 12;
84        let months = self.pg.months % 12;
85        let days = self.pg.days;
86        let hours = self.pg.microseconds / 3_600_000_000;
87        let minutes = (self.pg.microseconds % 3_600_000_000) / 60_000_000;
88        let seconds = (self.pg.microseconds % 60_000_000) / 1_000_000;
89        let milliseconds = self.pg.microseconds % 1_000_000 / 1_000;
90        let microseconds = self.pg.microseconds % 1_000;
91        let mut buf = String::new();
92        if years > 0 {
93            write!(buf, "{} years ", years)?;
94        }
95        if months > 0 {
96            write!(buf, "{} mons ", months)?;
97        }
98        if days > 0 {
99            write!(buf, "{} days ", days)?;
100        }
101        if hours > 0 {
102            write!(buf, "{} hours ", hours)?;
103        }
104        if minutes > 0 {
105            write!(buf, "{} minutes ", minutes)?;
106        }
107        if seconds > 0 {
108            write!(buf, "{} seconds ", seconds)?;
109        }
110        if milliseconds > 0 {
111            write!(buf, "{} milliseconds ", milliseconds)?;
112        }
113        if microseconds > 0 {
114            write!(buf, "{} microseconds ", microseconds)?;
115        }
116        if buf.is_empty() {
117            write!(buf, "0 seconds")
118        } else {
119            write!(f, "{}", &buf.as_str()[..buf.len() - 1])
120        }
121    }
122}
123
124impl<'a> FromSql<'a> for Interval {
125    fn from_sql(_: &Type, raw: &[u8]) -> Result<Interval, Box<dyn Error + Sync + Send>> {
126        Ok(Interval {
127            pg: pg_interval::Interval {
128                months: i32::from_be_bytes(raw[12..16].try_into().unwrap()),
129                days: i32::from_be_bytes(raw[8..12].try_into().unwrap()),
130                microseconds: i64::from_be_bytes(raw[0..8].try_into().unwrap()),
131            }
132        })
133    }
134
135    accepts!(INTERVAL);
136}
137
138impl ToSql for Interval {
139    fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
140        w.extend_from_slice(self.pg.microseconds.to_be_bytes().as_slice());
141        w.extend_from_slice(self.pg.days.to_be_bytes().as_slice());
142        w.extend_from_slice(self.pg.months.to_be_bytes().as_slice());
143        Ok(IsNull::No)
144    }
145
146    accepts!(INTERVAL);
147    to_sql_checked!();
148}
149
150impl Serialize for Interval {
151    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
152        where
153            S: serde::Serializer,
154    {
155        let mut state = serializer.serialize_struct("Interval", 3)?;
156        state.serialize_field("m", &self.pg.months)?;
157        state.serialize_field("d", &self.pg.days)?;
158        state.serialize_field("us", &self.pg.microseconds)?;
159        state.end()
160    }
161}
162
163impl<'de> Deserialize<'de> for Interval {
164    fn deserialize<D>(deserializer: D) -> Result<Interval, D::Error>
165        where
166            D: serde::Deserializer<'de>,
167    {
168        struct IntervalVisitor;
169
170        impl<'de> serde::de::Visitor<'de> for IntervalVisitor {
171            type Value = Interval;
172
173            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
174                formatter.write_str("a string representing an interval")
175            }
176
177            fn visit_map<V>(self, mut visitor: V) -> Result<Interval, V::Error>
178                where
179                    V: serde::de::MapAccess<'de>,
180            {
181                let mut months = None;
182                let mut days = None;
183                let mut microseconds = None;
184
185                while let Some(key) = visitor.next_key()? {
186                    match key {
187                        "m" => {
188                            if months.is_some() {
189                                return Err(serde::de::Error::duplicate_field("m"));
190                            }
191                            months = Some(visitor.next_value()?);
192                        }
193                        "d" => {
194                            if days.is_some() {
195                                return Err(serde::de::Error::duplicate_field("d"));
196                            }
197                            days = Some(visitor.next_value()?);
198                        }
199                        "us" => {
200                            if microseconds.is_some() {
201                                return Err(serde::de::Error::duplicate_field("us"));
202                            }
203                            microseconds = Some(visitor.next_value()?);
204                        }
205                        _ => {
206                            return Err(serde::de::Error::unknown_field(key, &["m", "d", "us"]));
207                        }
208                    }
209                }
210
211                let months = months.ok_or_else(|| serde::de::Error::missing_field("m"))?;
212                let days = days.ok_or_else(|| serde::de::Error::missing_field("d"))?;
213                let microseconds = microseconds.ok_or_else(|| serde::de::Error::missing_field("us"))?;
214
215                Ok(Interval {
216                    pg: pg_interval::Interval {
217                        months,
218                        days,
219                        microseconds,
220                    }
221                })
222            }
223        }
224
225        deserializer.deserialize_struct("Interval", &["m", "d", "us"], IntervalVisitor)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_interval_from_str() {
235        let interval = Interval::new("1 mons 2 days 3 seconds").unwrap();
236        assert_eq!(interval.pg.months, 1);
237        assert_eq!(interval.pg.days, 2);
238        assert_eq!(interval.pg.microseconds, 3000000);
239    }
240
241    #[test]
242    fn test_interval_from_sql() {
243        let interval = Interval::from_sql(&Type::INTERVAL, &[0, 0, 0, 0, 0, 45, 198, 192, 0, 0, 0, 2, 0, 0, 0, 1]).unwrap();
244        assert_eq!(interval.pg.months, 1);
245        assert_eq!(interval.pg.days, 2);
246        assert_eq!(interval.pg.microseconds, 3000000);
247    }
248
249    #[test]
250    fn test_interval_to_sql() {
251        let interval = Interval {
252            pg: pg_interval::Interval {
253                months: 1,
254                days: 2,
255                microseconds: 3000000,
256            }
257        };
258        let mut buf = BytesMut::new();
259        interval.to_sql(&Type::INTERVAL, &mut buf).unwrap();
260        assert_eq!(buf.as_ref(), &[0, 0, 0, 0, 0, 45, 198, 192, 0, 0, 0, 2, 0, 0, 0, 1]);
261    }
262
263    #[test]
264    fn test_interval_display() {
265        let interval = Interval {
266            pg: pg_interval::Interval {
267                months: 14,
268                days: 3,
269                microseconds: 4 * 3600000000 + 5 * 60000000 + 6 * 1000000 + 7 * 1000 + 8,
270            }
271        };
272        assert_eq!(format!("{}", interval), "1 years 2 mons 3 days 4 hours 5 minutes 6 seconds 7 milliseconds 8 microseconds");
273
274        let interval = Interval {
275            pg: pg_interval::Interval {
276                months: 1,
277                days: 2,
278                microseconds: 3 * 1000000,
279            }
280        };
281        assert_eq!(interval.to_string(), "1 mons 2 days 3 seconds");
282    }
283
284    #[test]
285    fn test_interval_serialize() {
286        let interval = Interval {
287            pg: pg_interval::Interval {
288                months: 1,
289                days: 2,
290                microseconds: 3,
291            }
292        };
293        let serialized = serde_json::to_string(&interval).unwrap();
294        assert_eq!(serialized, r#"{"m":1,"d":2,"us":3}"#);
295    }
296
297    #[test]
298    fn test_interval_deserialize() {
299        let deserialized: Interval = serde_json::from_str(r#"{"m":1,"d":2,"us":3}"#).unwrap();
300        assert_eq!(deserialized.pg.months, 1);
301        assert_eq!(deserialized.pg.days, 2);
302        assert_eq!(deserialized.pg.microseconds, 3);
303    }
304
305    #[test]
306    fn test_anyhow_error_propagation() {
307        let interval = (|| -> anyhow::Result<Interval> {
308            Ok(Interval::new("1 monthss")?)
309        })();
310        assert_eq!(interval.err().unwrap().to_string(), "Unknown or duplicate deliminator \"monthss\"");
311    }
312}