ip_network/
diesel_support.rs

1use std::io::Write;
2use diesel::expression::{AsExpression, Expression};
3use diesel::pg::Pg;
4use diesel::serialize::{self, IsNull, Output, ToSql};
5use diesel::deserialize::{self, FromSql};
6use diesel::sql_types::Cidr;
7use crate::{IpNetwork, Ipv4Network, Ipv6Network};
8use crate::postgres_common;
9
10impl FromSql<Cidr, Pg> for Ipv4Network {
11    fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
12        let bytes = not_none!(bytes);
13        postgres_common::from_sql_ipv4_network(bytes)
14    }
15}
16
17impl FromSql<Cidr, Pg> for Ipv6Network {
18    fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
19        let bytes = not_none!(bytes);
20        postgres_common::from_sql_ipv6_network(bytes)
21    }
22}
23
24impl FromSql<Cidr, Pg> for IpNetwork {
25    fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
26        let bytes = not_none!(bytes);
27        match bytes[0] {
28            postgres_common::IPV4_TYPE => Ok(IpNetwork::V4(
29                postgres_common::from_sql_ipv4_network(bytes)?,
30            )),
31            postgres_common::IPV6_TYPE => Ok(IpNetwork::V6(
32                postgres_common::from_sql_ipv6_network(bytes)?,
33            )),
34            _ => Err("CIDR is not IP version 4 or 6".into()),
35        }
36    }
37}
38
39impl ToSql<Cidr, Pg> for Ipv4Network {
40    fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
41        let data = postgres_common::to_sql_ipv4_network(self);
42        out.write_all(&data).map(|_| IsNull::No).map_err(Into::into)
43    }
44}
45
46impl ToSql<Cidr, Pg> for Ipv6Network {
47    fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
48        let data = postgres_common::to_sql_ipv6_network(self);
49        out.write_all(&data).map(|_| IsNull::No).map_err(Into::into)
50    }
51}
52
53impl ToSql<Cidr, Pg> for IpNetwork {
54    fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
55        match self {
56            IpNetwork::V4(network) => ToSql::<Cidr, Pg>::to_sql(network, out),
57            IpNetwork::V6(network) => ToSql::<Cidr, Pg>::to_sql(network, out),
58        }
59    }
60}
61
62#[allow(dead_code)]
63mod foreign_derives {
64    use super::*;
65
66    #[derive(FromSqlRow, AsExpression)]
67    #[diesel(foreign_derive)]
68    #[sql_type = "Cidr"]
69    struct IpNetworkProxy(IpNetwork);
70
71    #[derive(FromSqlRow, AsExpression)]
72    #[diesel(foreign_derive)]
73    #[sql_type = "Cidr"]
74    struct Ipv4NetworkProxy(Ipv4Network);
75
76    #[derive(FromSqlRow, AsExpression)]
77    #[diesel(foreign_derive)]
78    #[sql_type = "Cidr"]
79    struct Ipv6NetworkProxy(Ipv6Network);
80}
81
82diesel_infix_operator!(IsContainedBy, " << ", backend: Pg);
83diesel_infix_operator!(IsContainedByOrEquals, " <<= ", backend: Pg);
84diesel_infix_operator!(Contains, " >> ", backend: Pg);
85diesel_infix_operator!(ContainsOrEquals, " >>= ", backend: Pg);
86diesel_infix_operator!(ContainsOrIsContainedBy, " && ", backend: Pg);
87
88/// Support for PostgreSQL Network Address Operators for Diesel
89///
90/// See [PostgreSQL documentation for details](https://www.postgresql.org/docs/current/static/functions-net.html).
91pub trait PqCidrExtensionMethods: Expression<SqlType = Cidr> + Sized {
92    /// Creates a SQL `<<` expression.
93    fn is_contained_by<T>(self, other: T) -> IsContainedBy<Self, T::Expression>
94    where
95        T: AsExpression<Self::SqlType>,
96    {
97        IsContainedBy::new(self, other.as_expression())
98    }
99
100    /// Creates a SQL `<<=` expression.
101    fn is_contained_by_or_equals<T>(self, other: T) -> IsContainedByOrEquals<Self, T::Expression>
102    where
103        T: AsExpression<Self::SqlType>,
104    {
105        IsContainedByOrEquals::new(self, other.as_expression())
106    }
107
108    /// Creates a SQL `>>` expression.
109    fn contains<T>(self, other: T) -> Contains<Self, T::Expression>
110    where
111        T: AsExpression<Self::SqlType>,
112    {
113        Contains::new(self, other.as_expression())
114    }
115
116    /// Creates a SQL `>>=` expression.
117    fn contains_or_equals<T>(self, other: T) -> ContainsOrEquals<Self, T::Expression>
118    where
119        T: AsExpression<Self::SqlType>,
120    {
121        ContainsOrEquals::new(self, other.as_expression())
122    }
123
124    /// Creates a SQL `&&` expression.
125    fn contains_or_is_contained_by<T>(
126        self,
127        other: T,
128    ) -> ContainsOrIsContainedBy<Self, T::Expression>
129    where
130        T: AsExpression<Self::SqlType>,
131    {
132        ContainsOrIsContainedBy::new(self, other.as_expression())
133    }
134}
135
136impl<T> PqCidrExtensionMethods for T where T: Expression<SqlType = Cidr> {}
137
138/// CIDR functions.
139pub mod functions {
140    use diesel::sql_types::Cidr;
141
142    sql_function! {
143        /// Extract family of address; 4 for IPv4, 6 for IPv6.
144        fn family(x: Cidr) -> Integer;
145    }
146    sql_function! {
147        /// Extract netmask length.
148        fn masklen(x: Cidr) -> Integer;
149    }
150}
151
152pub mod helper_types {
153    pub type Family<Expr> = super::functions::family::HelperType<Expr>;
154    pub type Masklen<Expr> = super::functions::masklen::HelperType<Expr>;
155}
156
157pub mod dsl {
158    pub use super::functions::*;
159    pub use super::helper_types::*;
160}
161
162#[cfg(test)]
163mod tests {
164    use std::net::{Ipv4Addr, Ipv6Addr};
165    use diesel::sql_types::Cidr;
166    use diesel::pg::Pg;
167    use diesel::serialize::{Output, ToSql};
168    use diesel::deserialize::FromSql;
169    use diesel::prelude::*;
170    use diesel::debug_query;
171    use super::PqCidrExtensionMethods;
172    use super::{IpNetwork, Ipv4Network, Ipv6Network};
173    use super::dsl::*;
174
175    table! {
176        test {
177            id -> Integer,
178            ip_network -> Cidr,
179            ipv4_network -> Cidr,
180            ipv6_network -> Cidr,
181        }
182    }
183
184    #[derive(Insertable)]
185    #[table_name = "test"]
186    pub struct NewPost {
187        pub id: i32,
188        pub ip_network: IpNetwork,
189        pub ipv4_network: Ipv4Network,
190        pub ipv6_network: Ipv6Network,
191    }
192
193    fn test_output() -> Output<'static, Vec<u8>, Pg> {
194        let uninit = std::mem::MaybeUninit::uninit();
195        let fake_metadata_lookup = unsafe { uninit.assume_init() };
196        Output::new(Vec::new(), fake_metadata_lookup)
197    }
198
199    #[test]
200    fn ipv4_network() {
201        let mut bytes = test_output();
202        let ipv4_network = Ipv4Network::new(Ipv4Addr::new(1, 2, 3, 4), 32).unwrap();
203        ToSql::<Cidr, Pg>::to_sql(&ipv4_network, &mut bytes).unwrap();
204        let converted: Ipv4Network = FromSql::<Cidr, Pg>::from_sql(Some(bytes.as_ref())).unwrap();
205        assert_eq!(ipv4_network, converted);
206    }
207
208    #[test]
209    fn ipv6_network() {
210        let mut bytes = test_output();
211        let ipv6_network = Ipv6Network::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 128).unwrap();
212        ToSql::<Cidr, Pg>::to_sql(&ipv6_network, &mut bytes).unwrap();
213        let converted: Ipv6Network = FromSql::<Cidr, Pg>::from_sql(Some(bytes.as_ref())).unwrap();
214        assert_eq!(ipv6_network, converted);
215    }
216
217    #[test]
218    fn ip_network() {
219        let mut bytes = test_output();
220        let ip_network =
221            IpNetwork::V6(Ipv6Network::new(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), 128).unwrap());
222        ToSql::<Cidr, Pg>::to_sql(&ip_network, &mut bytes).unwrap();
223        let converted: IpNetwork = FromSql::<Cidr, Pg>::from_sql(Some(bytes.as_ref())).unwrap();
224        assert_eq!(ip_network, converted);
225    }
226
227    #[test]
228    fn operators() {
229        let ip = IpNetwork::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap();
230        test::ip_network.is_contained_by(&ip);
231        test::ip_network.is_contained_by_or_equals(&ip);
232        test::ip_network.contains(&ip);
233        test::ip_network.contains_or_equals(&ip);
234        test::ip_network.contains_or_is_contained_by(&ip);
235    }
236
237    #[test]
238    fn function_family() {
239        let query = test::table.select(family(test::ip_network));
240        let string_query = debug_query::<Pg, _>(&query).to_string();
241        assert_eq!(
242            "SELECT family(\"test\".\"ip_network\") FROM \"test\" -- binds: []",
243            string_query
244        );
245    }
246
247    #[test]
248    fn function_masklen() {
249        let query = test::table.select(masklen(test::ip_network));
250        let string_query = debug_query::<Pg, _>(&query).to_string();
251        assert_eq!(
252            "SELECT masklen(\"test\".\"ip_network\") FROM \"test\" -- binds: []",
253            string_query
254        );
255    }
256}