stac_duckdb/
client.rs

1use crate::{Error, Extension, Result};
2use arrow_array::{RecordBatch, RecordBatchIterator};
3use chrono::DateTime;
4use cql2::{Expr, ToDuckSQL};
5use duckdb::{Connection, types::Value};
6use geo::BoundingRect;
7use geojson::Geometry;
8use stac::{Collection, SpatialExtent, TemporalExtent, geoarrow::DATETIME_COLUMNS};
9use stac_api::{Direction, Search};
10use std::ops::{Deref, DerefMut};
11
12/// Default hive partitioning value
13pub const DEFAULT_USE_HIVE_PARTITIONING: bool = false;
14
15/// Default convert wkb value.
16pub const DEFAULT_CONVERT_WKB: bool = true;
17
18/// The default collection description.
19pub const DEFAULT_COLLECTION_DESCRIPTION: &str =
20    "Auto-generated collection from stac-geoparquet extents";
21
22/// The default union by name value.
23pub const DEFAULT_UNION_BY_NAME: bool = true;
24
25/// A client for making DuckDB requests for STAC objects.
26#[derive(Debug)]
27pub struct Client {
28    connection: Connection,
29
30    /// Whether to use hive partitioning
31    pub use_hive_partitioning: bool,
32
33    /// Whether to convert WKB to native geometries.
34    ///
35    /// If False, WKB metadata will be added.
36    pub convert_wkb: bool,
37
38    /// Whether to use `union_by_name` when querying.
39    ///
40    /// Defaults to true.
41    pub union_by_name: bool,
42}
43
44impl Client {
45    /// Creates a new client with an in-memory DuckDB connection.
46    ///
47    /// This function will install the spatial extension. If you'd like to
48    /// manage your own extensions (e.g. if your extensions are stored in a
49    /// different location), set things up then use `connection.into()` to get a
50    /// new `Client`.
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use stac_duckdb::Client;
56    ///
57    /// let client = Client::new().unwrap();
58    /// ```
59    pub fn new() -> Result<Client> {
60        let connection = Connection::open_in_memory()?;
61        connection.execute("INSTALL spatial", [])?;
62        connection.execute("LOAD spatial", [])?;
63        connection.execute("INSTALL icu", [])?;
64        connection.execute("LOAD icu", [])?;
65        Ok(connection.into())
66    }
67
68    /// Returns a vector of all extensions.
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use stac_duckdb::Client;
74    ///
75    /// let client = Client::new().unwrap();
76    /// let extensions = client.extensions().unwrap();
77    /// ```
78    pub fn extensions(&self) -> Result<Vec<Extension>> {
79        let mut statement = self.prepare(
80            "SELECT extension_name, loaded, installed, install_path, description, extension_version, install_mode, installed_from FROM duckdb_extensions();",
81        )?;
82        let extensions = statement
83            .query_map([], |row| {
84                Ok(Extension {
85                    name: row.get("extension_name")?,
86                    loaded: row.get("loaded")?,
87                    installed: row.get("installed")?,
88                    install_path: row.get("install_path")?,
89                    description: row.get("description")?,
90                    version: row.get("extension_version")?,
91                    install_mode: row.get("install_mode")?,
92                    installed_from: row.get("installed_from")?,
93                })
94            })?
95            .collect::<std::result::Result<Vec<_>, duckdb::Error>>()?;
96        Ok(extensions)
97    }
98
99    /// Returns one or more [stac::Collection] from the items in the stac-geoparquet file.
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use stac_duckdb::Client;
105    ///
106    /// let client = Client::new().unwrap();
107    /// let collections = client.collections("data/100-sentinel-2-items.parquet").unwrap();
108    /// ```
109    pub fn collections(&self, href: &str) -> Result<Vec<Collection>> {
110        let start_datetime= if self.prepare(&format!(
111            "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'start_datetime'",
112            self.format_parquet_href(href)
113        ))?.query([])?.next()?.is_some() {
114            "strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
115        } else {
116            "strftime(min(datetime), '%xT%X%z')"
117        };
118        let end_datetime = if self
119            .prepare(&format!(
120            "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'",
121            self.format_parquet_href(href)
122        ))?
123            .query([])?
124            .next()?
125            .is_some()
126        {
127            "strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
128        } else {
129            "strftime(max(datetime), '%xT%X%z')"
130        };
131        let mut statement = self.prepare(&format!(
132            "SELECT DISTINCT collection FROM {}",
133            self.format_parquet_href(href)
134        ))?;
135        let mut collections = Vec::new();
136        for row in statement.query_map([], |row| row.get::<_, String>(0))? {
137            let collection_id = row?;
138            let mut statement = self.connection.prepare(&
139                format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1", start_datetime, end_datetime,
140                self.format_parquet_href(href)
141            ))?;
142            let row = statement.query_row([&collection_id], |row| {
143                Ok((
144                    row.get::<_, String>(0)?,
145                    row.get::<_, String>(1)?,
146                    row.get::<_, String>(2)?,
147                ))
148            })?;
149            let mut collection = Collection::new(collection_id, DEFAULT_COLLECTION_DESCRIPTION);
150            let geometry: geo::Geometry = Geometry::from_json_value(serde_json::from_str(&row.0)?)
151                .map_err(Box::new)?
152                .try_into()
153                .map_err(Box::new)?;
154            if let Some(bbox) = geometry.bounding_rect() {
155                collection.extent.spatial = SpatialExtent {
156                    bbox: vec![bbox.into()],
157                };
158            }
159            collection.extent.temporal = TemporalExtent {
160                interval: vec![[
161                    Some(DateTime::parse_from_str(&row.1, "%FT%T%#z")?.into()),
162                    Some(DateTime::parse_from_str(&row.2, "%FT%T%#z")?.into()),
163                ]],
164            };
165            collections.push(collection);
166        }
167        Ok(collections)
168    }
169
170    /// Searches a single stac-geoparquet file.
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// use stac_duckdb::Client;
176    ///
177    /// let client = Client::new().unwrap();
178    /// let item_collection = client.search("data/100-sentinel-2-items.parquet", Default::default()).unwrap();
179    /// ```
180    pub fn search(&self, href: &str, search: Search) -> Result<stac_api::ItemCollection> {
181        let record_batches = self.search_to_arrow(href, search)?;
182        if record_batches.is_empty() {
183            Ok(Default::default())
184        } else {
185            let schema = record_batches[0].schema();
186            let item_collection = stac::geoarrow::json::from_record_batch_reader(
187                RecordBatchIterator::new(record_batches.into_iter().map(Ok), schema),
188            )?;
189            Ok(item_collection.into())
190        }
191    }
192
193    /// Searches to an iterator of record batches.
194    ///
195    /// # Examples
196    ///
197    /// ```
198    /// use stac_duckdb::Client;
199    ///
200    /// let client = Client::new().unwrap();
201    /// let record_batches = client.search_to_arrow("data/100-sentinel-2-items.parquet", Default::default()).unwrap();
202    /// ```
203    pub fn search_to_arrow(&self, href: &str, search: Search) -> Result<Vec<RecordBatch>> {
204        // TODO can we return an iterator?
205
206        // Note that we pull out some fields early so we can avoid closing some search strings below.
207
208        if search.items.query.is_some() {
209            return Err(Error::QueryNotImplemented);
210        }
211
212        // Check which columns we'll be selecting
213        let mut statement = self.prepare(&format!(
214            "SELECT column_name FROM (DESCRIBE SELECT * from {})",
215            self.format_parquet_href(href)
216        ))?;
217        let mut has_start_datetime = false;
218        let mut has_end_datetime = false;
219        let mut column_names = Vec::new();
220        let mut columns = Vec::new();
221        for row in statement.query_map([], |row| row.get::<_, String>(0))? {
222            let column = row?;
223            if column == "start_datetime" {
224                has_start_datetime = true;
225            }
226            if column == "end_datetime" {
227                has_end_datetime = true;
228            }
229
230            if let Some(fields) = search.fields.as_ref() {
231                if fields.exclude.contains(&column)
232                    || !(fields.include.is_empty() || fields.include.contains(&column))
233                {
234                    continue;
235                }
236            }
237
238            if column == "geometry" {
239                columns.push("ST_AsWKB(geometry) geometry".to_string());
240            } else if DATETIME_COLUMNS.contains(&column.as_str()) {
241                columns.push(format!("\"{column}\"::TIMESTAMPTZ {column}"))
242            } else {
243                columns.push(format!("\"{column}\""));
244            }
245            column_names.push(column);
246        }
247
248        // Get limit and offset
249        let limit = search.items.limit;
250        let offset = search
251            .items
252            .additional_fields
253            .get("offset")
254            .and_then(|v| v.as_i64());
255
256        // Build order_by
257        let mut order_by = Vec::with_capacity(search.sortby.len());
258        for sortby in &search.sortby {
259            order_by.push(format!(
260                "\"{}\" {}",
261                sortby.field,
262                match sortby.direction {
263                    Direction::Ascending => "ASC",
264                    Direction::Descending => "DESC",
265                }
266            ));
267        }
268
269        // Build wheres and params
270        let mut wheres = Vec::new();
271        let mut params = Vec::new();
272        if !search.ids.is_empty() {
273            wheres.push(format!(
274                "id IN ({})",
275                (0..search.ids.len())
276                    .map(|_| "?")
277                    .collect::<Vec<_>>()
278                    .join(",")
279            ));
280            params.extend(search.ids.into_iter().map(Value::Text));
281        }
282        if let Some(intersects) = search.intersects {
283            wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
284            params.push(Value::Text(intersects.to_string()));
285        }
286        if !search.collections.is_empty() {
287            wheres.push(format!(
288                "collection IN ({})",
289                (0..search.collections.len())
290                    .map(|_| "?")
291                    .collect::<Vec<_>>()
292                    .join(",")
293            ));
294            params.extend(search.collections.into_iter().map(Value::Text));
295        }
296        if let Some(bbox) = search.items.bbox {
297            wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
298            params.push(Value::Text(bbox.to_geometry().to_string()));
299        }
300        if let Some(datetime) = search.items.datetime {
301            let interval = stac::datetime::parse(&datetime)?;
302            if let Some(start) = interval.0 {
303                wheres.push(format!(
304                    "?::TIMESTAMPTZ <= {}",
305                    if has_start_datetime {
306                        "start_datetime"
307                    } else {
308                        "datetime"
309                    }
310                ));
311                params.push(Value::Text(start.to_rfc3339()));
312            }
313            if let Some(end) = interval.1 {
314                wheres.push(format!(
315                    "?::TIMESTAMPTZ >= {}", // Inclusive, https://github.com/radiantearth/stac-spec/pull/1280
316                    if has_end_datetime {
317                        "end_datetime"
318                    } else {
319                        "datetime"
320                    }
321                ));
322                params.push(Value::Text(end.to_rfc3339()));
323            }
324        }
325        if let Some(filter) = search.items.filter {
326            let expr: Expr = filter.try_into()?;
327            if expr_properties_match(&expr, &column_names) {
328                let sql = expr.to_ducksql().map_err(Box::new)?;
329                wheres.push(sql);
330            } else {
331                return Ok(Vec::new());
332            }
333        }
334
335        let mut suffix = String::new();
336        if !wheres.is_empty() {
337            suffix.push_str(&format!(" WHERE {}", wheres.join(" AND ")));
338        }
339        if !order_by.is_empty() {
340            suffix.push_str(&format!(" ORDER BY {}", order_by.join(", ")));
341        }
342        if let Some(limit) = limit {
343            suffix.push_str(&format!(" LIMIT {limit}"));
344        }
345        if let Some(offset) = offset {
346            suffix.push_str(&format!(" OFFSET {offset}"));
347        }
348
349        let sql = format!(
350            "SELECT {} FROM {}{}",
351            columns.join(","),
352            self.format_parquet_href(href),
353            suffix,
354        );
355        log::debug!("duckdb sql: {sql}");
356        let mut statement = self.prepare(&sql)?;
357        statement
358            .query_arrow(duckdb::params_from_iter(params))?
359            .map(|record_batch| {
360                let record_batch = if self.convert_wkb {
361                    stac::geoarrow::with_native_geometry(record_batch, "geometry")?
362                } else {
363                    stac::geoarrow::add_wkb_metadata(record_batch, "geometry")?
364                };
365                Ok(record_batch)
366            })
367            .collect::<Result<_>>()
368    }
369
370    fn format_parquet_href(&self, href: &str) -> String {
371        format!(
372            "read_parquet('{}', filename=true, hive_partitioning={}, union_by_name={})",
373            href,
374            if self.use_hive_partitioning {
375                "true"
376            } else {
377                "false"
378            },
379            if self.union_by_name { "true" } else { "false" }
380        )
381    }
382}
383
384fn expr_properties_match(expr: &Expr, properties: &[String]) -> bool {
385    use Expr::*;
386
387    match expr {
388        Property { property } => properties.contains(property),
389        Float(_) | Literal(_) | Bool(_) | Geometry(_) => true,
390        Operation { args, .. } => args
391            .iter()
392            .all(|expr| expr_properties_match(expr, properties)),
393        Interval { interval } => interval
394            .iter()
395            .all(|expr| expr_properties_match(expr, properties)),
396        Timestamp { timestamp } => expr_properties_match(timestamp, properties),
397        Date { date } => expr_properties_match(date, properties),
398        Array(exprs) => exprs
399            .iter()
400            .all(|expr| expr_properties_match(expr, properties)),
401        BBox { bbox } => bbox
402            .iter()
403            .all(|expr| expr_properties_match(expr, properties)),
404        Null => expr_properties_match(expr, properties),
405    }
406}
407
408impl Deref for Client {
409    type Target = Connection;
410
411    fn deref(&self) -> &Self::Target {
412        &self.connection
413    }
414}
415
416impl DerefMut for Client {
417    fn deref_mut(&mut self) -> &mut Self::Target {
418        &mut self.connection
419    }
420}
421
422impl From<Connection> for Client {
423    fn from(connection: Connection) -> Self {
424        Client {
425            connection,
426            use_hive_partitioning: DEFAULT_USE_HIVE_PARTITIONING,
427            convert_wkb: DEFAULT_CONVERT_WKB,
428            union_by_name: DEFAULT_UNION_BY_NAME,
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::Client;
436    use duckdb::Connection;
437    use geo::Geometry;
438    use rstest::{fixture, rstest};
439    use stac::Bbox;
440    use stac_api::{Items, Search, Sortby};
441    use stac_validate::Validate;
442
443    #[fixture]
444    #[once]
445    fn install_spatial() {
446        let connection = Connection::open_in_memory().unwrap();
447        connection.execute("INSTALL spatial", []).unwrap();
448    }
449
450    #[allow(unused_variables)]
451    #[fixture]
452    fn client(install_spatial: ()) -> Client {
453        Client::new().unwrap()
454    }
455
456    #[allow(unused_variables)]
457    #[rstest]
458    fn new(install_spatial: ()) {
459        Client::new().unwrap();
460    }
461
462    #[rstest]
463    fn extensions(client: Client) {
464        let _ = client.extensions().unwrap();
465    }
466
467    #[rstest]
468    #[tokio::test]
469    async fn search(client: Client) {
470        let item_collection = client
471            .search("data/100-sentinel-2-items.parquet", Search::default())
472            .unwrap();
473        assert_eq!(item_collection.items.len(), 100);
474        item_collection.items[0].validate().await.unwrap();
475    }
476
477    #[rstest]
478    fn search_to_arrow(client: Client) {
479        let record_batches = client
480            .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
481            .unwrap();
482        assert_eq!(record_batches.len(), 1);
483    }
484
485    #[rstest]
486    fn search_ids(client: Client) {
487        let item_collection = client
488            .search(
489                "data/100-sentinel-2-items.parquet",
490                Search::default().ids(vec![
491                    "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429".to_string(),
492                ]),
493            )
494            .unwrap();
495        assert_eq!(item_collection.items.len(), 1);
496        assert_eq!(
497            item_collection.items[0]["id"],
498            "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
499        );
500    }
501
502    #[rstest]
503    fn search_intersects(client: Client) {
504        let item_collection = client
505            .search(
506                "data/100-sentinel-2-items.parquet",
507                Search::default().intersects(&Geometry::Point(geo::point! { x: -106., y: 40.5 })),
508            )
509            .unwrap();
510        assert_eq!(item_collection.items.len(), 50);
511    }
512
513    #[rstest]
514    fn search_collections(client: Client) {
515        let item_collection = client
516            .search(
517                "data/100-sentinel-2-items.parquet",
518                Search::default().collections(vec!["sentinel-2-l2a".to_string()]),
519            )
520            .unwrap();
521        assert_eq!(item_collection.items.len(), 100);
522
523        let item_collection = client
524            .search(
525                "data/100-sentinel-2-items.parquet",
526                Search::default().collections(vec!["foobar".to_string()]),
527            )
528            .unwrap();
529        assert_eq!(item_collection.items.len(), 0);
530    }
531
532    #[rstest]
533    fn search_bbox(client: Client) {
534        let item_collection = client
535            .search(
536                "data/100-sentinel-2-items.parquet",
537                Search::default().bbox(Bbox::new(-106.1, 40.5, -106.0, 40.6)),
538            )
539            .unwrap();
540        assert_eq!(item_collection.items.len(), 50);
541    }
542
543    #[rstest]
544    fn search_datetime(client: Client) {
545        let item_collection = client
546            .search(
547                "data/100-sentinel-2-items.parquet",
548                Search::default().datetime("2024-12-02T00:00:00Z/.."),
549            )
550            .unwrap();
551        assert_eq!(item_collection.items.len(), 1);
552        let item_collection = client
553            .search(
554                "data/100-sentinel-2-items.parquet",
555                Search::default().datetime("../2024-12-02T00:00:00Z"),
556            )
557            .unwrap();
558        assert_eq!(item_collection.items.len(), 99);
559    }
560
561    #[rstest]
562    fn search_datetime_empty_interval(client: Client) {
563        let item_collection = client
564            .search(
565                "data/100-sentinel-2-items.parquet",
566                Search::default().datetime("2024-12-02T00:00:00Z/"),
567            )
568            .unwrap();
569        assert_eq!(item_collection.items.len(), 1);
570    }
571
572    #[rstest]
573    fn search_limit(client: Client) {
574        let item_collection = client
575            .search(
576                "data/100-sentinel-2-items.parquet",
577                Search::default().limit(42),
578            )
579            .unwrap();
580        assert_eq!(item_collection.items.len(), 42);
581    }
582
583    #[rstest]
584    fn search_offset(client: Client) {
585        let mut search = Search::default().limit(1);
586        search
587            .items
588            .additional_fields
589            .insert("offset".to_string(), 1.into());
590        let item_collection = client
591            .search("data/100-sentinel-2-items.parquet", search)
592            .unwrap();
593        assert_eq!(
594            item_collection.items[0]["id"],
595            "S2A_MSIL2A_20241201T175721_R141_T13TDE_20241201T213150"
596        );
597    }
598
599    #[rstest]
600    fn search_sortby(client: Client) {
601        let item_collection = client
602            .search(
603                "data/100-sentinel-2-items.parquet",
604                Search::default()
605                    .sortby(vec![Sortby::asc("datetime")])
606                    .limit(1),
607            )
608            .unwrap();
609        assert_eq!(
610            item_collection.items[0]["id"],
611            "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
612        );
613
614        let item_collection = client
615            .search(
616                "data/100-sentinel-2-items.parquet",
617                Search::default()
618                    .sortby(vec![Sortby::desc("datetime")])
619                    .limit(1),
620            )
621            .unwrap();
622        assert_eq!(
623            item_collection.items[0]["id"],
624            "S2B_MSIL2A_20241203T174629_R098_T13TDE_20241203T211406"
625        );
626    }
627
628    #[rstest]
629    fn search_fields(client: Client) {
630        let item_collection = client
631            .search(
632                "data/100-sentinel-2-items.parquet",
633                Search::default().fields("+id".parse().unwrap()).limit(1),
634            )
635            .unwrap();
636        assert_eq!(item_collection.items[0].len(), 1);
637    }
638
639    #[rstest]
640    fn collections(client: Client) {
641        let collections = client
642            .collections("data/100-sentinel-2-items.parquet")
643            .unwrap();
644        assert_eq!(collections.len(), 1);
645    }
646
647    #[rstest]
648    fn no_convert_wkb(mut client: Client) {
649        client.convert_wkb = false;
650        let record_batches = client
651            .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
652            .unwrap();
653        let schema = record_batches[0].schema();
654        assert_eq!(
655            schema.field_with_name("geometry").unwrap().metadata()["ARROW:extension:name"],
656            "geoarrow.wkb"
657        );
658    }
659
660    #[rstest]
661    fn filter(client: Client) {
662        let search = Search {
663            items: Items {
664                filter: Some("sat:relative_orbit = 98".parse().unwrap()),
665                ..Default::default()
666            },
667            ..Default::default()
668        };
669        let item_collection = client
670            .search("data/100-sentinel-2-items.parquet", search)
671            .unwrap();
672        assert_eq!(item_collection.items.len(), 49);
673    }
674
675    #[rstest]
676    fn filter_no_column(client: Client) {
677        let search = Search {
678            items: Items {
679                filter: Some("foo:bar = 42".parse().unwrap()),
680                ..Default::default()
681            },
682            ..Default::default()
683        };
684        let item_collection = client
685            .search("data/100-sentinel-2-items.parquet", search)
686            .unwrap();
687        assert_eq!(item_collection.items.len(), 0);
688    }
689
690    #[rstest]
691    fn sortby_property(client: Client) {
692        let search = Search {
693            items: Items {
694                sortby: vec!["eo:cloud_cover".parse().unwrap()],
695                ..Default::default()
696            },
697            ..Default::default()
698        };
699        let item_collection = client
700            .search("data/100-sentinel-2-items.parquet", search)
701            .unwrap();
702        assert_eq!(item_collection.items.len(), 100);
703    }
704
705    #[rstest]
706    fn union_by_name(client: Client) {
707        let _ = client.search("data/*.parquet", Default::default()).unwrap();
708    }
709
710    #[rstest]
711    fn no_union_by_name(mut client: Client) {
712        client.union_by_name = false;
713        let _ = client
714            .search("data/*.parquet", Default::default())
715            .unwrap_err();
716    }
717}