1use std::collections::{HashMap, HashSet};
2
3use anyhow::Result;
4use arrow::{
5 array::{Array, BinaryArray, RecordBatchReader, StringArray},
6 compute::cast,
7 datatypes::DataType,
8 ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream},
9 json::ArrayWriter,
10};
11use gdal::{
12 ArrowArrayStream, Dataset, cpl::CslStringList, spatial_ref::SpatialRef, vector::LayerAccess,
13};
14use schemars::{JsonSchema, schema_for};
15use serde::Deserialize;
16use url::Url;
17
18use ogcapi_drivers::{CollectionTransactions, postgres::Db};
19use ogcapi_types::{
20 common::{Bbox, Collection, Crs, Exception, Extent, SpatialExtent},
21 processes::{
22 Execute, Format, InlineOrRefData, Input, InputValueNoObject, Output, Process,
23 TransmissionMode,
24 },
25};
26
27use crate::{ProcessResponseBody, Processor};
28
29#[derive(Clone)]
33pub struct GdalLoader;
34
35#[derive(Deserialize, Debug, JsonSchema)]
37pub struct GdalLoaderInputs {
38 pub input: String,
40
41 pub collection: String,
43
44 pub filter: Option<String>,
46
47 pub s_srs: Option<u32>,
49
50 pub database_url: String,
52}
53
54impl GdalLoaderInputs {
55 pub fn execute_input(&self) -> HashMap<String, Input> {
56 let mut input = HashMap::from_iter([
57 (
58 "input".to_string(),
59 Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
60 InputValueNoObject::String(self.input.to_owned()),
61 )),
62 ),
63 (
64 "collection".to_string(),
65 Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
66 InputValueNoObject::String(self.collection.to_owned()),
67 )),
68 ),
69 (
70 "database_url".to_string(),
71 Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
72 InputValueNoObject::String(self.database_url.to_owned()),
73 )),
74 ),
75 ]);
76
77 if let Some(filter) = &self.filter {
78 input.insert(
79 "filter".to_owned(),
80 Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
81 InputValueNoObject::String(filter.to_owned()),
82 )),
83 );
84 }
85
86 if let Some(s_srs) = &self.s_srs {
87 input.insert(
88 "s_srs".to_owned(),
89 Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
90 InputValueNoObject::Integer(*s_srs as i64),
91 )),
92 );
93 }
94
95 input
96 }
97}
98
99#[derive(Clone, Debug, JsonSchema)]
101pub struct GdalLoaderOutputs {
102 pub collection: String,
103}
104
105impl GdalLoaderOutputs {
106 pub fn execute_output() -> HashMap<String, Output> {
107 HashMap::from([(
108 "greeting".to_string(),
109 Output {
110 format: Some(Format {
111 media_type: Some("text/plain".to_string()),
112 encoding: Some("utf8".to_string()),
113 schema: None,
114 }),
115 transmission_mode: TransmissionMode::Value,
116 },
117 )])
118 }
119}
120
121impl TryFrom<ProcessResponseBody> for GdalLoaderOutputs {
122 type Error = Exception;
123
124 fn try_from(value: ProcessResponseBody) -> Result<Self, Self::Error> {
125 if let ProcessResponseBody::Requested(buf) = value {
126 Ok(GdalLoaderOutputs {
127 collection: String::from_utf8(buf).unwrap(),
128 })
129 } else {
130 Err(Exception::new("500"))
131 }
132 }
133}
134
135#[async_trait::async_trait]
136impl Processor for GdalLoader {
137 fn id(&self) -> &'static str {
138 "gdal-loader"
139 }
140
141 fn version(&self) -> &'static str {
142 "0.1.0"
143 }
144
145 fn process(&self) -> Result<Process> {
146 Process::try_new(
147 self.id(),
148 self.version(),
149 &schema_for!(GdalLoaderInputs).schema,
150 &schema_for!(GdalLoaderOutputs).schema,
151 )
152 .map_err(Into::into)
153 }
154
155 async fn execute(&self, execute: Execute) -> Result<ProcessResponseBody> {
156 let value = serde_json::to_value(execute.inputs)?;
158 let mut inputs: GdalLoaderInputs = serde_json::from_value(value)?;
159
160 if inputs.input.to_lowercase().starts_with("http") {
162 inputs.input = format!("/vsicurl/{}", inputs.input);
163 }
164 if inputs.input.to_lowercase().ends_with("zip") {
165 inputs.input = format!("/vsizip/{}", inputs.input);
166 }
167
168 let collection = {
170 let dataset = Dataset::open(&inputs.input)?;
171
172 if dataset.layer_count() >= 1 && inputs.filter.is_none() {
174 inputs.filter = Some(dataset.layer(0)?.name());
175 }
176
177 if inputs.filter.is_none() {
178 return Err(Exception::new(format!(
179 "Found multiple layers! Use the 'filter' option to specifiy one of:\n\t- {}",
180 dataset
181 .layers()
182 .map(|l| l.name())
183 .collect::<Vec<String>>()
184 .join("\n\t- ")
185 ))
186 .into());
187 }
188
189 let layer = dataset.layer_by_name(inputs.filter.as_ref().unwrap())?;
190
191 let spatial_ref_src = match inputs.s_srs {
193 Some(epsg) => SpatialRef::from_epsg(epsg)?,
194 None => match layer.spatial_ref() {
195 Some(srs) => srs,
196 None => {
197 println!("Unknown spatial reference, falling back to `4326`");
198 SpatialRef::from_epsg(4326)?
199 }
200 },
201 };
202
203 let storage_crs = Crs::from_srid(spatial_ref_src.auth_code()?);
204
205 Collection {
207 id: inputs.collection.clone(),
208 crs: Vec::from_iter(HashSet::from([
209 Crs::default(),
210 storage_crs.clone(),
211 Crs::from_epsg(3857),
212 ])),
213 extent: layer.try_get_extent().unwrap().map(|e| Extent {
214 spatial: Some(SpatialExtent {
215 bbox: vec![Bbox::Bbox2D([e.MinX, e.MinY, e.MaxX, e.MaxY])],
216 crs: storage_crs.to_owned(),
217 }),
218 temporal: None,
219 }),
220 storage_crs: Some(storage_crs.to_owned()),
221 ..Default::default()
222 }
223 };
224
225 let db = Db::setup(&Url::parse(&inputs.database_url)?).await?;
227
228 db.delete_collection(&collection.id).await.unwrap();
229 db.create_collection(&collection).await.unwrap();
230
231 if let Some((geometry_type, dimensions)) = {
233 let dataset = Dataset::open(&inputs.input)?;
234 let layer = dataset
235 .layer_by_name(inputs.filter.as_ref().unwrap())
236 .unwrap();
237
238 match layer.defn().geom_fields().next().unwrap().field_type() {
239 0 => {
240 panic!("Unknown gemetry type.")
241 }
242 1 => Some(("POINT", 2)),
243 2 => Some(("LINESTRING", 2)),
244 3 => Some(("POLYGON", 2)),
245 4 => Some(("MULTIPOINT", 2)),
246 5 => Some(("MULTILINESTRING", 2)),
247 6 => Some(("MULTIPOLYGON", 2)),
248 2147483653 => Some(("MULTILINESTRINGZ", 3)),
249 2147483654 => Some(("MULTIPOLYGONZ", 3)),
250 i => {
251 panic!("Unmaped geometry type `{i}`");
252 }
253 }
254 } {
255 sqlx::query("SELECT DropGeometryColumn ('items', $1, 'geom')")
256 .bind(&collection.id)
257 .execute(&db.pool)
258 .await?;
259
260 sqlx::query("SELECT AddGeometryColumn ('items', $1, 'geom', $2, $3, $4)")
261 .bind(&collection.id)
262 .bind(collection.storage_crs.unwrap().as_srid())
263 .bind(geometry_type)
264 .bind(dimensions)
265 .execute(&db.pool)
266 .await?;
267
268 sqlx::query(&format!(
269 r#"CREATE INDEX ON items."{}" USING gist (geom)"#,
270 &collection.id
271 ))
272 .execute(&db.pool)
273 .await?;
274 }
275
276 let dataset = Dataset::open(&inputs.input)?;
280 let mut layer = dataset.layer_by_name(inputs.filter.as_ref().unwrap())?;
281
282 let mut output_stream = FFI_ArrowArrayStream::empty();
284
285 let output_stream_ptr = &mut output_stream as *mut FFI_ArrowArrayStream;
287
288 let gdal_pointer: *mut ArrowArrayStream = output_stream_ptr.cast();
292
293 unsafe { layer.read_arrow_stream(gdal_pointer, &CslStringList::new())? }
295
296 let arrow_stream_reader = ArrowArrayStreamReader::try_new(output_stream)?;
297 let schema = arrow_stream_reader.schema();
298
299 let fid_column_index = schema.column_with_name("OGC_FID").unwrap().0;
301 let mut geom_column_index = schema.column_with_name("wkb_geometry").unwrap().0;
302
303 if fid_column_index < geom_column_index {
305 geom_column_index -= 1;
306 }
307
308 let id = &collection.id;
309 let pool = &db.pool;
310
311 for result in arrow_stream_reader {
312 let mut batch = result?;
313 println!("Got some batch with {} features", batch.num_rows());
314
315 let fid_column = batch.remove_column(fid_column_index);
317 let fid_column = cast(&fid_column, &DataType::Utf8)?;
318 let fid_array = fid_column.as_any().downcast_ref::<StringArray>().unwrap();
319 let mut fid_vec = Vec::with_capacity(fid_array.len());
320 (0..fid_array.len()).for_each(|i| fid_vec.push(fid_array.value(i).to_owned()));
321
322 let geom_column = batch.remove_column(geom_column_index);
324 let geom_array = geom_column.as_any().downcast_ref::<BinaryArray>().unwrap();
325 let mut geom_vec = Vec::with_capacity(geom_array.len());
326 (0..geom_array.len()).for_each(|i| geom_vec.push(geom_array.value(i).to_owned()));
327
328 let buf = Vec::new();
330 let mut writer = ArrayWriter::new(buf);
331 writer.write_batches(&[&batch])?;
332 writer.finish()?;
333
334 let properties = String::from_utf8(writer.into_inner())?;
335
336 tokio::task::block_in_place(move || {
337 tokio::runtime::Handle::current().block_on(
338 sqlx::query(&format!(
339 r#"
340 INSERT INTO items."{}" (id, properties, geom)
341 SELECT * FROM UNNEST(
342 $1::text[],
343 (SELECT
344 array_agg(properties)
345 FROM (
346 SELECT jsonb_array_elements($2::jsonb) AS properties
347 )),
348 $3::bytea[]
349 )
350 "#,
351 id
352 ))
353 .bind(fid_vec)
354 .bind(properties)
355 .bind(geom_vec)
356 .execute(pool),
357 )
358 })?;
359 }
360
361 Ok(ProcessResponseBody::Requested(
362 inputs.collection.as_bytes().to_owned(),
363 ))
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use ogcapi_types::processes::Execute;
370
371 use crate::{
372 Processor,
373 gdal_loader::{GdalLoader, GdalLoaderInputs, GdalLoaderOutputs},
374 };
375
376 #[tokio::test(flavor = "multi_thread")]
377 async fn test_loader() {
378 let loader = GdalLoader;
379 assert_eq!(loader.id(), "gdal-loader");
380
381 println!(
382 "Process:\n{}",
383 serde_json::to_string_pretty(&loader.process().unwrap()).unwrap()
384 );
385
386 let input = GdalLoaderInputs {
387 input: "../data/ne_10m_railroads_north_america.geojson".to_owned(),
388 collection: "streets-gdal".to_string(),
389 filter: None,
390 s_srs: None,
391 database_url: "postgresql://postgres:password@localhost:5433/ogcapi".to_string(),
392 };
393
394 let execute = Execute {
395 inputs: input.execute_input(),
396 ..Default::default()
397 };
398
399 let output: GdalLoaderOutputs = loader.execute(execute).await.unwrap().try_into().unwrap();
400 assert_eq!(output.collection, "streets-gdal");
401 }
402}