sqlx_core/postgres/types/
ipnetwork.rs

1use std::net::{Ipv4Addr, Ipv6Addr};
2
3use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network};
4
5use crate::decode::Decode;
6use crate::encode::{Encode, IsNull};
7use crate::error::BoxDynError;
8use crate::postgres::{
9    PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres,
10};
11use crate::types::Type;
12
13// https://github.com/rust-lang/rust/search?q=AF_INET&unscoped_q=AF_INET
14
15#[cfg(windows)]
16const AF_INET: u8 = 2;
17
18#[cfg(not(any(unix, windows)))]
19const AF_INET: u8 = 0;
20
21#[cfg(unix)]
22const AF_INET: u8 = libc::AF_INET as u8;
23
24// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39
25
26const PGSQL_AF_INET: u8 = AF_INET;
27const PGSQL_AF_INET6: u8 = AF_INET + 1;
28
29impl Type<Postgres> for IpNetwork {
30    fn type_info() -> PgTypeInfo {
31        PgTypeInfo::INET
32    }
33
34    fn compatible(ty: &PgTypeInfo) -> bool {
35        *ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET
36    }
37}
38
39impl PgHasArrayType for IpNetwork {
40    fn array_type_info() -> PgTypeInfo {
41        PgTypeInfo::INET_ARRAY
42    }
43
44    fn array_compatible(ty: &PgTypeInfo) -> bool {
45        *ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY
46    }
47}
48
49impl Encode<'_, Postgres> for IpNetwork {
50    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
51        // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293
52        // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271
53
54        match self {
55            IpNetwork::V4(net) => {
56                buf.push(PGSQL_AF_INET); // ip_family
57                buf.push(net.prefix()); // ip_bits
58                buf.push(0); // is_cidr
59                buf.push(4); // nb (number of bytes)
60                buf.extend_from_slice(&net.ip().octets()) // address
61            }
62
63            IpNetwork::V6(net) => {
64                buf.push(PGSQL_AF_INET6); // ip_family
65                buf.push(net.prefix()); // ip_bits
66                buf.push(0); // is_cidr
67                buf.push(16); // nb (number of bytes)
68                buf.extend_from_slice(&net.ip().octets()); // address
69            }
70        }
71
72        IsNull::No
73    }
74
75    fn size_hint(&self) -> usize {
76        match self {
77            IpNetwork::V4(_) => 8,
78            IpNetwork::V6(_) => 20,
79        }
80    }
81}
82
83impl Decode<'_, Postgres> for IpNetwork {
84    fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
85        let bytes = match value.format() {
86            PgValueFormat::Binary => value.as_bytes()?,
87            PgValueFormat::Text => {
88                return Ok(value.as_str()?.parse()?);
89            }
90        };
91
92        if bytes.len() >= 8 {
93            let family = bytes[0];
94            let prefix = bytes[1];
95            let _is_cidr = bytes[2] != 0;
96            let len = bytes[3];
97
98            match family {
99                PGSQL_AF_INET => {
100                    if bytes.len() == 8 && len == 4 {
101                        let inet = Ipv4Network::new(
102                            Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]),
103                            prefix,
104                        )?;
105
106                        return Ok(IpNetwork::V4(inet));
107                    }
108                }
109
110                PGSQL_AF_INET6 => {
111                    if bytes.len() == 20 && len == 16 {
112                        let inet = Ipv6Network::new(
113                            Ipv6Addr::from([
114                                bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9],
115                                bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
116                                bytes[16], bytes[17], bytes[18], bytes[19],
117                            ]),
118                            prefix,
119                        )?;
120
121                        return Ok(IpNetwork::V6(inet));
122                    }
123                }
124
125                _ => {
126                    return Err(format!("unknown ip family {}", family).into());
127                }
128            }
129        }
130
131        Err("invalid data received when expecting an INET".into())
132    }
133}