stac_duckdb/
client.rs

1use crate::{Error, Extension, Result};
2use arrow_array::{RecordBatch, RecordBatchIterator};
3use chrono::DateTime;
4use cql2::{Expr, ToDuckSQL};
5use duckdb::{Connection, types::Value};
6use geo::BoundingRect;
7use geojson::Geometry;
8use stac::{Collection, SpatialExtent, TemporalExtent, geoarrow::DATETIME_COLUMNS};
9use stac_api::{Direction, Search};
10use std::ops::{Deref, DerefMut};
11
12/// Default hive partitioning value
13pub const DEFAULT_USE_HIVE_PARTITIONING: bool = false;
14
15/// Default convert wkb value.
16pub const DEFAULT_CONVERT_WKB: bool = true;
17
18/// The default collection description.
19pub const DEFAULT_COLLECTION_DESCRIPTION: &str =
20    "Auto-generated collection from stac-geoparquet extents";
21
22/// The default union by name value.
23pub const DEFAULT_UNION_BY_NAME: bool = true;
24
25/// A client for making DuckDB requests for STAC objects.
26#[derive(Debug)]
27pub struct Client {
28    connection: Connection,
29
30    /// Whether to use hive partitioning
31    pub use_hive_partitioning: bool,
32
33    /// Whether to convert WKB to native geometries.
34    ///
35    /// If False, WKB metadata will be added.
36    pub convert_wkb: bool,
37
38    /// Whether to use `union_by_name` when querying.
39    ///
40    /// Defaults to true.
41    pub union_by_name: bool,
42}
43
44impl Client {
45    /// Creates a new client with an in-memory DuckDB connection.
46    ///
47    /// This function will install the spatial extension. If you'd like to
48    /// manage your own extensions (e.g. if your extensions are stored in a
49    /// different location), set things up then use `connection.into()` to get a
50    /// new `Client`.
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use stac_duckdb::Client;
56    ///
57    /// let client = Client::new().unwrap();
58    /// ```
59    pub fn new() -> Result<Client> {
60        let connection = Connection::open_in_memory()?;
61        connection.execute("INSTALL spatial", [])?;
62        connection.execute("LOAD spatial", [])?;
63        connection.execute("INSTALL icu", [])?;
64        connection.execute("LOAD icu", [])?;
65        Ok(connection.into())
66    }
67
68    /// Returns a vector of all extensions.
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use stac_duckdb::Client;
74    ///
75    /// let client = Client::new().unwrap();
76    /// let extensions = client.extensions().unwrap();
77    /// ```
78    pub fn extensions(&self) -> Result<Vec<Extension>> {
79        let mut statement = self.prepare(
80            "SELECT extension_name, loaded, installed, install_path, description, extension_version, install_mode, installed_from FROM duckdb_extensions();",
81        )?;
82        let extensions = statement
83            .query_map([], |row| {
84                Ok(Extension {
85                    name: row.get("extension_name")?,
86                    loaded: row.get("loaded")?,
87                    installed: row.get("installed")?,
88                    install_path: row.get("install_path")?,
89                    description: row.get("description")?,
90                    version: row.get("extension_version")?,
91                    install_mode: row.get("install_mode")?,
92                    installed_from: row.get("installed_from")?,
93                })
94            })?
95            .collect::<std::result::Result<Vec<_>, duckdb::Error>>()?;
96        Ok(extensions)
97    }
98
99    /// Returns one or more [stac::Collection] from the items in the stac-geoparquet file.
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use stac_duckdb::Client;
105    ///
106    /// let client = Client::new().unwrap();
107    /// let collections = client.collections("data/100-sentinel-2-items.parquet").unwrap();
108    /// ```
109    pub fn collections(&self, href: &str) -> Result<Vec<Collection>> {
110        let start_datetime= if self.prepare(&format!(
111            "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'start_datetime'",
112            self.format_parquet_href(href)
113        ))?.query([])?.next()?.is_some() {
114            "strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
115        } else {
116            "strftime(min(datetime), '%xT%X%z')"
117        };
118        let end_datetime = if self
119            .prepare(&format!(
120            "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'",
121            self.format_parquet_href(href)
122        ))?
123            .query([])?
124            .next()?
125            .is_some()
126        {
127            "strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
128        } else {
129            "strftime(max(datetime), '%xT%X%z')"
130        };
131        let mut statement = self.prepare(&format!(
132            "SELECT DISTINCT collection FROM {}",
133            self.format_parquet_href(href)
134        ))?;
135        let mut collections = Vec::new();
136        for row in statement.query_map([], |row| row.get::<_, String>(0))? {
137            let collection_id = row?;
138            let mut statement = self.connection.prepare(&
139                format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1", start_datetime, end_datetime,
140                self.format_parquet_href(href)
141            ))?;
142            let row = statement.query_row([&collection_id], |row| {
143                Ok((
144                    row.get::<_, String>(0)?,
145                    row.get::<_, String>(1)?,
146                    row.get::<_, String>(2)?,
147                ))
148            })?;
149            let mut collection = Collection::new(collection_id, DEFAULT_COLLECTION_DESCRIPTION);
150            let geometry: geo::Geometry = Geometry::from_json_value(serde_json::from_str(&row.0)?)
151                .map_err(Box::new)?
152                .try_into()
153                .map_err(Box::new)?;
154            if let Some(bbox) = geometry.bounding_rect() {
155                collection.extent.spatial = SpatialExtent {
156                    bbox: vec![bbox.into()],
157                };
158            }
159            collection.extent.temporal = TemporalExtent {
160                interval: vec![[
161                    Some(DateTime::parse_from_str(&row.1, "%FT%T%#z")?.into()),
162                    Some(DateTime::parse_from_str(&row.2, "%FT%T%#z")?.into()),
163                ]],
164            };
165            collections.push(collection);
166        }
167        Ok(collections)
168    }
169
170    /// Searches a single stac-geoparquet file.
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// use stac_duckdb::Client;
176    ///
177    /// let client = Client::new().unwrap();
178    /// let item_collection = client.search("data/100-sentinel-2-items.parquet", Default::default()).unwrap();
179    /// ```
180    pub fn search(&self, href: &str, search: Search) -> Result<stac_api::ItemCollection> {
181        let record_batches = self.search_to_arrow(href, search)?;
182        if record_batches.is_empty() {
183            Ok(Default::default())
184        } else {
185            let schema = record_batches[0].schema();
186            let item_collection = stac::geoarrow::json::from_record_batch_reader(
187                RecordBatchIterator::new(record_batches.into_iter().map(Ok), schema),
188            )?;
189            Ok(item_collection.into())
190        }
191    }
192
193    /// Searches to an iterator of record batches.
194    ///
195    /// # Examples
196    ///
197    /// ```
198    /// use stac_duckdb::Client;
199    ///
200    /// let client = Client::new().unwrap();
201    /// let record_batches = client.search_to_arrow("data/100-sentinel-2-items.parquet", Default::default()).unwrap();
202    /// ```
203    pub fn search_to_arrow(&self, href: &str, search: Search) -> Result<Vec<RecordBatch>> {
204        // TODO can we return an iterator?
205        if let Some((sql, params)) = self.build_query(href, search)? {
206            log::debug!("duckdb sql: {sql}");
207            let mut statement = self.prepare(&sql)?;
208            statement
209                .query_arrow(duckdb::params_from_iter(params))?
210                .map(|record_batch| {
211                    let record_batch = if self.convert_wkb {
212                        stac::geoarrow::with_native_geometry(record_batch, "geometry")?
213                    } else {
214                        stac::geoarrow::add_wkb_metadata(record_batch, "geometry")?
215                    };
216                    Ok(record_batch)
217                })
218                .collect::<Result<_>>()
219        } else {
220            Ok(Vec::new())
221        }
222    }
223
224    /// Returns the SQL query string and parameters for this href and search object.
225    ///
226    /// Returns `None` if we can _know_ that the query will return nothing.
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// use stac_duckdb::Client;
232    ///
233    /// let client = Client::new().unwrap();
234    /// let (sql, params) = client.build_query("data/100-sentinel-2-items.parquet", Default::default()).unwrap().unwrap();
235    /// ```
236    pub fn build_query(&self, href: &str, search: Search) -> Result<Option<(String, Vec<Value>)>> {
237        // Note that we pull out some fields early so we can avoid closing some search strings below.
238
239        if search.items.query.is_some() {
240            return Err(Error::QueryNotImplemented);
241        }
242
243        // Check which columns we'll be selecting
244        let mut statement = self.prepare(&format!(
245            "SELECT column_name FROM (DESCRIBE SELECT * from {})",
246            self.format_parquet_href(href)
247        ))?;
248        let mut has_start_datetime = false;
249        let mut has_end_datetime = false;
250        let mut column_names = Vec::new();
251        let mut columns = Vec::new();
252        for row in statement.query_map([], |row| row.get::<_, String>(0))? {
253            let column = row?;
254            if column == "start_datetime" {
255                has_start_datetime = true;
256            }
257            if column == "end_datetime" {
258                has_end_datetime = true;
259            }
260
261            if let Some(fields) = search.fields.as_ref() {
262                if fields.exclude.contains(&column)
263                    || !(fields.include.is_empty() || fields.include.contains(&column))
264                {
265                    continue;
266                }
267            }
268
269            if column == "geometry" {
270                columns.push("ST_AsWKB(geometry) geometry".to_string());
271            } else if DATETIME_COLUMNS.contains(&column.as_str()) {
272                columns.push(format!("\"{column}\"::TIMESTAMPTZ {column}"))
273            } else {
274                columns.push(format!("\"{column}\""));
275            }
276            column_names.push(column);
277        }
278
279        // Get limit and offset
280        let limit = search.items.limit;
281        let offset = search
282            .items
283            .additional_fields
284            .get("offset")
285            .and_then(|v| v.as_i64());
286
287        // Build order_by
288        let mut order_by = Vec::with_capacity(search.sortby.len());
289        for sortby in &search.sortby {
290            order_by.push(format!(
291                "\"{}\" {}",
292                sortby.field,
293                match sortby.direction {
294                    Direction::Ascending => "ASC",
295                    Direction::Descending => "DESC",
296                }
297            ));
298        }
299
300        // Build wheres and params
301        let mut wheres = Vec::new();
302        let mut params = Vec::new();
303        if !search.ids.is_empty() {
304            wheres.push(format!(
305                "id IN ({})",
306                (0..search.ids.len())
307                    .map(|_| "?")
308                    .collect::<Vec<_>>()
309                    .join(",")
310            ));
311            params.extend(search.ids.into_iter().map(Value::Text));
312        }
313        if let Some(intersects) = search.intersects {
314            wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
315            params.push(Value::Text(intersects.to_string()));
316        }
317        if !search.collections.is_empty() {
318            wheres.push(format!(
319                "collection IN ({})",
320                (0..search.collections.len())
321                    .map(|_| "?")
322                    .collect::<Vec<_>>()
323                    .join(",")
324            ));
325            params.extend(search.collections.into_iter().map(Value::Text));
326        }
327        if let Some(bbox) = search.items.bbox {
328            wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
329            params.push(Value::Text(bbox.to_geometry().to_string()));
330        }
331        if let Some(datetime) = search.items.datetime {
332            let interval = stac::datetime::parse(&datetime)?;
333            if let Some(start) = interval.0 {
334                wheres.push(format!(
335                    "?::TIMESTAMPTZ <= {}",
336                    if has_start_datetime {
337                        "start_datetime"
338                    } else {
339                        "datetime"
340                    }
341                ));
342                params.push(Value::Text(start.to_rfc3339()));
343            }
344            if let Some(end) = interval.1 {
345                wheres.push(format!(
346                    "?::TIMESTAMPTZ >= {}", // Inclusive, https://github.com/radiantearth/stac-spec/pull/1280
347                    if has_end_datetime {
348                        "end_datetime"
349                    } else {
350                        "datetime"
351                    }
352                ));
353                params.push(Value::Text(end.to_rfc3339()));
354            }
355        }
356        if let Some(filter) = search.items.filter {
357            let expr: Expr = filter.try_into()?;
358            if expr_properties_match(&expr, &column_names) {
359                let sql = expr.to_ducksql().map_err(Box::new)?;
360                wheres.push(sql);
361            } else {
362                return Ok(None);
363            }
364        }
365
366        let mut suffix = String::new();
367        if !wheres.is_empty() {
368            suffix.push_str(&format!(" WHERE {}", wheres.join(" AND ")));
369        }
370        if !order_by.is_empty() {
371            suffix.push_str(&format!(" ORDER BY {}", order_by.join(", ")));
372        }
373        if let Some(limit) = limit {
374            suffix.push_str(&format!(" LIMIT {limit}"));
375        }
376        if let Some(offset) = offset {
377            suffix.push_str(&format!(" OFFSET {offset}"));
378        }
379
380        let sql = format!(
381            "SELECT {} FROM {}{}",
382            columns.join(","),
383            self.format_parquet_href(href),
384            suffix,
385        );
386        Ok(Some((sql, params)))
387    }
388
389    fn format_parquet_href(&self, href: &str) -> String {
390        format!(
391            "read_parquet('{}', filename=true, hive_partitioning={}, union_by_name={})",
392            href,
393            if self.use_hive_partitioning {
394                "true"
395            } else {
396                "false"
397            },
398            if self.union_by_name { "true" } else { "false" }
399        )
400    }
401}
402
403fn expr_properties_match(expr: &Expr, properties: &[String]) -> bool {
404    use Expr::*;
405
406    match expr {
407        Property { property } => properties.contains(property),
408        Float(_) | Literal(_) | Bool(_) | Geometry(_) => true,
409        Operation { args, .. } => args
410            .iter()
411            .all(|expr| expr_properties_match(expr, properties)),
412        Interval { interval } => interval
413            .iter()
414            .all(|expr| expr_properties_match(expr, properties)),
415        Timestamp { timestamp } => expr_properties_match(timestamp, properties),
416        Date { date } => expr_properties_match(date, properties),
417        Array(exprs) => exprs
418            .iter()
419            .all(|expr| expr_properties_match(expr, properties)),
420        BBox { bbox } => bbox
421            .iter()
422            .all(|expr| expr_properties_match(expr, properties)),
423        Null => expr_properties_match(expr, properties),
424    }
425}
426
427impl Deref for Client {
428    type Target = Connection;
429
430    fn deref(&self) -> &Self::Target {
431        &self.connection
432    }
433}
434
435impl DerefMut for Client {
436    fn deref_mut(&mut self) -> &mut Self::Target {
437        &mut self.connection
438    }
439}
440
441impl From<Connection> for Client {
442    fn from(connection: Connection) -> Self {
443        Client {
444            connection,
445            use_hive_partitioning: DEFAULT_USE_HIVE_PARTITIONING,
446            convert_wkb: DEFAULT_CONVERT_WKB,
447            union_by_name: DEFAULT_UNION_BY_NAME,
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::Client;
455    use duckdb::Connection;
456    use geo::Geometry;
457    use rstest::{fixture, rstest};
458    use stac::Bbox;
459    use stac_api::{Items, Search, Sortby};
460    use stac_validate::Validate;
461
462    #[fixture]
463    #[once]
464    fn install_extensions() {
465        let connection = Connection::open_in_memory().unwrap();
466        connection.execute("INSTALL icu", []).unwrap();
467        connection.execute("INSTALL spatial", []).unwrap();
468    }
469
470    #[allow(unused_variables)]
471    #[fixture]
472    fn client(install_extensions: ()) -> Client {
473        Client::new().unwrap()
474    }
475
476    #[rstest]
477    fn extensions(client: Client) {
478        let _ = client.extensions().unwrap();
479    }
480
481    #[rstest]
482    #[tokio::test]
483    async fn search(client: Client) {
484        let item_collection = client
485            .search("data/100-sentinel-2-items.parquet", Search::default())
486            .unwrap();
487        assert_eq!(item_collection.items.len(), 100);
488        item_collection.items[0].validate().await.unwrap();
489    }
490
491    #[rstest]
492    fn search_to_arrow(client: Client) {
493        let record_batches = client
494            .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
495            .unwrap();
496        assert_eq!(record_batches.len(), 1);
497    }
498
499    #[rstest]
500    fn search_ids(client: Client) {
501        let item_collection = client
502            .search(
503                "data/100-sentinel-2-items.parquet",
504                Search::default().ids(vec![
505                    "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429".to_string(),
506                ]),
507            )
508            .unwrap();
509        assert_eq!(item_collection.items.len(), 1);
510        assert_eq!(
511            item_collection.items[0]["id"],
512            "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
513        );
514    }
515
516    #[rstest]
517    fn search_intersects(client: Client) {
518        let item_collection = client
519            .search(
520                "data/100-sentinel-2-items.parquet",
521                Search::default().intersects(&Geometry::Point(geo::point! { x: -106., y: 40.5 })),
522            )
523            .unwrap();
524        assert_eq!(item_collection.items.len(), 50);
525    }
526
527    #[rstest]
528    fn search_collections(client: Client) {
529        let item_collection = client
530            .search(
531                "data/100-sentinel-2-items.parquet",
532                Search::default().collections(vec!["sentinel-2-l2a".to_string()]),
533            )
534            .unwrap();
535        assert_eq!(item_collection.items.len(), 100);
536
537        let item_collection = client
538            .search(
539                "data/100-sentinel-2-items.parquet",
540                Search::default().collections(vec!["foobar".to_string()]),
541            )
542            .unwrap();
543        assert_eq!(item_collection.items.len(), 0);
544    }
545
546    #[rstest]
547    fn search_bbox(client: Client) {
548        let item_collection = client
549            .search(
550                "data/100-sentinel-2-items.parquet",
551                Search::default().bbox(Bbox::new(-106.1, 40.5, -106.0, 40.6)),
552            )
553            .unwrap();
554        assert_eq!(item_collection.items.len(), 50);
555    }
556
557    #[rstest]
558    fn search_datetime(client: Client) {
559        let item_collection = client
560            .search(
561                "data/100-sentinel-2-items.parquet",
562                Search::default().datetime("2024-12-02T00:00:00Z/.."),
563            )
564            .unwrap();
565        assert_eq!(item_collection.items.len(), 1);
566        let item_collection = client
567            .search(
568                "data/100-sentinel-2-items.parquet",
569                Search::default().datetime("../2024-12-02T00:00:00Z"),
570            )
571            .unwrap();
572        assert_eq!(item_collection.items.len(), 99);
573    }
574
575    #[rstest]
576    fn search_datetime_empty_interval(client: Client) {
577        let item_collection = client
578            .search(
579                "data/100-sentinel-2-items.parquet",
580                Search::default().datetime("2024-12-02T00:00:00Z/"),
581            )
582            .unwrap();
583        assert_eq!(item_collection.items.len(), 1);
584    }
585
586    #[rstest]
587    fn search_limit(client: Client) {
588        let item_collection = client
589            .search(
590                "data/100-sentinel-2-items.parquet",
591                Search::default().limit(42),
592            )
593            .unwrap();
594        assert_eq!(item_collection.items.len(), 42);
595    }
596
597    #[rstest]
598    fn search_offset(client: Client) {
599        let mut search = Search::default().limit(1);
600        search
601            .items
602            .additional_fields
603            .insert("offset".to_string(), 1.into());
604        let item_collection = client
605            .search("data/100-sentinel-2-items.parquet", search)
606            .unwrap();
607        assert_eq!(
608            item_collection.items[0]["id"],
609            "S2A_MSIL2A_20241201T175721_R141_T13TDE_20241201T213150"
610        );
611    }
612
613    #[rstest]
614    fn search_sortby(client: Client) {
615        let item_collection = client
616            .search(
617                "data/100-sentinel-2-items.parquet",
618                Search::default()
619                    .sortby(vec![Sortby::asc("datetime")])
620                    .limit(1),
621            )
622            .unwrap();
623        assert_eq!(
624            item_collection.items[0]["id"],
625            "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
626        );
627
628        let item_collection = client
629            .search(
630                "data/100-sentinel-2-items.parquet",
631                Search::default()
632                    .sortby(vec![Sortby::desc("datetime")])
633                    .limit(1),
634            )
635            .unwrap();
636        assert_eq!(
637            item_collection.items[0]["id"],
638            "S2B_MSIL2A_20241203T174629_R098_T13TDE_20241203T211406"
639        );
640    }
641
642    #[rstest]
643    fn search_fields(client: Client) {
644        let item_collection = client
645            .search(
646                "data/100-sentinel-2-items.parquet",
647                Search::default().fields("+id".parse().unwrap()).limit(1),
648            )
649            .unwrap();
650        assert_eq!(item_collection.items[0].len(), 1);
651    }
652
653    #[rstest]
654    fn collections(client: Client) {
655        let collections = client
656            .collections("data/100-sentinel-2-items.parquet")
657            .unwrap();
658        assert_eq!(collections.len(), 1);
659    }
660
661    #[rstest]
662    fn no_convert_wkb(mut client: Client) {
663        client.convert_wkb = false;
664        let record_batches = client
665            .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
666            .unwrap();
667        let schema = record_batches[0].schema();
668        assert_eq!(
669            schema.field_with_name("geometry").unwrap().metadata()["ARROW:extension:name"],
670            "geoarrow.wkb"
671        );
672    }
673
674    #[rstest]
675    fn filter(client: Client) {
676        let search = Search {
677            items: Items {
678                filter: Some("sat:relative_orbit = 98".parse().unwrap()),
679                ..Default::default()
680            },
681            ..Default::default()
682        };
683        let item_collection = client
684            .search("data/100-sentinel-2-items.parquet", search)
685            .unwrap();
686        assert_eq!(item_collection.items.len(), 49);
687    }
688
689    #[rstest]
690    fn filter_no_column(client: Client) {
691        let search = Search {
692            items: Items {
693                filter: Some("foo:bar = 42".parse().unwrap()),
694                ..Default::default()
695            },
696            ..Default::default()
697        };
698        let item_collection = client
699            .search("data/100-sentinel-2-items.parquet", search)
700            .unwrap();
701        assert_eq!(item_collection.items.len(), 0);
702    }
703
704    #[rstest]
705    fn sortby_property(client: Client) {
706        let search = Search {
707            items: Items {
708                sortby: vec!["eo:cloud_cover".parse().unwrap()],
709                ..Default::default()
710            },
711            ..Default::default()
712        };
713        let item_collection = client
714            .search("data/100-sentinel-2-items.parquet", search)
715            .unwrap();
716        assert_eq!(item_collection.items.len(), 100);
717    }
718
719    #[rstest]
720    fn union_by_name(client: Client) {
721        let _ = client.search("data/*.parquet", Default::default()).unwrap();
722    }
723
724    #[rstest]
725    fn no_union_by_name(mut client: Client) {
726        client.union_by_name = false;
727        let _ = client
728            .search("data/*.parquet", Default::default())
729            .unwrap_err();
730    }
731}