1use crate::{Error, Extension, Result};
2use arrow_array::{RecordBatch, RecordBatchIterator};
3use chrono::DateTime;
4use cql2::{Expr, ToDuckSQL};
5use duckdb::{Connection, types::Value};
6use geo::BoundingRect;
7use geojson::Geometry;
8use stac::{Collection, SpatialExtent, TemporalExtent, geoarrow::DATETIME_COLUMNS};
9use stac_api::{Direction, Search};
10use std::ops::{Deref, DerefMut};
11
12pub const DEFAULT_USE_HIVE_PARTITIONING: bool = false;
14
15pub const DEFAULT_CONVERT_WKB: bool = true;
17
18pub const DEFAULT_COLLECTION_DESCRIPTION: &str =
20 "Auto-generated collection from stac-geoparquet extents";
21
22pub const DEFAULT_UNION_BY_NAME: bool = true;
24
25#[derive(Debug)]
27pub struct Client {
28 connection: Connection,
29
30 pub use_hive_partitioning: bool,
32
33 pub convert_wkb: bool,
37
38 pub union_by_name: bool,
42}
43
44impl Client {
45 pub fn new() -> Result<Client> {
60 let connection = Connection::open_in_memory()?;
61 connection.execute("INSTALL spatial", [])?;
62 connection.execute("LOAD spatial", [])?;
63 connection.execute("INSTALL icu", [])?;
64 connection.execute("LOAD icu", [])?;
65 Ok(connection.into())
66 }
67
68 pub fn extensions(&self) -> Result<Vec<Extension>> {
79 let mut statement = self.prepare(
80 "SELECT extension_name, loaded, installed, install_path, description, extension_version, install_mode, installed_from FROM duckdb_extensions();",
81 )?;
82 let extensions = statement
83 .query_map([], |row| {
84 Ok(Extension {
85 name: row.get("extension_name")?,
86 loaded: row.get("loaded")?,
87 installed: row.get("installed")?,
88 install_path: row.get("install_path")?,
89 description: row.get("description")?,
90 version: row.get("extension_version")?,
91 install_mode: row.get("install_mode")?,
92 installed_from: row.get("installed_from")?,
93 })
94 })?
95 .collect::<std::result::Result<Vec<_>, duckdb::Error>>()?;
96 Ok(extensions)
97 }
98
99 pub fn collections(&self, href: &str) -> Result<Vec<Collection>> {
110 let start_datetime= if self.prepare(&format!(
111 "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'start_datetime'",
112 self.format_parquet_href(href)
113 ))?.query([])?.next()?.is_some() {
114 "strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
115 } else {
116 "strftime(min(datetime), '%xT%X%z')"
117 };
118 let end_datetime = if self
119 .prepare(&format!(
120 "SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'",
121 self.format_parquet_href(href)
122 ))?
123 .query([])?
124 .next()?
125 .is_some()
126 {
127 "strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
128 } else {
129 "strftime(max(datetime), '%xT%X%z')"
130 };
131 let mut statement = self.prepare(&format!(
132 "SELECT DISTINCT collection FROM {}",
133 self.format_parquet_href(href)
134 ))?;
135 let mut collections = Vec::new();
136 for row in statement.query_map([], |row| row.get::<_, String>(0))? {
137 let collection_id = row?;
138 let mut statement = self.connection.prepare(&
139 format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1", start_datetime, end_datetime,
140 self.format_parquet_href(href)
141 ))?;
142 let row = statement.query_row([&collection_id], |row| {
143 Ok((
144 row.get::<_, String>(0)?,
145 row.get::<_, String>(1)?,
146 row.get::<_, String>(2)?,
147 ))
148 })?;
149 let mut collection = Collection::new(collection_id, DEFAULT_COLLECTION_DESCRIPTION);
150 let geometry: geo::Geometry = Geometry::from_json_value(serde_json::from_str(&row.0)?)
151 .map_err(Box::new)?
152 .try_into()
153 .map_err(Box::new)?;
154 if let Some(bbox) = geometry.bounding_rect() {
155 collection.extent.spatial = SpatialExtent {
156 bbox: vec![bbox.into()],
157 };
158 }
159 collection.extent.temporal = TemporalExtent {
160 interval: vec![[
161 Some(DateTime::parse_from_str(&row.1, "%FT%T%#z")?.into()),
162 Some(DateTime::parse_from_str(&row.2, "%FT%T%#z")?.into()),
163 ]],
164 };
165 collections.push(collection);
166 }
167 Ok(collections)
168 }
169
170 pub fn search(&self, href: &str, search: Search) -> Result<stac_api::ItemCollection> {
181 let record_batches = self.search_to_arrow(href, search)?;
182 if record_batches.is_empty() {
183 Ok(Default::default())
184 } else {
185 let schema = record_batches[0].schema();
186 let item_collection = stac::geoarrow::json::from_record_batch_reader(
187 RecordBatchIterator::new(record_batches.into_iter().map(Ok), schema),
188 )?;
189 Ok(item_collection.into())
190 }
191 }
192
193 pub fn search_to_arrow(&self, href: &str, search: Search) -> Result<Vec<RecordBatch>> {
204 if search.items.query.is_some() {
209 return Err(Error::QueryNotImplemented);
210 }
211
212 let mut statement = self.prepare(&format!(
214 "SELECT column_name FROM (DESCRIBE SELECT * from {})",
215 self.format_parquet_href(href)
216 ))?;
217 let mut has_start_datetime = false;
218 let mut has_end_datetime = false;
219 let mut column_names = Vec::new();
220 let mut columns = Vec::new();
221 for row in statement.query_map([], |row| row.get::<_, String>(0))? {
222 let column = row?;
223 if column == "start_datetime" {
224 has_start_datetime = true;
225 }
226 if column == "end_datetime" {
227 has_end_datetime = true;
228 }
229
230 if let Some(fields) = search.fields.as_ref() {
231 if fields.exclude.contains(&column)
232 || !(fields.include.is_empty() || fields.include.contains(&column))
233 {
234 continue;
235 }
236 }
237
238 if column == "geometry" {
239 columns.push("ST_AsWKB(geometry) geometry".to_string());
240 } else if DATETIME_COLUMNS.contains(&column.as_str()) {
241 columns.push(format!("\"{column}\"::TIMESTAMPTZ {column}"))
242 } else {
243 columns.push(format!("\"{column}\""));
244 }
245 column_names.push(column);
246 }
247
248 let limit = search.items.limit;
250 let offset = search
251 .items
252 .additional_fields
253 .get("offset")
254 .and_then(|v| v.as_i64());
255
256 let mut order_by = Vec::with_capacity(search.sortby.len());
258 for sortby in &search.sortby {
259 order_by.push(format!(
260 "\"{}\" {}",
261 sortby.field,
262 match sortby.direction {
263 Direction::Ascending => "ASC",
264 Direction::Descending => "DESC",
265 }
266 ));
267 }
268
269 let mut wheres = Vec::new();
271 let mut params = Vec::new();
272 if !search.ids.is_empty() {
273 wheres.push(format!(
274 "id IN ({})",
275 (0..search.ids.len())
276 .map(|_| "?")
277 .collect::<Vec<_>>()
278 .join(",")
279 ));
280 params.extend(search.ids.into_iter().map(Value::Text));
281 }
282 if let Some(intersects) = search.intersects {
283 wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
284 params.push(Value::Text(intersects.to_string()));
285 }
286 if !search.collections.is_empty() {
287 wheres.push(format!(
288 "collection IN ({})",
289 (0..search.collections.len())
290 .map(|_| "?")
291 .collect::<Vec<_>>()
292 .join(",")
293 ));
294 params.extend(search.collections.into_iter().map(Value::Text));
295 }
296 if let Some(bbox) = search.items.bbox {
297 wheres.push("ST_Intersects(geometry, ST_GeomFromGeoJSON(?))".to_string());
298 params.push(Value::Text(bbox.to_geometry().to_string()));
299 }
300 if let Some(datetime) = search.items.datetime {
301 let interval = stac::datetime::parse(&datetime)?;
302 if let Some(start) = interval.0 {
303 wheres.push(format!(
304 "?::TIMESTAMPTZ <= {}",
305 if has_start_datetime {
306 "start_datetime"
307 } else {
308 "datetime"
309 }
310 ));
311 params.push(Value::Text(start.to_rfc3339()));
312 }
313 if let Some(end) = interval.1 {
314 wheres.push(format!(
315 "?::TIMESTAMPTZ >= {}", if has_end_datetime {
317 "end_datetime"
318 } else {
319 "datetime"
320 }
321 ));
322 params.push(Value::Text(end.to_rfc3339()));
323 }
324 }
325 if let Some(filter) = search.items.filter {
326 let expr: Expr = filter.try_into()?;
327 if expr_properties_match(&expr, &column_names) {
328 let sql = expr.to_ducksql().map_err(Box::new)?;
329 wheres.push(sql);
330 } else {
331 return Ok(Vec::new());
332 }
333 }
334
335 let mut suffix = String::new();
336 if !wheres.is_empty() {
337 suffix.push_str(&format!(" WHERE {}", wheres.join(" AND ")));
338 }
339 if !order_by.is_empty() {
340 suffix.push_str(&format!(" ORDER BY {}", order_by.join(", ")));
341 }
342 if let Some(limit) = limit {
343 suffix.push_str(&format!(" LIMIT {limit}"));
344 }
345 if let Some(offset) = offset {
346 suffix.push_str(&format!(" OFFSET {offset}"));
347 }
348
349 let sql = format!(
350 "SELECT {} FROM {}{}",
351 columns.join(","),
352 self.format_parquet_href(href),
353 suffix,
354 );
355 log::debug!("duckdb sql: {sql}");
356 let mut statement = self.prepare(&sql)?;
357 statement
358 .query_arrow(duckdb::params_from_iter(params))?
359 .map(|record_batch| {
360 let record_batch = if self.convert_wkb {
361 stac::geoarrow::with_native_geometry(record_batch, "geometry")?
362 } else {
363 stac::geoarrow::add_wkb_metadata(record_batch, "geometry")?
364 };
365 Ok(record_batch)
366 })
367 .collect::<Result<_>>()
368 }
369
370 fn format_parquet_href(&self, href: &str) -> String {
371 format!(
372 "read_parquet('{}', filename=true, hive_partitioning={}, union_by_name={})",
373 href,
374 if self.use_hive_partitioning {
375 "true"
376 } else {
377 "false"
378 },
379 if self.union_by_name { "true" } else { "false" }
380 )
381 }
382}
383
384fn expr_properties_match(expr: &Expr, properties: &[String]) -> bool {
385 use Expr::*;
386
387 match expr {
388 Property { property } => properties.contains(property),
389 Float(_) | Literal(_) | Bool(_) | Geometry(_) => true,
390 Operation { args, .. } => args
391 .iter()
392 .all(|expr| expr_properties_match(expr, properties)),
393 Interval { interval } => interval
394 .iter()
395 .all(|expr| expr_properties_match(expr, properties)),
396 Timestamp { timestamp } => expr_properties_match(timestamp, properties),
397 Date { date } => expr_properties_match(date, properties),
398 Array(exprs) => exprs
399 .iter()
400 .all(|expr| expr_properties_match(expr, properties)),
401 BBox { bbox } => bbox
402 .iter()
403 .all(|expr| expr_properties_match(expr, properties)),
404 Null => expr_properties_match(expr, properties),
405 }
406}
407
408impl Deref for Client {
409 type Target = Connection;
410
411 fn deref(&self) -> &Self::Target {
412 &self.connection
413 }
414}
415
416impl DerefMut for Client {
417 fn deref_mut(&mut self) -> &mut Self::Target {
418 &mut self.connection
419 }
420}
421
422impl From<Connection> for Client {
423 fn from(connection: Connection) -> Self {
424 Client {
425 connection,
426 use_hive_partitioning: DEFAULT_USE_HIVE_PARTITIONING,
427 convert_wkb: DEFAULT_CONVERT_WKB,
428 union_by_name: DEFAULT_UNION_BY_NAME,
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::Client;
436 use duckdb::Connection;
437 use geo::Geometry;
438 use rstest::{fixture, rstest};
439 use stac::Bbox;
440 use stac_api::{Items, Search, Sortby};
441 use stac_validate::Validate;
442
443 #[fixture]
444 #[once]
445 fn install_spatial() {
446 let connection = Connection::open_in_memory().unwrap();
447 connection.execute("INSTALL spatial", []).unwrap();
448 }
449
450 #[allow(unused_variables)]
451 #[fixture]
452 fn client(install_spatial: ()) -> Client {
453 Client::new().unwrap()
454 }
455
456 #[allow(unused_variables)]
457 #[rstest]
458 fn new(install_spatial: ()) {
459 Client::new().unwrap();
460 }
461
462 #[rstest]
463 fn extensions(client: Client) {
464 let _ = client.extensions().unwrap();
465 }
466
467 #[rstest]
468 #[tokio::test]
469 async fn search(client: Client) {
470 let item_collection = client
471 .search("data/100-sentinel-2-items.parquet", Search::default())
472 .unwrap();
473 assert_eq!(item_collection.items.len(), 100);
474 item_collection.items[0].validate().await.unwrap();
475 }
476
477 #[rstest]
478 fn search_to_arrow(client: Client) {
479 let record_batches = client
480 .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
481 .unwrap();
482 assert_eq!(record_batches.len(), 1);
483 }
484
485 #[rstest]
486 fn search_ids(client: Client) {
487 let item_collection = client
488 .search(
489 "data/100-sentinel-2-items.parquet",
490 Search::default().ids(vec![
491 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429".to_string(),
492 ]),
493 )
494 .unwrap();
495 assert_eq!(item_collection.items.len(), 1);
496 assert_eq!(
497 item_collection.items[0]["id"],
498 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
499 );
500 }
501
502 #[rstest]
503 fn search_intersects(client: Client) {
504 let item_collection = client
505 .search(
506 "data/100-sentinel-2-items.parquet",
507 Search::default().intersects(&Geometry::Point(geo::point! { x: -106., y: 40.5 })),
508 )
509 .unwrap();
510 assert_eq!(item_collection.items.len(), 50);
511 }
512
513 #[rstest]
514 fn search_collections(client: Client) {
515 let item_collection = client
516 .search(
517 "data/100-sentinel-2-items.parquet",
518 Search::default().collections(vec!["sentinel-2-l2a".to_string()]),
519 )
520 .unwrap();
521 assert_eq!(item_collection.items.len(), 100);
522
523 let item_collection = client
524 .search(
525 "data/100-sentinel-2-items.parquet",
526 Search::default().collections(vec!["foobar".to_string()]),
527 )
528 .unwrap();
529 assert_eq!(item_collection.items.len(), 0);
530 }
531
532 #[rstest]
533 fn search_bbox(client: Client) {
534 let item_collection = client
535 .search(
536 "data/100-sentinel-2-items.parquet",
537 Search::default().bbox(Bbox::new(-106.1, 40.5, -106.0, 40.6)),
538 )
539 .unwrap();
540 assert_eq!(item_collection.items.len(), 50);
541 }
542
543 #[rstest]
544 fn search_datetime(client: Client) {
545 let item_collection = client
546 .search(
547 "data/100-sentinel-2-items.parquet",
548 Search::default().datetime("2024-12-02T00:00:00Z/.."),
549 )
550 .unwrap();
551 assert_eq!(item_collection.items.len(), 1);
552 let item_collection = client
553 .search(
554 "data/100-sentinel-2-items.parquet",
555 Search::default().datetime("../2024-12-02T00:00:00Z"),
556 )
557 .unwrap();
558 assert_eq!(item_collection.items.len(), 99);
559 }
560
561 #[rstest]
562 fn search_datetime_empty_interval(client: Client) {
563 let item_collection = client
564 .search(
565 "data/100-sentinel-2-items.parquet",
566 Search::default().datetime("2024-12-02T00:00:00Z/"),
567 )
568 .unwrap();
569 assert_eq!(item_collection.items.len(), 1);
570 }
571
572 #[rstest]
573 fn search_limit(client: Client) {
574 let item_collection = client
575 .search(
576 "data/100-sentinel-2-items.parquet",
577 Search::default().limit(42),
578 )
579 .unwrap();
580 assert_eq!(item_collection.items.len(), 42);
581 }
582
583 #[rstest]
584 fn search_offset(client: Client) {
585 let mut search = Search::default().limit(1);
586 search
587 .items
588 .additional_fields
589 .insert("offset".to_string(), 1.into());
590 let item_collection = client
591 .search("data/100-sentinel-2-items.parquet", search)
592 .unwrap();
593 assert_eq!(
594 item_collection.items[0]["id"],
595 "S2A_MSIL2A_20241201T175721_R141_T13TDE_20241201T213150"
596 );
597 }
598
599 #[rstest]
600 fn search_sortby(client: Client) {
601 let item_collection = client
602 .search(
603 "data/100-sentinel-2-items.parquet",
604 Search::default()
605 .sortby(vec![Sortby::asc("datetime")])
606 .limit(1),
607 )
608 .unwrap();
609 assert_eq!(
610 item_collection.items[0]["id"],
611 "S2A_MSIL2A_20240326T174951_R141_T13TDE_20240329T224429"
612 );
613
614 let item_collection = client
615 .search(
616 "data/100-sentinel-2-items.parquet",
617 Search::default()
618 .sortby(vec![Sortby::desc("datetime")])
619 .limit(1),
620 )
621 .unwrap();
622 assert_eq!(
623 item_collection.items[0]["id"],
624 "S2B_MSIL2A_20241203T174629_R098_T13TDE_20241203T211406"
625 );
626 }
627
628 #[rstest]
629 fn search_fields(client: Client) {
630 let item_collection = client
631 .search(
632 "data/100-sentinel-2-items.parquet",
633 Search::default().fields("+id".parse().unwrap()).limit(1),
634 )
635 .unwrap();
636 assert_eq!(item_collection.items[0].len(), 1);
637 }
638
639 #[rstest]
640 fn collections(client: Client) {
641 let collections = client
642 .collections("data/100-sentinel-2-items.parquet")
643 .unwrap();
644 assert_eq!(collections.len(), 1);
645 }
646
647 #[rstest]
648 fn no_convert_wkb(mut client: Client) {
649 client.convert_wkb = false;
650 let record_batches = client
651 .search_to_arrow("data/100-sentinel-2-items.parquet", Search::default())
652 .unwrap();
653 let schema = record_batches[0].schema();
654 assert_eq!(
655 schema.field_with_name("geometry").unwrap().metadata()["ARROW:extension:name"],
656 "geoarrow.wkb"
657 );
658 }
659
660 #[rstest]
661 fn filter(client: Client) {
662 let search = Search {
663 items: Items {
664 filter: Some("sat:relative_orbit = 98".parse().unwrap()),
665 ..Default::default()
666 },
667 ..Default::default()
668 };
669 let item_collection = client
670 .search("data/100-sentinel-2-items.parquet", search)
671 .unwrap();
672 assert_eq!(item_collection.items.len(), 49);
673 }
674
675 #[rstest]
676 fn filter_no_column(client: Client) {
677 let search = Search {
678 items: Items {
679 filter: Some("foo:bar = 42".parse().unwrap()),
680 ..Default::default()
681 },
682 ..Default::default()
683 };
684 let item_collection = client
685 .search("data/100-sentinel-2-items.parquet", search)
686 .unwrap();
687 assert_eq!(item_collection.items.len(), 0);
688 }
689
690 #[rstest]
691 fn sortby_property(client: Client) {
692 let search = Search {
693 items: Items {
694 sortby: vec!["eo:cloud_cover".parse().unwrap()],
695 ..Default::default()
696 },
697 ..Default::default()
698 };
699 let item_collection = client
700 .search("data/100-sentinel-2-items.parquet", search)
701 .unwrap();
702 assert_eq!(item_collection.items.len(), 100);
703 }
704
705 #[rstest]
706 fn union_by_name(client: Client) {
707 let _ = client.search("data/*.parquet", Default::default()).unwrap();
708 }
709
710 #[rstest]
711 fn no_union_by_name(mut client: Client) {
712 client.union_by_name = false;
713 let _ = client
714 .search("data/*.parquet", Default::default())
715 .unwrap_err();
716 }
717}