zarr_datafusion/datasource/
zarr.rs1use arrow::datatypes::SchemaRef;
2use async_trait::async_trait;
3use datafusion::catalog::Session;
4use datafusion::common::stats::{ColumnStatistics, Precision, Statistics};
5use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
6use datafusion::{datasource::TableProvider, error::Result, physical_plan::ExecutionPlan};
7use std::sync::Arc;
8use tracing::{debug, info};
9use zarrs::storage::AsyncReadableListableStorage;
10use zarrs_object_store::object_store::path::Path as ObjectPath;
11
12use crate::physical_plan::zarr_exec::ZarrExec;
13use crate::reader::filter::parse_coord_filters;
14use crate::reader::schema_inference::ZarrStoreMeta;
15
16pub type CachedRemoteStore = Option<(AsyncReadableListableStorage, ObjectPath, ZarrStoreMeta)>;
18
19pub struct ZarrTable {
20 schema: SchemaRef,
21 path: String,
22 cached_remote: CachedRemoteStore,
24 store_meta: Option<ZarrStoreMeta>,
26}
27
28impl std::fmt::Debug for ZarrTable {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("ZarrTable")
31 .field("schema", &self.schema)
32 .field("path", &self.path)
33 .field(
34 "cached_remote",
35 &self.cached_remote.as_ref().map(|(_, p, _)| p),
36 )
37 .field(
38 "total_rows",
39 &self.store_meta.as_ref().map(|m| m.total_rows),
40 )
41 .finish()
42 }
43}
44
45impl ZarrTable {
46 pub fn new(schema: SchemaRef, path: impl Into<String>) -> Self {
47 Self {
48 schema,
49 path: path.into(),
50 cached_remote: None,
51 store_meta: None,
52 }
53 }
54
55 pub fn with_metadata(
57 schema: SchemaRef,
58 path: impl Into<String>,
59 metadata: ZarrStoreMeta,
60 ) -> Self {
61 Self {
62 schema,
63 path: path.into(),
64 cached_remote: None,
65 store_meta: Some(metadata),
66 }
67 }
68
69 pub fn with_cached_remote(
71 schema: SchemaRef,
72 path: impl Into<String>,
73 store: AsyncReadableListableStorage,
74 prefix: ObjectPath,
75 metadata: ZarrStoreMeta,
76 ) -> Self {
77 Self {
78 schema,
79 path: path.into(),
80 cached_remote: Some((store, prefix, metadata.clone())),
81 store_meta: Some(metadata),
82 }
83 }
84}
85
86#[async_trait]
87impl TableProvider for ZarrTable {
88 fn as_any(&self) -> &dyn std::any::Any {
89 self
90 }
91
92 fn schema(&self) -> SchemaRef {
93 self.schema.clone()
94 }
95
96 fn table_type(&self) -> datafusion::datasource::TableType {
97 datafusion::datasource::TableType::Base
98 }
99
100 fn supports_filters_pushdown(
106 &self,
107 filters: &[&Expr],
108 ) -> Result<Vec<TableProviderFilterPushDown>> {
109 Ok(filters
110 .iter()
111 .map(|_| TableProviderFilterPushDown::Inexact)
112 .collect())
113 }
114
115 async fn scan(
116 &self,
117 _state: &dyn Session,
118 projection: Option<&Vec<usize>>,
119 filters: &[datafusion::logical_expr::Expr],
120 limit: Option<usize>,
121 ) -> Result<Arc<dyn ExecutionPlan>> {
122 let total_columns = self.schema.fields().len();
124 if let Some(indices) = projection {
125 let projected_names: Vec<_> = indices
126 .iter()
127 .map(|&i| self.schema.field(i).name().as_str())
128 .collect();
129 info!(
130 projected = indices.len(),
131 total = total_columns,
132 columns = ?projected_names,
133 "Projection pushdown"
134 );
135 } else {
136 info!(
137 projected = total_columns,
138 total = total_columns,
139 "No projection pushdown (all columns)"
140 );
141 }
142
143 if let Some(limit) = limit {
145 info!(limit, "Limit pushdown");
146 }
147
148 debug!(
150 num_filters = filters.len(),
151 filters = ?filters,
152 "Filters passed to scan()"
153 );
154 let coord_filters = if let Some(meta) = &self.store_meta {
155 let coord_names: Vec<String> = meta.coords.iter().map(|c| c.name.clone()).collect();
156 debug!(?coord_names, "Coordinate names from metadata");
157 let parsed = parse_coord_filters(filters, &coord_names);
158 if !parsed.is_empty() {
159 info!(
160 num_filters = parsed.len(),
161 coords = ?parsed.filters.keys().collect::<Vec<_>>(),
162 "Filter pushdown"
163 );
164 Some(parsed)
165 } else {
166 None
167 }
168 } else {
169 None
171 };
172
173 Ok(Arc::new(ZarrExec::new(
174 self.schema.clone(),
175 self.path.clone(),
176 projection.cloned(),
177 limit,
178 self.cached_remote.clone(),
179 coord_filters,
180 )))
181 }
182
183 fn statistics(&self) -> Option<Statistics> {
192 let meta = self.store_meta.as_ref()?;
193
194 let column_statistics: Vec<ColumnStatistics> = self
196 .schema
197 .fields()
198 .iter()
199 .map(|field| {
200 let field_name = field.name();
201
202 if let Some(coord) = meta.coords.iter().find(|c| &c.name == field_name) {
204 if let Some((min, max)) = coord.coord_min_max {
205 let distinct_count = coord.shape[0] as usize;
207
208 let (min_value, max_value) = match field.data_type() {
211 arrow::datatypes::DataType::Dictionary(_, value_type) => {
212 scalar_values_from_f64(min, max, value_type.as_ref())
213 }
214 dt => scalar_values_from_f64(min, max, dt),
215 };
216
217 info!(
218 coord = %field_name,
219 min = %min_value,
220 max = %max_value,
221 distinct = distinct_count,
222 "Coordinate statistics"
223 );
224
225 return ColumnStatistics {
226 null_count: Precision::Exact(0),
227 min_value: Precision::Exact(min_value),
228 max_value: Precision::Exact(max_value),
229 distinct_count: Precision::Exact(distinct_count),
230 ..Default::default()
231 };
232 }
233 }
234
235 ColumnStatistics {
237 null_count: Precision::Exact(0),
238 ..Default::default()
239 }
240 })
241 .collect();
242
243 info!(
244 total_rows = meta.total_rows,
245 num_columns = column_statistics.len(),
246 "Providing statistics for query optimization"
247 );
248
249 Some(Statistics {
250 num_rows: Precision::Exact(meta.total_rows),
251 total_byte_size: Precision::Absent,
252 column_statistics,
253 })
254 }
255}
256
257fn scalar_values_from_f64(
259 min: f64,
260 max: f64,
261 data_type: &arrow::datatypes::DataType,
262) -> (
263 datafusion::common::ScalarValue,
264 datafusion::common::ScalarValue,
265) {
266 use arrow::datatypes::DataType;
267 use datafusion::common::ScalarValue;
268
269 match data_type {
270 DataType::Float64 => (
271 ScalarValue::Float64(Some(min)),
272 ScalarValue::Float64(Some(max)),
273 ),
274 DataType::Float32 => (
275 ScalarValue::Float32(Some(min as f32)),
276 ScalarValue::Float32(Some(max as f32)),
277 ),
278 DataType::Int64 => (
279 ScalarValue::Int64(Some(min as i64)),
280 ScalarValue::Int64(Some(max as i64)),
281 ),
282 DataType::Int32 => (
283 ScalarValue::Int32(Some(min as i32)),
284 ScalarValue::Int32(Some(max as i32)),
285 ),
286 DataType::Int16 => (
287 ScalarValue::Int16(Some(min as i16)),
288 ScalarValue::Int16(Some(max as i16)),
289 ),
290 DataType::UInt64 => (
291 ScalarValue::UInt64(Some(min as u64)),
292 ScalarValue::UInt64(Some(max as u64)),
293 ),
294 DataType::UInt32 => (
295 ScalarValue::UInt32(Some(min as u32)),
296 ScalarValue::UInt32(Some(max as u32)),
297 ),
298 _ => (
300 ScalarValue::Float64(Some(min)),
301 ScalarValue::Float64(Some(max)),
302 ),
303 }
304}