1use std::io::Write;
2use diesel::expression::{AsExpression, Expression};
3use diesel::pg::Pg;
4use diesel::serialize::{self, IsNull, Output, ToSql};
5use diesel::deserialize::{self, FromSql};
6use diesel::sql_types::Cidr;
7use crate::{IpNetwork, Ipv4Network, Ipv6Network};
8use crate::postgres_common;
9
10impl FromSql<Cidr, Pg> for Ipv4Network {
11 fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
12 let bytes = not_none!(bytes);
13 postgres_common::from_sql_ipv4_network(bytes)
14 }
15}
16
17impl FromSql<Cidr, Pg> for Ipv6Network {
18 fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
19 let bytes = not_none!(bytes);
20 postgres_common::from_sql_ipv6_network(bytes)
21 }
22}
23
24impl FromSql<Cidr, Pg> for IpNetwork {
25 fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
26 let bytes = not_none!(bytes);
27 match bytes[0] {
28 postgres_common::IPV4_TYPE => Ok(IpNetwork::V4(
29 postgres_common::from_sql_ipv4_network(bytes)?,
30 )),
31 postgres_common::IPV6_TYPE => Ok(IpNetwork::V6(
32 postgres_common::from_sql_ipv6_network(bytes)?,
33 )),
34 _ => Err("CIDR is not IP version 4 or 6".into()),
35 }
36 }
37}
38
39impl ToSql<Cidr, Pg> for Ipv4Network {
40 fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
41 let data = postgres_common::to_sql_ipv4_network(self);
42 out.write_all(&data).map(|_| IsNull::No).map_err(Into::into)
43 }
44}
45
46impl ToSql<Cidr, Pg> for Ipv6Network {
47 fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
48 let data = postgres_common::to_sql_ipv6_network(self);
49 out.write_all(&data).map(|_| IsNull::No).map_err(Into::into)
50 }
51}
52
53impl ToSql<Cidr, Pg> for IpNetwork {
54 fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
55 match self {
56 IpNetwork::V4(network) => ToSql::<Cidr, Pg>::to_sql(network, out),
57 IpNetwork::V6(network) => ToSql::<Cidr, Pg>::to_sql(network, out),
58 }
59 }
60}
61
62#[allow(dead_code)]
63mod foreign_derives {
64 use super::*;
65
66 #[derive(FromSqlRow, AsExpression)]
67 #[diesel(foreign_derive)]
68 #[sql_type = "Cidr"]
69 struct IpNetworkProxy(IpNetwork);
70
71 #[derive(FromSqlRow, AsExpression)]
72 #[diesel(foreign_derive)]
73 #[sql_type = "Cidr"]
74 struct Ipv4NetworkProxy(Ipv4Network);
75
76 #[derive(FromSqlRow, AsExpression)]
77 #[diesel(foreign_derive)]
78 #[sql_type = "Cidr"]
79 struct Ipv6NetworkProxy(Ipv6Network);
80}
81
82diesel_infix_operator!(IsContainedBy, " << ", backend: Pg);
83diesel_infix_operator!(IsContainedByOrEquals, " <<= ", backend: Pg);
84diesel_infix_operator!(Contains, " >> ", backend: Pg);
85diesel_infix_operator!(ContainsOrEquals, " >>= ", backend: Pg);
86diesel_infix_operator!(ContainsOrIsContainedBy, " && ", backend: Pg);
87
88pub trait PqCidrExtensionMethods: Expression<SqlType = Cidr> + Sized {
92 fn is_contained_by<T>(self, other: T) -> IsContainedBy<Self, T::Expression>
94 where
95 T: AsExpression<Self::SqlType>,
96 {
97 IsContainedBy::new(self, other.as_expression())
98 }
99
100 fn is_contained_by_or_equals<T>(self, other: T) -> IsContainedByOrEquals<Self, T::Expression>
102 where
103 T: AsExpression<Self::SqlType>,
104 {
105 IsContainedByOrEquals::new(self, other.as_expression())
106 }
107
108 fn contains<T>(self, other: T) -> Contains<Self, T::Expression>
110 where
111 T: AsExpression<Self::SqlType>,
112 {
113 Contains::new(self, other.as_expression())
114 }
115
116 fn contains_or_equals<T>(self, other: T) -> ContainsOrEquals<Self, T::Expression>
118 where
119 T: AsExpression<Self::SqlType>,
120 {
121 ContainsOrEquals::new(self, other.as_expression())
122 }
123
124 fn contains_or_is_contained_by<T>(
126 self,
127 other: T,
128 ) -> ContainsOrIsContainedBy<Self, T::Expression>
129 where
130 T: AsExpression<Self::SqlType>,
131 {
132 ContainsOrIsContainedBy::new(self, other.as_expression())
133 }
134}
135
136impl<T> PqCidrExtensionMethods for T where T: Expression<SqlType = Cidr> {}
137
138pub mod functions {
140 use diesel::sql_types::Cidr;
141
142 sql_function! {
143 fn family(x: Cidr) -> Integer;
145 }
146 sql_function! {
147 fn masklen(x: Cidr) -> Integer;
149 }
150}
151
152pub mod helper_types {
153 pub type Family<Expr> = super::functions::family::HelperType<Expr>;
154 pub type Masklen<Expr> = super::functions::masklen::HelperType<Expr>;
155}
156
157pub mod dsl {
158 pub use super::functions::*;
159 pub use super::helper_types::*;
160}
161
162#[cfg(test)]
163mod tests {
164 use std::net::{Ipv4Addr, Ipv6Addr};
165 use diesel::sql_types::Cidr;
166 use diesel::pg::Pg;
167 use diesel::serialize::{Output, ToSql};
168 use diesel::deserialize::FromSql;
169 use diesel::prelude::*;
170 use diesel::debug_query;
171 use super::PqCidrExtensionMethods;
172 use super::{IpNetwork, Ipv4Network, Ipv6Network};
173 use super::dsl::*;
174
175 table! {
176 test {
177 id -> Integer,
178 ip_network -> Cidr,
179 ipv4_network -> Cidr,
180 ipv6_network -> Cidr,
181 }
182 }
183
184 #[derive(Insertable)]
185 #[table_name = "test"]
186 pub struct NewPost {
187 pub id: i32,
188 pub ip_network: IpNetwork,
189 pub ipv4_network: Ipv4Network,
190 pub ipv6_network: Ipv6Network,
191 }
192
193 fn test_output() -> Output<'static, Vec<u8>, Pg> {
194 let uninit = std::mem::MaybeUninit::uninit();
195 let fake_metadata_lookup = unsafe { uninit.assume_init() };
196 Output::new(Vec::new(), fake_metadata_lookup)
197 }
198
199 #[test]
200 fn ipv4_network() {
201 let mut bytes = test_output();
202 let ipv4_network = Ipv4Network::new(Ipv4Addr::new(1, 2, 3, 4), 32).unwrap();
203 ToSql::<Cidr, Pg>::to_sql(&ipv4_network, &mut bytes).unwrap();
204 let converted: Ipv4Network = FromSql::<Cidr, Pg>::from_sql(Some(bytes.as_ref())).unwrap();
205 assert_eq!(ipv4_network, converted);
206 }
207
208 #[test]
209 fn ipv6_network() {
210 let mut bytes = test_output();
211 let ipv6_network = Ipv6Network::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 128).unwrap();
212 ToSql::<Cidr, Pg>::to_sql(&ipv6_network, &mut bytes).unwrap();
213 let converted: Ipv6Network = FromSql::<Cidr, Pg>::from_sql(Some(bytes.as_ref())).unwrap();
214 assert_eq!(ipv6_network, converted);
215 }
216
217 #[test]
218 fn ip_network() {
219 let mut bytes = test_output();
220 let ip_network =
221 IpNetwork::V6(Ipv6Network::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 128).unwrap());
222 ToSql::<Cidr, Pg>::to_sql(&ip_network, &mut bytes).unwrap();
223 let converted: IpNetwork = FromSql::<Cidr, Pg>::from_sql(Some(bytes.as_ref())).unwrap();
224 assert_eq!(ip_network, converted);
225 }
226
227 #[test]
228 fn operators() {
229 let ip = IpNetwork::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap();
230 test::ip_network.is_contained_by(&ip);
231 test::ip_network.is_contained_by_or_equals(&ip);
232 test::ip_network.contains(&ip);
233 test::ip_network.contains_or_equals(&ip);
234 test::ip_network.contains_or_is_contained_by(&ip);
235 }
236
237 #[test]
238 fn function_family() {
239 let query = test::table.select(family(test::ip_network));
240 let string_query = debug_query::<Pg, _>(&query).to_string();
241 assert_eq!(
242 "SELECT family(\"test\".\"ip_network\") FROM \"test\" -- binds: []",
243 string_query
244 );
245 }
246
247 #[test]
248 fn function_masklen() {
249 let query = test::table.select(masklen(test::ip_network));
250 let string_query = debug_query::<Pg, _>(&query).to_string();
251 assert_eq!(
252 "SELECT masklen(\"test\".\"ip_network\") FROM \"test\" -- binds: []",
253 string_query
254 );
255 }
256}