1use std::error::Error;
2use postgres_types::{FromSql, IsNull, Kind, ToSql, Type};
3use postgres_types::private::BytesMut;
4use postgres_protocol::{self as protocol, types};
5
6use crate::{BoundSided, BoundType, Normalizable, Range, RangeBound};
7
8impl<'a, T> FromSql<'a> for Range<T>
9where
10 T: PartialOrd + Normalizable + FromSql<'a>,
11{
12 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Range<T>, Box<dyn Error + Sync + Send>> {
13 let element_type = match *ty.kind() {
14 Kind::Range(ref ty) => ty,
15 _ => panic!("unexpected type {:?}", ty),
16 };
17
18 match types::range_from_sql(raw)? {
19 types::Range::Empty => Ok(Range::empty()),
20 types::Range::Nonempty(lower, upper) => {
21 let lower = bound_from_sql(lower, element_type)?;
22 let upper = bound_from_sql(upper, element_type)?;
23 Ok(Range::new(lower, upper))
24 }
25 }
26 }
27
28 fn accepts(ty: &Type) -> bool {
29 match *ty.kind() {
30 Kind::Range(ref inner) => <T as FromSql>::accepts(inner),
31 _ => false,
32 }
33 }
34}
35
36fn bound_from_sql<'a, T, S>(bound: types::RangeBound<Option<&'a [u8]>>, ty: &Type) -> Result<Option<RangeBound<S, T>>, Box<dyn Error + Sync + Send>>
37where
38 T: PartialOrd + Normalizable + FromSql<'a>,
39 S: BoundSided,
40{
41 match bound {
42 types::RangeBound::Exclusive(value) => {
43 let value = match value {
44 Some(value) => T::from_sql(ty, value)?,
45 None => T::from_sql_null(ty)?,
46 };
47 Ok(Some(RangeBound::new(value, BoundType::Exclusive)))
48 }
49 types::RangeBound::Inclusive(value) => {
50 let value = match value {
51 Some(value) => T::from_sql(ty, value)?,
52 None => T::from_sql_null(ty)?,
53 };
54 Ok(Some(RangeBound::new(value, BoundType::Inclusive)))
55 }
56 types::RangeBound::Unbounded => Ok(None),
57 }
58}
59
60impl<T> ToSql for Range<T>
61where
62 T: PartialOrd + Normalizable + ToSql,
63{
64 fn to_sql(&self, ty: &Type, buf: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
65 let element_type = match *ty.kind() {
66 Kind::Range(ref ty) => ty,
67 _ => panic!("unexpected type {:?}", ty),
68 };
69
70 if self.is_empty() {
71 types::empty_range_to_sql(buf);
72 } else {
73 types::range_to_sql(
74 |buf| bound_to_sql(self.lower(), element_type, buf),
75 |buf| bound_to_sql(self.upper(), element_type, buf),
76 buf,
77 )?;
78 }
79
80 Ok(IsNull::No)
81 }
82
83 fn accepts(ty: &Type) -> bool {
84 match *ty.kind() {
85 Kind::Range(ref inner) => <T as ToSql>::accepts(inner),
86 _ => false,
87 }
88 }
89
90 to_sql_checked!();
91}
92
93fn bound_to_sql<S, T>(bound: Option<&RangeBound<S, T>>, ty: &Type, buf: &mut BytesMut) -> Result<types::RangeBound<protocol::IsNull>, Box<dyn Error + Sync + Send>>
94where
95 S: BoundSided,
96 T: ToSql,
97{
98 match bound {
99 Some(bound) => {
100 let null = match bound.value.to_sql(ty, buf)? {
101 IsNull::Yes => protocol::IsNull::Yes,
102 IsNull::No => protocol::IsNull::No,
103 };
104
105 match bound.type_ {
106 BoundType::Exclusive => Ok(types::RangeBound::Exclusive(null)),
107 BoundType::Inclusive => Ok(types::RangeBound::Inclusive(null)),
108 }
109 }
110 None => Ok(types::RangeBound::Unbounded),
111 }
112}
113
114#[cfg(test)]
115mod test {
116 use std::fmt;
117
118 use postgres::{Client, NoTls};
119 use postgres::types::{FromSql, ToSql};
120 #[cfg(feature = "with-chrono-0_4")]
121 use chrono_04::{TimeZone, Utc, Duration};
122
123 macro_rules! test_range {
124 ($name:expr, $t:ty, $low:expr, $low_str:expr, $high:expr, $high_str:expr) => ({
125 let tests = &[(Some(range!('(',; ')')), "'(,)'".to_string()),
126 (Some(range!('[' $low,; ')')), format!("'[{},)'", $low_str)),
127 (Some(range!('(' $low,; ')')), format!("'({},)'", $low_str)),
128 (Some(range!('(', $high; ']')), format!("'(,{}]'", $high_str)),
129 (Some(range!('(', $high; ')')), format!("'(,{})'", $high_str)),
130 (Some(range!('[' $low, $high; ']')),
131 format!("'[{},{}]'", $low_str, $high_str)),
132 (Some(range!('[' $low, $high; ')')),
133 format!("'[{},{})'", $low_str, $high_str)),
134 (Some(range!('(' $low, $high; ']')),
135 format!("'({},{}]'", $low_str, $high_str)),
136 (Some(range!('(' $low, $high; ')')),
137 format!("'({},{})'", $low_str, $high_str)),
138 (Some(range!(empty)), "'empty'".to_string()),
139 (None, "NULL".to_string())];
140 test_type($name, tests);
141 })
142 }
143
144
145 fn test_type<T, S>(sql_type: &str, checks: &[(T, S)])
146 where for<'a>
147 T: Sync + PartialEq + FromSql<'a> + ToSql,
148 S: fmt::Display
149 {
150 let mut conn = Client::connect("postgres://postgres@localhost", NoTls).unwrap();
151 for &(ref val, ref repr) in checks {
152 let stmt = conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type))
153 .unwrap();
154 let result = conn.query(&stmt, &[]).unwrap().iter().next().unwrap().get(0);
155 assert!(val == &result);
156
157 let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap();
158 let result = conn.query(&stmt, &[val]).unwrap().iter().next().unwrap().get(0);
159 assert!(val == &result);
160 }
161 }
162
163 #[test]
164 fn test_int4range_params() {
165 test_range!("INT4RANGE", i32, 100i32, "100", 200i32, "200")
166 }
167
168 #[test]
169 fn test_int8range_params() {
170 test_range!("INT8RANGE", i64, 100i64, "100", 200i64, "200")
171 }
172
173 #[test]
174 #[cfg(feature = "with-chrono-0_4")]
175 fn test_tsrange_params() {
176 let low = Utc.timestamp(0, 0);
177 let high = low + Duration::days(10);
178 test_range!("TSRANGE", NaiveDateTime, low.naive_utc(), "1970-01-01", high.naive_utc(), "1970-01-11");
179 }
180
181 #[test]
182 #[cfg(feature = "with-chrono-0_4")]
183 fn test_tstzrange_params() {
184 let low = Utc.timestamp(0, 0);
185 let high = low + Duration::days(10);
186 test_range!("TSTZRANGE", DateTime<Utc>, low, "1970-01-01", high, "1970-01-11");
187 }
188}