Skip to main content

ogcapi_processes/
gdal_loader.rs

1use std::collections::{HashMap, HashSet};
2
3use anyhow::Result;
4use arrow::{
5    array::{Array, BinaryArray, RecordBatchReader, StringArray},
6    compute::cast,
7    datatypes::DataType,
8    ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream},
9    json::ArrayWriter,
10};
11use gdal::{
12    ArrowArrayStream, Dataset, cpl::CslStringList, spatial_ref::SpatialRef, vector::LayerAccess,
13};
14use schemars::{JsonSchema, schema_for};
15use serde::Deserialize;
16use url::Url;
17
18use ogcapi_drivers::{CollectionTransactions, postgres::Db};
19use ogcapi_types::{
20    common::{Bbox, Collection, Crs, Exception, Extent, SpatialExtent},
21    processes::{
22        Execute, Format, InlineOrRefData, Input, InputValueNoObject, Output, Process,
23        TransmissionMode,
24    },
25};
26
27use crate::{ProcessResponseBody, Processor};
28
29/// GDAL loader `Processor`
30///
31/// Process to load vector data.
32#[derive(Clone)]
33pub struct GdalLoader;
34
35/// Inputs for the `gdal-loader` process
36#[derive(Deserialize, Debug, JsonSchema)]
37pub struct GdalLoaderInputs {
38    /// Input file
39    pub input: String,
40
41    /// Set the collection name, defaults to layer name or `osm`
42    pub collection: String,
43
44    /// Filter input by layer name, imports all if not present
45    pub filter: Option<String>,
46
47    /// Source srs, if omitted tries to derive from the input layer
48    pub s_srs: Option<u32>,
49
50    /// Postgres database url
51    pub database_url: String,
52}
53
54impl GdalLoaderInputs {
55    pub fn execute_input(&self) -> HashMap<String, Input> {
56        let mut input = HashMap::from_iter([
57            (
58                "input".to_string(),
59                Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
60                    InputValueNoObject::String(self.input.to_owned()),
61                )),
62            ),
63            (
64                "collection".to_string(),
65                Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
66                    InputValueNoObject::String(self.collection.to_owned()),
67                )),
68            ),
69            (
70                "database_url".to_string(),
71                Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
72                    InputValueNoObject::String(self.database_url.to_owned()),
73                )),
74            ),
75        ]);
76
77        if let Some(filter) = &self.filter {
78            input.insert(
79                "filter".to_owned(),
80                Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
81                    InputValueNoObject::String(filter.to_owned()),
82                )),
83            );
84        }
85
86        if let Some(s_srs) = &self.s_srs {
87            input.insert(
88                "s_srs".to_owned(),
89                Input::InlineOrRefData(InlineOrRefData::InputValueNoObject(
90                    InputValueNoObject::Integer(*s_srs as i64),
91                )),
92            );
93        }
94
95        input
96    }
97}
98
99/// Outputs for the `gdal-loader` process
100#[derive(Clone, Debug, JsonSchema)]
101pub struct GdalLoaderOutputs {
102    pub collection: String,
103}
104
105impl GdalLoaderOutputs {
106    pub fn execute_output() -> HashMap<String, Output> {
107        HashMap::from([(
108            "greeting".to_string(),
109            Output {
110                format: Some(Format {
111                    media_type: Some("text/plain".to_string()),
112                    encoding: Some("utf8".to_string()),
113                    schema: None,
114                }),
115                transmission_mode: TransmissionMode::Value,
116            },
117        )])
118    }
119}
120
121impl TryFrom<ProcessResponseBody> for GdalLoaderOutputs {
122    type Error = Exception;
123
124    fn try_from(value: ProcessResponseBody) -> Result<Self, Self::Error> {
125        if let ProcessResponseBody::Requested(buf) = value {
126            Ok(GdalLoaderOutputs {
127                collection: String::from_utf8(buf).unwrap(),
128            })
129        } else {
130            Err(Exception::new("500"))
131        }
132    }
133}
134
135#[async_trait::async_trait]
136impl Processor for GdalLoader {
137    fn id(&self) -> &'static str {
138        "gdal-loader"
139    }
140
141    fn version(&self) -> &'static str {
142        "0.1.0"
143    }
144
145    fn process(&self) -> Result<Process> {
146        Process::try_new(
147            self.id(),
148            self.version(),
149            &schema_for!(GdalLoaderInputs).schema,
150            &schema_for!(GdalLoaderOutputs).schema,
151        )
152        .map_err(Into::into)
153    }
154
155    async fn execute(&self, execute: Execute) -> Result<ProcessResponseBody> {
156        // Parse input
157        let value = serde_json::to_value(execute.inputs)?;
158        let mut inputs: GdalLoaderInputs = serde_json::from_value(value)?;
159
160        // Handle http & zip
161        if inputs.input.to_lowercase().starts_with("http") {
162            inputs.input = format!("/vsicurl/{}", inputs.input);
163        }
164        if inputs.input.to_lowercase().ends_with("zip") {
165            inputs.input = format!("/vsizip/{}", inputs.input);
166        }
167
168        // Get collection
169        let collection = {
170            let dataset = Dataset::open(&inputs.input)?;
171
172            // Get layer
173            if dataset.layer_count() >= 1 && inputs.filter.is_none() {
174                inputs.filter = Some(dataset.layer(0)?.name());
175            }
176
177            if inputs.filter.is_none() {
178                return Err(Exception::new(format!(
179                    "Found multiple layers! Use the 'filter' option to specifiy one of:\n\t- {}",
180                    dataset
181                        .layers()
182                        .map(|l| l.name())
183                        .collect::<Vec<String>>()
184                        .join("\n\t- ")
185                ))
186                .into());
187            }
188
189            let layer = dataset.layer_by_name(inputs.filter.as_ref().unwrap())?;
190
191            // Get coordinate reference system
192            let spatial_ref_src = match inputs.s_srs {
193                Some(epsg) => SpatialRef::from_epsg(epsg)?,
194                None => match layer.spatial_ref() {
195                    Some(srs) => srs,
196                    None => {
197                        println!("Unknown spatial reference, falling back to `4326`");
198                        SpatialRef::from_epsg(4326)?
199                    }
200                },
201            };
202
203            let storage_crs = Crs::from_srid(spatial_ref_src.auth_code()?);
204
205            // Create collection (overwrite/delete existing)
206            Collection {
207                id: inputs.collection.clone(),
208                crs: Vec::from_iter(HashSet::from([
209                    Crs::default(),
210                    storage_crs.clone(),
211                    Crs::from_epsg(3857),
212                ])),
213                extent: layer.try_get_extent().unwrap().map(|e| Extent {
214                    spatial: Some(SpatialExtent {
215                        bbox: vec![Bbox::Bbox2D([e.MinX, e.MinY, e.MaxX, e.MaxY])],
216                        crs: storage_crs.to_owned(),
217                    }),
218                    temporal: None,
219                }),
220                storage_crs: Some(storage_crs.to_owned()),
221                ..Default::default()
222            }
223        };
224
225        // Setup driver
226        let db = Db::setup(&Url::parse(&inputs.database_url)?).await?;
227
228        db.delete_collection(&collection.id).await.unwrap();
229        db.create_collection(&collection).await.unwrap();
230
231        // Set concrete geometry type if possible https://github.com/georust/gdal/blob/00adecc94361228a2197224205fc9260d14d7549/gdal-sys/prebuilt-bindings/gdal_3.4.rs#L3454
232        if let Some((geometry_type, dimensions)) = {
233            let dataset = Dataset::open(&inputs.input)?;
234            let layer = dataset
235                .layer_by_name(inputs.filter.as_ref().unwrap())
236                .unwrap();
237
238            match layer.defn().geom_fields().next().unwrap().field_type() {
239                0 => {
240                    panic!("Unknown gemetry type.")
241                }
242                1 => Some(("POINT", 2)),
243                2 => Some(("LINESTRING", 2)),
244                3 => Some(("POLYGON", 2)),
245                4 => Some(("MULTIPOINT", 2)),
246                5 => Some(("MULTILINESTRING", 2)),
247                6 => Some(("MULTIPOLYGON", 2)),
248                2147483653 => Some(("MULTILINESTRINGZ", 3)),
249                2147483654 => Some(("MULTIPOLYGONZ", 3)),
250                i => {
251                    panic!("Unmaped geometry type `{i}`");
252                }
253            }
254        } {
255            sqlx::query("SELECT DropGeometryColumn ('items', $1, 'geom')")
256                .bind(&collection.id)
257                .execute(&db.pool)
258                .await?;
259
260            sqlx::query("SELECT AddGeometryColumn ('items', $1, 'geom', $2, $3, $4)")
261                .bind(&collection.id)
262                .bind(collection.storage_crs.unwrap().as_srid())
263                .bind(geometry_type)
264                .bind(dimensions)
265                .execute(&db.pool)
266                .await?;
267
268            sqlx::query(&format!(
269                r#"CREATE INDEX ON items."{}" USING gist (geom)"#,
270                &collection.id
271            ))
272            .execute(&db.pool)
273            .await?;
274        }
275
276        // Load features
277        // let _count = layer.lock().unwrap().feature_count();
278
279        let dataset = Dataset::open(&inputs.input)?;
280        let mut layer = dataset.layer_by_name(inputs.filter.as_ref().unwrap())?;
281
282        // Instantiate an `ArrowArrayStream` for OGR to write into
283        let mut output_stream = FFI_ArrowArrayStream::empty();
284
285        // Take a pointer to it
286        let output_stream_ptr = &mut output_stream as *mut FFI_ArrowArrayStream;
287
288        // GDAL includes its own copy of the ArrowArrayStream struct definition. These are guaranteed
289        // to be the same across implementations, but we need to manually cast between the two for Rust
290        // to allow it.
291        let gdal_pointer: *mut ArrowArrayStream = output_stream_ptr.cast();
292
293        // Read the layer's data into our provisioned pointer
294        unsafe { layer.read_arrow_stream(gdal_pointer, &CslStringList::new())? }
295
296        let arrow_stream_reader = ArrowArrayStreamReader::try_new(output_stream)?;
297        let schema = arrow_stream_reader.schema();
298
299        // Get the index of the fid and geom column
300        let fid_column_index = schema.column_with_name("OGC_FID").unwrap().0;
301        let mut geom_column_index = schema.column_with_name("wkb_geometry").unwrap().0;
302
303        // adjust for later column removal
304        if fid_column_index < geom_column_index {
305            geom_column_index -= 1;
306        }
307
308        let id = &collection.id;
309        let pool = &db.pool;
310
311        for result in arrow_stream_reader {
312            let mut batch = result?;
313            println!("Got some batch with {} features", batch.num_rows());
314
315            // Get the id column
316            let fid_column = batch.remove_column(fid_column_index);
317            let fid_column = cast(&fid_column, &DataType::Utf8)?;
318            let fid_array = fid_column.as_any().downcast_ref::<StringArray>().unwrap();
319            let mut fid_vec = Vec::with_capacity(fid_array.len());
320            (0..fid_array.len()).for_each(|i| fid_vec.push(fid_array.value(i).to_owned()));
321
322            // Get the geometry column
323            let geom_column = batch.remove_column(geom_column_index);
324            let geom_array = geom_column.as_any().downcast_ref::<BinaryArray>().unwrap();
325            let mut geom_vec = Vec::with_capacity(geom_array.len());
326            (0..geom_array.len()).for_each(|i| geom_vec.push(geom_array.value(i).to_owned()));
327
328            // Get the properties
329            let buf = Vec::new();
330            let mut writer = ArrayWriter::new(buf);
331            writer.write_batches(&[&batch])?;
332            writer.finish()?;
333
334            let properties = String::from_utf8(writer.into_inner())?;
335
336            tokio::task::block_in_place(move || {
337                tokio::runtime::Handle::current().block_on(
338                    sqlx::query(&format!(
339                        r#"
340                INSERT INTO items."{}" (id, properties, geom)
341                SELECT * FROM UNNEST(
342                    $1::text[],
343                    (SELECT
344                        array_agg(properties)
345                    FROM (
346                        SELECT jsonb_array_elements($2::jsonb) AS properties
347                    )),
348                    $3::bytea[]
349                )
350                "#,
351                        id
352                    ))
353                    .bind(fid_vec)
354                    .bind(properties)
355                    .bind(geom_vec)
356                    .execute(pool),
357                )
358            })?;
359        }
360
361        Ok(ProcessResponseBody::Requested(
362            inputs.collection.as_bytes().to_owned(),
363        ))
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use ogcapi_types::processes::Execute;
370
371    use crate::{
372        Processor,
373        gdal_loader::{GdalLoader, GdalLoaderInputs, GdalLoaderOutputs},
374    };
375
376    #[tokio::test(flavor = "multi_thread")]
377    async fn test_loader() {
378        let loader = GdalLoader;
379        assert_eq!(loader.id(), "gdal-loader");
380
381        println!(
382            "Process:\n{}",
383            serde_json::to_string_pretty(&loader.process().unwrap()).unwrap()
384        );
385
386        let input = GdalLoaderInputs {
387            input: "../data/ne_10m_railroads_north_america.geojson".to_owned(),
388            collection: "streets-gdal".to_string(),
389            filter: None,
390            s_srs: None,
391            database_url: "postgresql://postgres:password@localhost:5433/ogcapi".to_string(),
392        };
393
394        let execute = Execute {
395            inputs: input.execute_input(),
396            ..Default::default()
397        };
398
399        let output: GdalLoaderOutputs = loader.execute(execute).await.unwrap().try_into().unwrap();
400        assert_eq!(output.collection, "streets-gdal");
401    }
402}