postgres_range/
impls.rs

1use std::error::Error;
2use postgres_types::{FromSql, IsNull, Kind, ToSql, Type};
3use postgres_types::private::BytesMut;
4use postgres_protocol::{self as protocol, types};
5
6use crate::{BoundSided, BoundType, Normalizable, Range, RangeBound};
7
8impl<'a, T> FromSql<'a> for Range<T>
9where
10    T: PartialOrd + Normalizable + FromSql<'a>,
11{
12    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Range<T>, Box<dyn Error + Sync + Send>> {
13        let element_type = match *ty.kind() {
14            Kind::Range(ref ty) => ty,
15            _ => panic!("unexpected type {:?}", ty),
16        };
17
18        match types::range_from_sql(raw)? {
19            types::Range::Empty => Ok(Range::empty()),
20            types::Range::Nonempty(lower, upper) => {
21                let lower = bound_from_sql(lower, element_type)?;
22                let upper = bound_from_sql(upper, element_type)?;
23                Ok(Range::new(lower, upper))
24            }
25        }
26    }
27
28    fn accepts(ty: &Type) -> bool {
29        match *ty.kind() {
30            Kind::Range(ref inner) => <T as FromSql>::accepts(inner),
31            _ => false,
32        }
33    }
34}
35
36fn bound_from_sql<'a, T, S>(bound: types::RangeBound<Option<&'a [u8]>>, ty: &Type) -> Result<Option<RangeBound<S, T>>, Box<dyn Error + Sync + Send>>
37where
38    T: PartialOrd + Normalizable + FromSql<'a>,
39    S: BoundSided,
40{
41    match bound {
42        types::RangeBound::Exclusive(value) => {
43            let value = match value {
44                Some(value) => T::from_sql(ty, value)?,
45                None => T::from_sql_null(ty)?,
46            };
47            Ok(Some(RangeBound::new(value, BoundType::Exclusive)))
48        }
49        types::RangeBound::Inclusive(value) => {
50            let value = match value {
51                Some(value) => T::from_sql(ty, value)?,
52                None => T::from_sql_null(ty)?,
53            };
54            Ok(Some(RangeBound::new(value, BoundType::Inclusive)))
55        }
56        types::RangeBound::Unbounded => Ok(None),
57    }
58}
59
60impl<T> ToSql for Range<T>
61where
62    T: PartialOrd + Normalizable + ToSql,
63{
64    fn to_sql(&self, ty: &Type, buf: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
65        let element_type = match *ty.kind() {
66            Kind::Range(ref ty) => ty,
67            _ => panic!("unexpected type {:?}", ty),
68        };
69
70        if self.is_empty() {
71            types::empty_range_to_sql(buf);
72        } else {
73            types::range_to_sql(
74                |buf| bound_to_sql(self.lower(), element_type, buf),
75                |buf| bound_to_sql(self.upper(), element_type, buf),
76                buf,
77            )?;
78        }
79
80        Ok(IsNull::No)
81    }
82
83    fn accepts(ty: &Type) -> bool {
84        match *ty.kind() {
85            Kind::Range(ref inner) => <T as ToSql>::accepts(inner),
86            _ => false,
87        }
88    }
89
90    to_sql_checked!();
91}
92
93fn bound_to_sql<S, T>(bound: Option<&RangeBound<S, T>>, ty: &Type, buf: &mut BytesMut) -> Result<types::RangeBound<protocol::IsNull>, Box<dyn Error + Sync + Send>>
94where
95    S: BoundSided,
96    T: ToSql,
97{
98    match bound {
99        Some(bound) => {
100            let null = match bound.value.to_sql(ty, buf)? {
101                IsNull::Yes => protocol::IsNull::Yes,
102                IsNull::No => protocol::IsNull::No,
103            };
104
105            match bound.type_ {
106                BoundType::Exclusive => Ok(types::RangeBound::Exclusive(null)),
107                BoundType::Inclusive => Ok(types::RangeBound::Inclusive(null)),
108            }
109        }
110        None => Ok(types::RangeBound::Unbounded),
111    }
112}
113
114#[cfg(test)]
115mod test {
116    use std::fmt;
117
118    use postgres::{Client, NoTls};
119    use postgres::types::{FromSql, ToSql};
120    #[cfg(feature = "with-chrono-0_4")]
121    use chrono_04::{TimeZone, Utc, Duration};
122
123    macro_rules! test_range {
124        ($name:expr, $t:ty, $low:expr, $low_str:expr, $high:expr, $high_str:expr) => ({
125            let tests = &[(Some(range!('(',; ')')), "'(,)'".to_string()),
126                         (Some(range!('[' $low,; ')')), format!("'[{},)'", $low_str)),
127                         (Some(range!('(' $low,; ')')), format!("'({},)'", $low_str)),
128                         (Some(range!('(', $high; ']')), format!("'(,{}]'", $high_str)),
129                         (Some(range!('(', $high; ')')), format!("'(,{})'", $high_str)),
130                         (Some(range!('[' $low, $high; ']')),
131                          format!("'[{},{}]'", $low_str, $high_str)),
132                         (Some(range!('[' $low, $high; ')')),
133                          format!("'[{},{})'", $low_str, $high_str)),
134                         (Some(range!('(' $low, $high; ']')),
135                          format!("'({},{}]'", $low_str, $high_str)),
136                         (Some(range!('(' $low, $high; ')')),
137                          format!("'({},{})'", $low_str, $high_str)),
138                         (Some(range!(empty)), "'empty'".to_string()),
139                         (None, "NULL".to_string())];
140            test_type($name, tests);
141        })
142    }
143
144
145    fn test_type<T, S>(sql_type: &str, checks: &[(T, S)])
146    where for<'a>
147        T: Sync + PartialEq + FromSql<'a> + ToSql,
148        S: fmt::Display
149    {
150        let mut conn = Client::connect("postgres://postgres@localhost", NoTls).unwrap();
151        for &(ref val, ref repr) in checks {
152            let stmt = conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type))
153                .unwrap();
154            let result = conn.query(&stmt, &[]).unwrap().iter().next().unwrap().get(0);
155            assert!(val == &result);
156
157            let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap();
158            let result = conn.query(&stmt, &[val]).unwrap().iter().next().unwrap().get(0);
159            assert!(val == &result);
160        }
161    }
162
163    #[test]
164    fn test_int4range_params() {
165        test_range!("INT4RANGE", i32, 100i32, "100", 200i32, "200")
166    }
167
168    #[test]
169    fn test_int8range_params() {
170        test_range!("INT8RANGE", i64, 100i64, "100", 200i64, "200")
171    }
172
173    #[test]
174    #[cfg(feature = "with-chrono-0_4")]
175    fn test_tsrange_params() {
176        let low = Utc.timestamp(0, 0);
177        let high = low + Duration::days(10);
178        test_range!("TSRANGE", NaiveDateTime, low.naive_utc(), "1970-01-01", high.naive_utc(), "1970-01-11");
179    }
180
181    #[test]
182    #[cfg(feature = "with-chrono-0_4")]
183    fn test_tstzrange_params() {
184        let low = Utc.timestamp(0, 0);
185        let high = low + Duration::days(10);
186        test_range!("TSTZRANGE", DateTime<Utc>, low, "1970-01-01", high, "1970-01-11");
187    }
188}