1use std::ops::Deref;
2
3#[cfg(feature = "sqlx")]
4use geozero::{wkb, ToWkb};
5
6#[cfg(feature = "sqlx")]
7use sqlx::{
8    encode::IsNull,
9    postgres::{PgHasArrayType, PgTypeInfo, PgValueRef},
10    Postgres, ValueRef,
11};
12
13#[derive(Clone, Debug)]
14pub struct Geometry(pub geo::Geometry<f64>);
15
16impl PartialEq for Geometry {
17    fn eq(&self, other: &Self) -> bool {
18        self.0 == other.0
19    }
20}
21
22impl Eq for Geometry {}
23
24impl Deref for Geometry {
25    type Target = geo::Geometry<f64>;
26
27    fn deref(&self) -> &Self::Target {
28        &self.0
29    }
30}
31
32#[cfg(feature = "sqlx")]
33impl sqlx::Type<Postgres> for Geometry {
34    fn type_info() -> PgTypeInfo {
35        PgTypeInfo::with_name("geometry")
36    }
37}
38
39#[cfg(feature = "sqlx")]
40impl PgHasArrayType for Geometry {
41    fn array_type_info() -> PgTypeInfo {
42        PgTypeInfo::with_name("_geometry")
43    }
44}
45
46impl Geometry {
47    pub fn into_inner(self) -> geo::Geometry<f64> {
48        self.0
49    }
50}
51
52#[cfg(feature = "sqlx")]
53impl<'de> sqlx::Decode<'de, Postgres> for Geometry {
54    fn decode(value: PgValueRef<'de>) -> Result<Self, sqlx::error::BoxDynError> {
55        if value.is_null() {
56            return Err(Box::new(sqlx::error::UnexpectedNullError));
57        }
58        let decode = wkb::Decode::<geo::Geometry<f64>>::decode(value)?;
59        Ok(Geometry(decode.geometry.expect(
60            "geometry parsing failed without error for non-null value",
61        )))
62    }
63}
64
65#[cfg(feature = "sqlx")]
66impl<'en> sqlx::Encode<'en, Postgres> for Geometry {
67    fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> IsNull {
68        let x = self
69            .0
70            .to_ewkb(geozero::CoordDimensions::xy(), None)
71            .unwrap();
72        buf.extend(x);
73        sqlx::encode::IsNull::No
74    }
75}
76
77#[cfg(feature = "serde")]
78impl serde::Serialize for Geometry {
79    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
80    where
81        S: serde::Serializer,
82    {
83        use geozero::ToJson;
84        use serde::ser::{Error, SerializeMap};
85        use serde_json::Value;
86        use std::collections::BTreeMap;
87
88        let s = self.0.to_json().map_err(Error::custom)?;
89        let s: BTreeMap<String, Value> = serde_json::from_str(&s).map_err(Error::custom)?;
90
91        let mut map = serializer.serialize_map(Some(s.len()))?;
92        for (k, v) in s {
93            map.serialize_entry(&k, &v)?;
94        }
95        map.end()
96    }
97}
98
99#[cfg(all(test, feature = "sqlx"))]
100mod sqlx_tests {
101    use super::Geometry;
102    use geo::{line_string, LineString, MultiLineString, Polygon};
103
104    async fn pg_roundtrip(data_to: &Geometry, type_name: &str) -> Geometry {
105        use sqlx::postgres::PgPoolOptions;
106        let conn = PgPoolOptions::new()
107            .max_connections(5)
108            .connect("postgres://postgres:password@localhost/postgres")
109            .await
110            .unwrap();
111        let mut conn = conn.begin().await.unwrap();
112
113        sqlx::query(&format!(
114            "CREATE TABLE test ( id SERIAL PRIMARY KEY, geom GEOMETRY({type_name}, 26910) )"
115        ))
116        .execute(&mut *conn)
117        .await
118        .unwrap();
119
120        sqlx::query("INSERT INTO test (geom) VALUES ($1)")
121            .bind(data_to)
122            .execute(&mut *conn)
123            .await
124            .unwrap();
125
126        let (data_from,): (Geometry,) = sqlx::query_as("SELECT geom FROM test")
127            .fetch_one(&mut *conn)
128            .await
129            .unwrap();
130
131        data_from
132    }
133
134    #[tokio::test]
135    async fn point() {
136        let data_to = Geometry(geo::Geometry::Point((0., 1.).into()));
137        let data_from = pg_roundtrip(&data_to, "Point").await;
138        assert_eq!(data_to, data_from);
139    }
140
141    #[tokio::test]
142    async fn line() {
143        let open_line_string: LineString<f64> = line_string![(x: 0., y: 0.), (x: 5., y: 0.)];
144        let data_to = Geometry(geo::Geometry::MultiLineString(MultiLineString(vec![
145            open_line_string,
146        ])));
147        let data_from = pg_roundtrip(&data_to, "MultiLineString").await;
148        assert_eq!(data_to, data_from);
149    }
150
151    #[tokio::test]
152    async fn polygon() {
153        let polygon = Polygon::new(
154            LineString::from(vec![(0., 0.), (1., 1.), (1., 0.), (0., 0.)]),
155            vec![LineString::from(vec![
156                (0.1, 0.1),
157                (0.9, 0.9),
158                (0.9, 0.1),
159                (0.1, 0.1),
160            ])],
161        );
162        let data_to = Geometry(geo::Geometry::Polygon(polygon));
163        let data_from = pg_roundtrip(&data_to, "Polygon").await;
164        assert_eq!(data_to, data_from);
165    }
166}