1use std::ops::Deref;
2
3#[cfg(feature = "sqlx")]
4use geozero::{wkb, ToWkb};
5
6#[cfg(feature = "sqlx")]
7use sqlx::{
8 encode::IsNull,
9 postgres::{PgHasArrayType, PgTypeInfo, PgValueRef},
10 Postgres, ValueRef,
11};
12
13#[derive(Clone, Debug)]
14pub struct Geometry(pub geo::Geometry<f64>);
15
16impl PartialEq for Geometry {
17 fn eq(&self, other: &Self) -> bool {
18 self.0 == other.0
19 }
20}
21
22impl Eq for Geometry {}
23
24impl Deref for Geometry {
25 type Target = geo::Geometry<f64>;
26
27 fn deref(&self) -> &Self::Target {
28 &self.0
29 }
30}
31
32#[cfg(feature = "sqlx")]
33impl sqlx::Type<Postgres> for Geometry {
34 fn type_info() -> PgTypeInfo {
35 PgTypeInfo::with_name("geometry")
36 }
37}
38
39#[cfg(feature = "sqlx")]
40impl PgHasArrayType for Geometry {
41 fn array_type_info() -> PgTypeInfo {
42 PgTypeInfo::with_name("_geometry")
43 }
44}
45
46impl Geometry {
47 pub fn into_inner(self) -> geo::Geometry<f64> {
48 self.0
49 }
50}
51
52#[cfg(feature = "sqlx")]
53impl<'de> sqlx::Decode<'de, Postgres> for Geometry {
54 fn decode(value: PgValueRef<'de>) -> Result<Self, sqlx::error::BoxDynError> {
55 if value.is_null() {
56 return Err(Box::new(sqlx::error::UnexpectedNullError));
57 }
58 let decode = wkb::Decode::<geo::Geometry<f64>>::decode(value)?;
59 Ok(Geometry(decode.geometry.expect(
60 "geometry parsing failed without error for non-null value",
61 )))
62 }
63}
64
65#[cfg(feature = "sqlx")]
66impl<'en> sqlx::Encode<'en, Postgres> for Geometry {
67 fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> IsNull {
68 let x = self
69 .0
70 .to_ewkb(geozero::CoordDimensions::xy(), None)
71 .unwrap();
72 buf.extend(x);
73 sqlx::encode::IsNull::No
74 }
75}
76
77#[cfg(feature = "serde")]
78impl serde::Serialize for Geometry {
79 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
80 where
81 S: serde::Serializer,
82 {
83 use geozero::ToJson;
84 use serde::ser::{Error, SerializeMap};
85 use serde_json::Value;
86 use std::collections::BTreeMap;
87
88 let s = self.0.to_json().map_err(Error::custom)?;
89 let s: BTreeMap<String, Value> = serde_json::from_str(&s).map_err(Error::custom)?;
90
91 let mut map = serializer.serialize_map(Some(s.len()))?;
92 for (k, v) in s {
93 map.serialize_entry(&k, &v)?;
94 }
95 map.end()
96 }
97}
98
99#[cfg(all(test, feature = "sqlx"))]
100mod sqlx_tests {
101 use super::Geometry;
102 use geo::{line_string, LineString, MultiLineString, Polygon};
103
104 async fn pg_roundtrip(data_to: &Geometry, type_name: &str) -> Geometry {
105 use sqlx::postgres::PgPoolOptions;
106 let conn = PgPoolOptions::new()
107 .max_connections(5)
108 .connect("postgres://postgres:password@localhost/postgres")
109 .await
110 .unwrap();
111 let mut conn = conn.begin().await.unwrap();
112
113 sqlx::query(&format!(
114 "CREATE TABLE test ( id SERIAL PRIMARY KEY, geom GEOMETRY({type_name}, 26910) )"
115 ))
116 .execute(&mut *conn)
117 .await
118 .unwrap();
119
120 sqlx::query("INSERT INTO test (geom) VALUES ($1)")
121 .bind(data_to)
122 .execute(&mut *conn)
123 .await
124 .unwrap();
125
126 let (data_from,): (Geometry,) = sqlx::query_as("SELECT geom FROM test")
127 .fetch_one(&mut *conn)
128 .await
129 .unwrap();
130
131 data_from
132 }
133
134 #[tokio::test]
135 async fn point() {
136 let data_to = Geometry(geo::Geometry::Point((0., 1.).into()));
137 let data_from = pg_roundtrip(&data_to, "Point").await;
138 assert_eq!(data_to, data_from);
139 }
140
141 #[tokio::test]
142 async fn line() {
143 let open_line_string: LineString<f64> = line_string![(x: 0., y: 0.), (x: 5., y: 0.)];
144 let data_to = Geometry(geo::Geometry::MultiLineString(MultiLineString(vec![
145 open_line_string,
146 ])));
147 let data_from = pg_roundtrip(&data_to, "MultiLineString").await;
148 assert_eq!(data_to, data_from);
149 }
150
151 #[tokio::test]
152 async fn polygon() {
153 let polygon = Polygon::new(
154 LineString::from(vec![(0., 0.), (1., 1.), (1., 0.), (0., 0.)]),
155 vec![LineString::from(vec![
156 (0.1, 0.1),
157 (0.9, 0.9),
158 (0.9, 0.1),
159 (0.1, 0.1),
160 ])],
161 );
162 let data_to = Geometry(geo::Geometry::Polygon(polygon));
163 let data_from = pg_roundtrip(&data_to, "Polygon").await;
164 assert_eq!(data_to, data_from);
165 }
166}