1use crate::{Error, Extension, Result};
2use arrow_array::{RecordBatch, RecordBatchIterator};
3use arrow_schema::{ArrowError, SchemaRef};
4use chrono::DateTime;
5use cql2::{Expr, ToDuckSQL};
6use duckdb::{Connection, Statement, types::Value};
7use geo::BoundingRect;
8use geojson::Geometry;
9use stac::api::{Direction, Search};
10use stac::{Collection, SpatialExtent, TemporalExtent, geoarrow::DATETIME_COLUMNS};
11use std::ops::{Deref, DerefMut};
12
13pub const DEFAULT_USE_HIVE_PARTITIONING: bool = false;
15
16pub const DEFAULT_CONVERT_WKB: bool = true;
18
19pub const DEFAULT_COLLECTION_DESCRIPTION: &str =
21 "Auto-generated collection from stac-geoparquet extents";
22
23pub const DEFAULT_UNION_BY_NAME: bool = true;
25
26pub const DEFAULT_REMOVE_FILENAME_COLUMN: bool = true;
28
29#[derive(Debug)]
31pub struct Client {
32 connection: Connection,
33
34 pub use_hive_partitioning: bool,
36
37 pub convert_wkb: bool,
41
42 pub union_by_name: bool,
46
47 pub remove_filename_column: bool,
51}
52
53impl Client {
54 pub fn new() -> Result<Client> {
69 let connection = Connection::open_in_memory()?;
70 connection.execute("INSTALL spatial", [])?;
71 connection.execute("LOAD spatial", [])?;
72 connection.execute("INSTALL icu", [])?;
73 connection.execute("LOAD icu", [])?;
74 Ok(connection.into())
75 }
76
77 pub fn extensions(&self) -> Result<Vec<Extension>> {
88 let mut statement = self.prepare(
89 "SELECT extension_name, loaded, installed, install_path, description, extension_version, install_mode, installed_from FROM duckdb_extensions();",
90 )?;
91 let extensions = statement
92 .query_map([], |row| {
93 Ok(Extension {
94 name: row.get("extension_name")?,
95 loaded: row.get("loaded")?,
96 installed: row.get("installed")?,
97 install_path: row.get("install_path")?,
98 description: row.get("description")?,
99 version: row.get("extension_version")?,
100 install_mode: row.get("install_mode")?,
101 installed_from: row.get("installed_from")?,
102 })
103 })?
104 .collect::<std::result::Result<Vec<_>, duckdb::Error>>()?;
105 Ok(extensions)
106 }
107
108 pub fn collections(&self, href: &str) -> Result<Vec<Collection>> {
119 let start_datetime= if self.prepare(&format!(
120 "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'start_datetime'",
121 self.format_parquet_href(href)
122 ))?.query([])?.next()?.is_some() {
123 "strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
124 } else {
125 "strftime(min(datetime), '%xT%X%z')"
126 };
127 let end_datetime = if self
128 .prepare(&format!(
129 "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'",
130 self.format_parquet_href(href)
131 ))?
132 .query([])?
133 .next()?
134 .is_some()
135 {
136 "strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
137 } else {
138 "strftime(max(datetime), '%xT%X%z')"
139 };
140 let mut statement = self.prepare(&format!(
141 "SELECT DISTINCT collection FROM {}",
142 self.format_parquet_href(href)
143 ))?;
144 let mut collections = Vec::new();
145 for row in statement.query_map([], |row| row.get::<_, String>(0))? {
146 let collection_id = row?;
147 let mut statement = self.connection.prepare(&
148 format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1", start_datetime, end_datetime,
149 self.format_parquet_href(href)
150 ))?;
151 let row = statement.query_row([&collection_id], |row| {
152 Ok((
153 row.get::<_, String>(0)?,
154 row.get::<_, String>(1)?,
155 row.get::<_, String>(2)?,
156 ))
157 })?;
158 let mut collection = Collection::new(collection_id, DEFAULT_COLLECTION_DESCRIPTION);
159 let geometry: geo::Geometry = Geometry::from_json_value(serde_json::from_str(&row.0)?)
160 .map_err(Box::new)?
161 .try_into()
162 .map_err(Box::new)?;
163 if let Some(bbox) = geometry.bounding_rect() {
164 collection.extent.spatial = SpatialExtent {
165 bbox: vec![bbox.into()],
166 };
167 }
168 collection.extent.temporal = TemporalExtent {
169 interval: vec![[
170 Some(DateTime::parse_from_str(&row.1, "%FT%T%#z")?.into()),
171 Some(DateTime::parse_from_str(&row.2, "%FT%T%#z")?.into()),
172 ]],
173 };
174 collections.push(collection);
175 }
176 Ok(collections)
177 }
178
179 pub fn search(&self, href: &str, search: Search) -> Result<stac::api::ItemCollection> {
190 let mut arrow_iter = self.search_to_arrow(href, search)?;
191 let Some(schema) = arrow_iter.schema() else {
192 return Ok(Default::default());
193 };
194
195 let first_batch = match arrow_iter.next() {
196 Some(batch) => batch?,
197 None => return Ok(Default::default()),
198 };
199
200 let batches = std::iter::once(Ok(first_batch))
201 .chain(arrow_iter)
202 .map(|batch| batch.map_err(|err| ArrowError::ExternalError(Box::new(err))));
203
204 let item_collection = stac::geoarrow::json::from_record_batch_reader(
205 RecordBatchIterator::new(batches, schema),
206 )?;
207 Ok(item_collection.into())
208 }
209
210 pub fn search_to_arrow<'conn>(
229 &'conn self,
230 href: &str,
231 search: Search,
232 ) -> Result<SearchArrowBatchIter<'conn>> {
233 if let Some((sql, params)) = self.build_query(href, search)? {
234 log::debug!("duckdb sql: {sql}");
235 let mut statement = self.prepare(&sql)?;
236 statement.execute(duckdb::params_from_iter(params))?;
237 Ok(SearchArrowBatchIter::new(
238 statement,
239 self.convert_wkb,
240 self.remove_filename_column,
241 ))
242 } else {
243 Ok(SearchArrowBatchIter::empty(
244 self.convert_wkb,
245 self.remove_filename_column,
246 ))
247 }
248 }
249
250 pub fn build_query(&self, href: &str, search: Search) -> Result<Option<(String, Vec<Value>)>> {
263 if search.items.query.is_some() {
266 return Err(Error::QueryNotImplemented);
267 }
268
269 let mut statement = self.prepare(&format!(
271 "SELECT column_name FROM (DESCRIBE SELECT * from {})",
272 self.format_parquet_href(href)
273 ))?;
274 let mut has_start_datetime = false;
275 let mut has_end_datetime = false;
276 let mut column_names = Vec::new();
277 let mut columns = Vec::new();
278 for row in statement.query_map([], |row| row.get::<_, String>(0))? {
279 let column = row?;
280 if column == "start_datetime" {
281 has_start_datetime = true;
282 }
283 if column == "end_datetime" {
284 has_end_datetime = true;
285 }
286
287 if let Some(fields) = search.fields.as_ref() {
288 if fields.exclude.contains(&column)
289 || !(fields.include.is_empty() || fields.include.contains(&column))
290 {
291 continue;
292 }
293 }
294
295 if column == "geometry" {
296 columns.push("ST_AsWKB(geometry) geometry".to_string());
297 } else if DATETIME_COLUMNS.contains(&column.as_str()) {
298 columns.push(format!("\"{column}\"::TIMESTAMPTZ {column}"))
299 } else {
300 columns.push(format!("\"{column}\""));
301 }
302 column_names.push(column);
303 }
304
305 let limit = search.items.limit;
307 let offset = search
308 .items
309 .additional_fields
310 .get("offset")
311 .and_then(|v| v.as_i64());
312
313 let mut order_by = Vec::with_capacity(search.sortby.len());
315 for sortby in &search.sortby {
316 order_by.push(format!(
317 "\"{}\" {}",
318 sortby.field,
319 match sortby.direction {
320 Direction::Ascending => "ASC",
321 Direction::Descending => "DESC",
322 }
323 ));
324 }
325
326 let mut wheres = Vec::new();
328 let mut params = Vec::new();
329 if !search.ids.is_empty() {
330 wheres.push(format!(
331 "id IN ({})",
332 (0..search.ids.len())
333 .map(|_| "?")
334 .collect::<Vec<_>>()
335 .join(",")
336 ));
337 params.extend(search.ids.into_iter().map(Value::Text));
338 }
339 if let Some(intersects) = search.intersects {
340 wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
341 params.push(Value::Text(intersects.to_string()));
342 }
343 if !search.collections.is_empty() {
344 wheres.push(format!(
345 "collection IN ({})",
346 (0..search.collections.len())
347 .map(|_| "?")
348 .collect::<Vec<_>>()
349 .join(",")
350 ));
351 params.extend(search.collections.into_iter().map(Value::Text));
352 }
353 if let Some(bbox) = search.items.bbox {
354 wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
355 params.push(Value::Text(bbox.to_geometry().to_string()));
356 }
357 if let Some(datetime) = search.items.datetime {
358 let interval = stac::datetime::parse(&datetime)?;
359 if let Some(start) = interval.0 {
360 wheres.push(format!(
361 "?::TIMESTAMPTZ <= {}",
362 if has_start_datetime {
363 "start_datetime"
364 } else {
365 "datetime"
366 }
367 ));
368 params.push(Value::Text(start.to_rfc3339()));
369 }
370 if let Some(end) = interval.1 {
371 wheres.push(format!(
372 "?::TIMESTAMPTZ >= {}", if has_end_datetime {
374 "end_datetime"
375 } else {
376 "datetime"
377 }
378 ));
379 params.push(Value::Text(end.to_rfc3339()));
380 }
381 }
382 if let Some(filter) = search.items.filter {
383 let expr: Expr = filter.try_into()?;
384 if expr_properties_match(&expr, &column_names) {
385 let sql = expr.to_ducksql().map_err(Box::new)?;
386 wheres.push(sql);
387 } else {
388 return Ok(None);
389 }
390 }
391
392 let mut suffix = String::new();
393 if !wheres.is_empty() {
394 suffix.push_str(&format!(" WHERE {}", wheres.join(" AND ")));
395 }
396 if !order_by.is_empty() {
397 suffix.push_str(&format!(" ORDER BY {}", order_by.join(", ")));
398 }
399 if let Some(limit) = limit {
400 suffix.push_str(&format!(" LIMIT {limit}"));
401 }
402 if let Some(offset) = offset {
403 suffix.push_str(&format!(" OFFSET {offset}"));
404 }
405
406 let sql = format!(
407 "SELECT {} FROM {}{}",
408 columns.join(","),
409 self.format_parquet_href(href),
410 suffix,
411 );
412 Ok(Some((sql, params)))
413 }
414
415 fn format_parquet_href(&self, href: &str) -> String {
416 format!(
417 "read_parquet('{}', hive_partitioning={}, union_by_name={})",
418 href,
419 if self.use_hive_partitioning {
420 "true"
421 } else {
422 "false"
423 },
424 if self.union_by_name { "true" } else { "false" }
425 )
426 }
427}
428
429fn expr_properties_match(expr: &Expr, properties: &[String]) -> bool {
430 use Expr::*;
431
432 match expr {
433 Property { property } => properties.contains(property),
434 Float(_) | Literal(_) | Bool(_) | Geometry(_) => true,
435 Operation { args, .. } => args
436 .iter()
437 .all(|expr| expr_properties_match(expr, properties)),
438 Interval { interval } => interval
439 .iter()
440 .all(|expr| expr_properties_match(expr, properties)),
441 Timestamp { timestamp } => expr_properties_match(timestamp, properties),
442 Date { date } => expr_properties_match(date, properties),
443 Array(exprs) => exprs
444 .iter()
445 .all(|expr| expr_properties_match(expr, properties)),
446 BBox { bbox } => bbox
447 .iter()
448 .all(|expr| expr_properties_match(expr, properties)),
449 Null => expr_properties_match(expr, properties),
450 }
451}
452
453impl Deref for Client {
454 type Target = Connection;
455
456 fn deref(&self) -> &Self::Target {
457 &self.connection
458 }
459}
460
461impl DerefMut for Client {
462 fn deref_mut(&mut self) -> &mut Self::Target {
463 &mut self.connection
464 }
465}
466
467impl From<Connection> for Client {
468 fn from(connection: Connection) -> Self {
469 Client {
470 connection,
471 use_hive_partitioning: DEFAULT_USE_HIVE_PARTITIONING,
472 convert_wkb: DEFAULT_CONVERT_WKB,
473 union_by_name: DEFAULT_UNION_BY_NAME,
474 remove_filename_column: DEFAULT_REMOVE_FILENAME_COLUMN,
475 }
476 }
477}
478
479pub struct SearchArrowBatchIter<'conn> {
481 statement: Option<Statement<'conn>>,
482 convert_wkb: bool,
483 remove_filename_column: bool,
484 schema: Option<SchemaRef>,
485}
486
487impl<'conn> SearchArrowBatchIter<'conn> {
488 fn new(statement: Statement<'conn>, convert_wkb: bool, remove_filename_column: bool) -> Self {
489 let schema = Some(statement.schema());
490 Self {
491 statement: Some(statement),
492 convert_wkb,
493 remove_filename_column,
494 schema,
495 }
496 }
497
498 fn empty(convert_wkb: bool, remove_filename_column: bool) -> Self {
499 Self {
500 statement: None,
501 convert_wkb,
502 remove_filename_column,
503 schema: None,
504 }
505 }
506
507 pub fn schema(&self) -> Option<SchemaRef> {
508 self.schema.clone()
509 }
510
511 fn finalize_batch(&self, record_batch: RecordBatch) -> Result<RecordBatch> {
512 let mut record_batch = if self.convert_wkb {
513 stac::geoarrow::with_native_geometry(record_batch, "geometry")?
514 } else {
515 stac::geoarrow::add_wkb_metadata(record_batch, "geometry")?
516 };
517 if self.remove_filename_column {
518 record_batch = remove_column(record_batch, "filename");
519 }
520 Ok(record_batch)
521 }
522}
523
524impl<'conn> Iterator for SearchArrowBatchIter<'conn> {
525 type Item = Result<RecordBatch>;
526
527 fn next(&mut self) -> Option<Self::Item> {
528 let statement = self.statement.as_ref()?;
529
530 match statement.step() {
531 Some(struct_array) => {
532 let record_batch = RecordBatch::from(&struct_array);
533 match self.finalize_batch(record_batch) {
534 Ok(batch) => Some(Ok(batch)),
535 Err(err) => {
536 self.statement = None;
537 Some(Err(err))
538 }
539 }
540 }
541 None => {
542 self.statement = None;
543 None
544 }
545 }
546 }
547}
548
549fn remove_column(mut record_batch: RecordBatch, name: &str) -> RecordBatch {
550 if let Some((index, _)) = record_batch.schema().column_with_name(name) {
551 record_batch.remove_column(index);
552 }
553 record_batch
554}
555
556#[cfg(test)]
557mod tests {
558 use super::Client;
559 use duckdb::Connection;
560 use geo::Geometry;
561 use rstest::{fixture, rstest};
562 use stac::Bbox;
563 use stac::api::{Items, Search, Sortby};
564 use stac_validate::Validate;
565
566 #[fixture]
567 #[once]
568 fn install_extensions() {
569 let connection = Connection::open_in_memory().unwrap();
570 connection.execute("INSTALL icu", []).unwrap();
571 connection.execute("INSTALL spatial", []).unwrap();
572 }
573
574 #[allow(unused_variables)]
575 #[fixture]
576 fn client(install_extensions: ()) -> Client {
577 Client::new().unwrap()
578 }
579
580 #[rstest]
581 fn extensions(client: Client) {
582 let _ = client.extensions().unwrap();
583 }
584
585 #[rstest]
586 #[tokio::test]
587 async fn search(client: Client) {
588 let item_collection = client
589 .search("data/100-sentinel-2-items.parquet", Search::default())
590 .unwrap();
591 assert_eq!(item_collection.items.len(), 100);
592 item_collection.items[0].validate().await.unwrap();
593 }
594
595 #[rstest]
596 fn search_to_arrow(client: Client) {
597 let record_batches = client
598 .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
599 .unwrap()
600 .collect::<std::result::Result<Vec<_>, _>>()
601 .unwrap();
602 assert_eq!(record_batches.len(), 1);
603 }
604
605 #[rstest]
606 fn search_ids(client: Client) {
607 let item_collection = client
608 .search(
609 "data/100-sentinel-2-items.parquet",
610 Search::default().ids(vec![
611 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429".to_string(),
612 ]),
613 )
614 .unwrap();
615 assert_eq!(item_collection.items.len(), 1);
616 assert_eq!(
617 item_collection.items[0]["id"],
618 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
619 );
620 }
621
622 #[rstest]
623 fn search_intersects(client: Client) {
624 let item_collection = client
625 .search(
626 "data/100-sentinel-2-items.parquet",
627 Search::default().intersects(&Geometry::Point(geo::point! { x: -106., y: 40.5 })),
628 )
629 .unwrap();
630 assert_eq!(item_collection.items.len(), 50);
631 }
632
633 #[rstest]
634 fn search_collections(client: Client) {
635 let item_collection = client
636 .search(
637 "data/100-sentinel-2-items.parquet",
638 Search::default().collections(vec!["sentinel-2-l2a".to_string()]),
639 )
640 .unwrap();
641 assert_eq!(item_collection.items.len(), 100);
642
643 let item_collection = client
644 .search(
645 "data/100-sentinel-2-items.parquet",
646 Search::default().collections(vec!["foobar".to_string()]),
647 )
648 .unwrap();
649 assert_eq!(item_collection.items.len(), 0);
650 }
651
652 #[rstest]
653 fn search_bbox(client: Client) {
654 let item_collection = client
655 .search(
656 "data/100-sentinel-2-items.parquet",
657 Search::default().bbox(Bbox::new(-106.1, 40.5, -106.0, 40.6)),
658 )
659 .unwrap();
660 assert_eq!(item_collection.items.len(), 50);
661 }
662
663 #[rstest]
664 fn search_datetime(client: Client) {
665 let item_collection = client
666 .search(
667 "data/100-sentinel-2-items.parquet",
668 Search::default().datetime("2024-12-02T00:00:00Z/.."),
669 )
670 .unwrap();
671 assert_eq!(item_collection.items.len(), 1);
672 let item_collection = client
673 .search(
674 "data/100-sentinel-2-items.parquet",
675 Search::default().datetime("../2024-12-02T00:00:00Z"),
676 )
677 .unwrap();
678 assert_eq!(item_collection.items.len(), 99);
679 }
680
681 #[rstest]
682 fn search_datetime_empty_interval(client: Client) {
683 let item_collection = client
684 .search(
685 "data/100-sentinel-2-items.parquet",
686 Search::default().datetime("2024-12-02T00:00:00Z/"),
687 )
688 .unwrap();
689 assert_eq!(item_collection.items.len(), 1);
690 }
691
692 #[rstest]
693 fn search_limit(client: Client) {
694 let item_collection = client
695 .search(
696 "data/100-sentinel-2-items.parquet",
697 Search::default().limit(42),
698 )
699 .unwrap();
700 assert_eq!(item_collection.items.len(), 42);
701 }
702
703 #[rstest]
704 fn search_offset(client: Client) {
705 let mut search = Search::default().limit(1);
706 search
707 .items
708 .additional_fields
709 .insert("offset".to_string(), 1.into());
710 let item_collection = client
711 .search("data/100-sentinel-2-items.parquet", search)
712 .unwrap();
713 assert_eq!(
714 item_collection.items[0]["id"],
715 "S2A_MSIL2A_20241201T175721_R141_T13TDE_20241201T213150"
716 );
717 }
718
719 #[rstest]
720 fn search_sortby(client: Client) {
721 let item_collection = client
722 .search(
723 "data/100-sentinel-2-items.parquet",
724 Search::default()
725 .sortby(vec![Sortby::asc("datetime")])
726 .limit(1),
727 )
728 .unwrap();
729 assert_eq!(
730 item_collection.items[0]["id"],
731 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
732 );
733
734 let item_collection = client
735 .search(
736 "data/100-sentinel-2-items.parquet",
737 Search::default()
738 .sortby(vec![Sortby::desc("datetime")])
739 .limit(1),
740 )
741 .unwrap();
742 assert_eq!(
743 item_collection.items[0]["id"],
744 "S2B_MSIL2A_20241203T174629_R098_T13TDE_20241203T211406"
745 );
746 }
747
748 #[rstest]
749 fn search_fields(client: Client) {
750 let item_collection = client
751 .search(
752 "data/100-sentinel-2-items.parquet",
753 Search::default().fields("+id".parse().unwrap()).limit(1),
754 )
755 .unwrap();
756 assert_eq!(item_collection.items[0].len(), 1);
757 }
758
759 #[rstest]
760 fn collections(client: Client) {
761 let collections = client
762 .collections("data/100-sentinel-2-items.parquet")
763 .unwrap();
764 assert_eq!(collections.len(), 1);
765 }
766
767 #[rstest]
768 fn no_convert_wkb(mut client: Client) {
769 client.convert_wkb = false;
770 let record_batches = client
771 .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
772 .unwrap()
773 .collect::<std::result::Result<Vec<_>, _>>()
774 .unwrap();
775 let schema = record_batches[0].schema();
776 assert_eq!(
777 schema.field_with_name("geometry").unwrap().metadata()["ARROW:extension:name"],
778 "geoarrow.wkb"
779 );
780 }
781
782 #[rstest]
783 fn filter(client: Client) {
784 let search = Search {
785 items: Items {
786 filter: Some("sat:relative_orbit = 98".parse().unwrap()),
787 ..Default::default()
788 },
789 ..Default::default()
790 };
791 let item_collection = client
792 .search("data/100-sentinel-2-items.parquet", search)
793 .unwrap();
794 assert_eq!(item_collection.items.len(), 49);
795 }
796
797 #[rstest]
798 fn filter_no_column(client: Client) {
799 let search = Search {
800 items: Items {
801 filter: Some("foo:bar = 42".parse().unwrap()),
802 ..Default::default()
803 },
804 ..Default::default()
805 };
806 let item_collection = client
807 .search("data/100-sentinel-2-items.parquet", search)
808 .unwrap();
809 assert_eq!(item_collection.items.len(), 0);
810 }
811
812 #[rstest]
813 fn sortby_property(client: Client) {
814 let search = Search {
815 items: Items {
816 sortby: vec!["eo:cloud_cover".parse().unwrap()],
817 ..Default::default()
818 },
819 ..Default::default()
820 };
821 let item_collection = client
822 .search("data/100-sentinel-2-items.parquet", search)
823 .unwrap();
824 assert_eq!(item_collection.items.len(), 100);
825 }
826
827 #[rstest]
828 fn union_by_name(client: Client) {
829 let _ = client.search("data/*.parquet", Default::default()).unwrap();
830 }
831
832 #[rstest]
833 fn no_union_by_name(mut client: Client) {
834 client.union_by_name = false;
835 let _ = client
836 .search("data/*.parquet", Default::default())
837 .unwrap_err();
838 }
839
840 #[rstest]
841 fn remove_filename_column(client: Client) {
842 let item_collection = client
843 .search("data/100-sentinel-2-items.parquet", Default::default())
844 .unwrap();
845 for item in item_collection.items {
846 assert!(
847 !item["properties"]
848 .as_object()
849 .as_ref()
850 .unwrap()
851 .contains_key("filename")
852 );
853 }
854 }
855}