Skip to main content

krishiv_sql/
connector_table.rs

1//! DataFusion `TableProviderFactory` implementations backed by
2//! [`krishiv_connectors::ConnectorRegistry`].
3
4use std::any::Any;
5use std::collections::HashSet;
6use std::path::PathBuf;
7use std::sync::{Arc, RwLock};
8
9use arrow::datatypes::{Schema, SchemaRef};
10use arrow::record_batch::RecordBatch;
11use async_trait::async_trait;
12use datafusion::catalog::TableProvider;
13use datafusion::catalog::TableProviderFactory;
14use datafusion::catalog::streaming::StreamingTable;
15use datafusion::datasource::MemTable;
16use datafusion::error::{DataFusionError, Result as DataFusionResult};
17use datafusion::logical_expr::CreateExternalTable;
18use datafusion::physical_plan::ExecutionPlan;
19use krishiv_connectors::{ConnectorConfig, ConnectorError, ConnectorRegistry, default_registry};
20
21use crate::kafka_table::{KafkaPartitionStream, kafka_auto_commit_interval_ms, project_batch};
22
23/// Reject paths that escape the warehouse root via traversal or absolutes.
24fn validate_path_under_warehouse(location: &str) -> DataFusionResult<()> {
25    let warehouse = std::env::var("KRISHIV_WAREHOUSE_ROOT").unwrap_or_else(|_| ".".to_string());
26    let base = PathBuf::from(&warehouse).canonicalize().map_err(|e| {
27        DataFusionError::External(Box::new(ConnectorError::Unsupported {
28            message: format!("warehouse root '{warehouse}' not accessible: {e}"),
29        }))
30    })?;
31    let candidate = PathBuf::from(location);
32    let resolved = if candidate.is_relative() {
33        base.join(&candidate)
34    } else {
35        candidate
36    };
37    let canonical = resolved.canonicalize().map_err(|e| {
38        DataFusionError::External(Box::new(ConnectorError::Unsupported {
39            message: format!("path '{location}' not accessible: {e}"),
40        }))
41    })?;
42    if !canonical.starts_with(&base) {
43        return Err(DataFusionError::External(Box::new(
44            ConnectorError::Unsupported {
45                message: format!("path '{location}' escapes warehouse root '{warehouse}'"),
46            },
47        )));
48    }
49    Ok(())
50}
51
52/// Shared registry instance for SQL DDL table factories.
53pub fn shared_connector_registry() -> Arc<ConnectorRegistry> {
54    Arc::new(default_registry())
55}
56
57/// Register PARQUET, S3, and KAFKA DDL factories on a DataFusion table-factory map.
58pub fn register_connector_table_factories(
59    table_factories: &mut std::collections::HashMap<String, Arc<dyn TableProviderFactory>>,
60    streaming_sources: Arc<RwLock<HashSet<String>>>,
61) {
62    let registry = shared_connector_registry();
63    table_factories.insert(
64        "PARQUET".to_string(),
65        Arc::new(ConnectorTableFactory::bounded(
66            "parquet",
67            Arc::clone(&registry),
68        )),
69    );
70    table_factories.insert(
71        "S3".to_string(),
72        Arc::new(ConnectorTableFactory::bounded("s3", registry)),
73    );
74    table_factories.insert(
75        "KAFKA".to_string(),
76        Arc::new(ConnectorTableFactory::streaming(streaming_sources)),
77    );
78}
79
80/// Build a [`ConnectorConfig`] from a `CREATE EXTERNAL TABLE` command.
81pub fn connector_config_from_ddl(
82    kind: &str,
83    cmd: &CreateExternalTable,
84) -> DataFusionResult<ConnectorConfig> {
85    let name = cmd.name.table().to_string();
86    Ok(match kind {
87        "parquet" => {
88            if !cmd.location.is_empty() {
89                validate_path_under_warehouse(&cmd.location)?;
90            }
91            ConnectorConfig::new(name, kind).with_property("path", cmd.location.clone())
92        }
93        "s3" => {
94            let mut cfg = ConnectorConfig::new(cmd.name.table(), kind)
95                .with_property("object_path", cmd.location.clone());
96            for (key, value) in &cmd.options {
97                if key == "base_path" {
98                    cfg = cfg.with_property("base_path", value.clone());
99                }
100            }
101            cfg
102        }
103        "kafka" => {
104            let mut cfg = ConnectorConfig::new(cmd.name.table(), kind)
105                .with_property("topic", cmd.location.clone())
106                .with_property("bootstrap.servers", "127.0.0.1:9092".to_string())
107                .with_property("group.id", "krishiv-sql".to_string());
108            for (key, value) in &cmd.options {
109                match key.as_str() {
110                    "bootstrap.servers" => {
111                        cfg = cfg.with_property("bootstrap.servers", value.clone());
112                    }
113                    "group.id" => {
114                        cfg = cfg.with_property("group.id", value.clone());
115                    }
116                    other => {
117                        cfg = cfg.with_property(other, value.clone());
118                    }
119                }
120            }
121            if let Some(ms) = kafka_auto_commit_interval_ms() {
122                cfg = cfg.with_property("auto.commit.interval.ms", ms.to_string());
123            }
124            cfg
125        }
126        _ => ConnectorConfig::new(name, kind).with_property("path", cmd.location.clone()),
127    })
128}
129
130fn connector_error(err: ConnectorError) -> DataFusionError {
131    DataFusionError::External(Box::new(err))
132}
133
134/// Factory for bounded connector sources opened through the registry.
135pub struct ConnectorTableFactory {
136    connector_kind: &'static str,
137    registry: Arc<ConnectorRegistry>,
138    streaming_sources: Option<Arc<RwLock<HashSet<String>>>>,
139}
140
141impl std::fmt::Debug for ConnectorTableFactory {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.debug_struct("ConnectorTableFactory")
144            .field("connector_kind", &self.connector_kind)
145            .finish_non_exhaustive()
146    }
147}
148
149impl ConnectorTableFactory {
150    pub fn bounded(connector_kind: &'static str, registry: Arc<ConnectorRegistry>) -> Self {
151        Self {
152            connector_kind,
153            registry,
154            streaming_sources: None,
155        }
156    }
157
158    pub fn streaming(streaming_sources: Arc<RwLock<HashSet<String>>>) -> Self {
159        Self {
160            connector_kind: "kafka",
161            registry: shared_connector_registry(),
162            streaming_sources: Some(streaming_sources),
163        }
164    }
165}
166
167#[async_trait]
168impl TableProviderFactory for ConnectorTableFactory {
169    async fn create(
170        &self,
171        _state: &dyn datafusion::catalog::Session,
172        cmd: &CreateExternalTable,
173    ) -> DataFusionResult<Arc<dyn TableProvider>> {
174        let config = connector_config_from_ddl(self.connector_kind, cmd)?;
175        self.registry
176            .validate_source(&config)
177            .map_err(connector_error)?;
178
179        if self.connector_kind == "kafka" {
180            return create_kafka_table_provider(cmd, &config, self.streaming_sources.as_ref())
181                .await;
182        }
183
184        let schema: SchemaRef = cmd.schema.as_ref().inner().clone();
185        Ok(Arc::new(BoundedConnectorProvider {
186            registry: Arc::clone(&self.registry),
187            config,
188            schema,
189        }))
190    }
191}
192
193async fn create_kafka_table_provider(
194    cmd: &CreateExternalTable,
195    config: &ConnectorConfig,
196    streaming_sources: Option<&Arc<RwLock<HashSet<String>>>>,
197) -> DataFusionResult<Arc<dyn TableProvider>> {
198    use krishiv_connectors::kafka::{KafkaConfig, KafkaSource};
199
200    let kafka_config = KafkaConfig::from_config(config).map_err(connector_error)?;
201    let schema: SchemaRef = cmd.schema.as_ref().inner().clone();
202    let source = KafkaSource::new(kafka_config).map_err(connector_error)?;
203    let partition = Arc::new(KafkaPartitionStream::new(schema.clone(), source));
204    let table = StreamingTable::try_new(schema, vec![partition])?;
205
206    if let Some(streaming_sources) = streaming_sources {
207        let table_name = cmd.name.table().to_string();
208        streaming_sources
209            .write()
210            .unwrap_or_else(|e| e.into_inner())
211            .insert(table_name);
212    }
213
214    Ok(Arc::new(table))
215}
216
217/// Bounded scan provider that materializes all connector batches at scan time.
218struct BoundedConnectorProvider {
219    registry: Arc<ConnectorRegistry>,
220    config: ConnectorConfig,
221    schema: SchemaRef,
222}
223
224impl std::fmt::Debug for BoundedConnectorProvider {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        f.debug_struct("BoundedConnectorProvider")
227            .field("config", &self.config)
228            .finish_non_exhaustive()
229    }
230}
231
232#[async_trait]
233impl TableProvider for BoundedConnectorProvider {
234    fn as_any(&self) -> &dyn Any {
235        self
236    }
237
238    fn schema(&self) -> SchemaRef {
239        Arc::clone(&self.schema)
240    }
241
242    fn table_type(&self) -> datafusion::logical_expr::TableType {
243        datafusion::logical_expr::TableType::Base
244    }
245
246    fn statistics(&self) -> Option<datafusion::physical_plan::Statistics> {
247        use datafusion::common::stats::Precision;
248        use datafusion::physical_plan::Statistics;
249        let row_count = self.registry.estimated_row_count(&self.config)?;
250        Some(Statistics {
251            num_rows: Precision::Inexact(row_count as usize),
252            ..Statistics::new_unknown(&self.schema)
253        })
254    }
255
256    async fn scan(
257        &self,
258        state: &dyn datafusion::catalog::Session,
259        projection: Option<&Vec<usize>>,
260        filters: &[datafusion::logical_expr::Expr],
261        limit: Option<usize>,
262    ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
263        let mut source = self
264            .registry
265            .open_source(&self.config)
266            .await
267            .map_err(connector_error)?;
268
269        // T7: apply the user's projection and limit eagerly. The previous
270        // implementation drained the entire source into a `MemTable` and
271        // deferred the projection and limit to DataFusion's
272        // `MemTable::scan`. That is correct but forces the connector to
273        // materialise every row and every column before any predicate
274        // runs, defeating Parquet column-pruning and file-pruning for any
275        // sink that does not have a `DataSourceExec` shim. Eager
276        // projection and limit short-circuits here bring the connector's
277        // behaviour closer to the `DataSourceExec` path and significantly
278        // reduce memory pressure for large bounded sources.
279        //
280        // Filter pushdown to the connector remains a follow-up: the
281        // connector `Source` trait does not yet accept filter
282        // expressions, and DataFusion's physical-expression builder is
283        // version-sensitive. For now, filters are still applied by
284        // DataFusion's downstream `MemTable::scan` so the result is
285        // identical — just less memory-efficient than a connector that
286        // accepts pushdown filters.
287        let projection_columns: Option<Vec<String>> = projection.map(|idxs| {
288            idxs.iter()
289                .map(|&i| self.schema.field(i).name().clone())
290                .collect()
291        });
292        let mut batches: Vec<RecordBatch> = Vec::new();
293        let mut rows_accumulated: usize = 0;
294        let limit_threshold: Option<usize> = limit;
295        loop {
296            let batch = source.read_batch_dyn().await.map_err(connector_error)?;
297            let Some(batch) = batch else { break };
298            let batch = project_batch(&batch, &self.schema)
299                .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
300            // Project to the user-requested columns.
301            let batch = match &projection_columns {
302                Some(cols) => project_to_columns(&batch, cols)
303                    .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?,
304                None => batch,
305            };
306            if batch.num_rows() == 0 {
307                continue;
308            }
309            // Honour the limit by truncating the last batch.
310            let batch = match limit_threshold {
311                Some(threshold) if rows_accumulated + batch.num_rows() > threshold => {
312                    let take = threshold.saturating_sub(rows_accumulated);
313                    batch.slice(0, take)
314                }
315                _ => batch,
316            };
317            rows_accumulated += batch.num_rows();
318            batches.push(batch);
319            if let Some(threshold) = limit_threshold
320                && rows_accumulated >= threshold
321            {
322                break;
323            }
324        }
325
326        let table = MemTable::try_new(Arc::clone(&self.schema), vec![batches])?;
327        table.scan(state, projection, filters, limit).await
328    }
329}
330
331/// T7: project a batch down to the named columns.
332fn project_to_columns(
333    batch: &RecordBatch,
334    columns: &[String],
335) -> arrow::error::Result<RecordBatch> {
336    if columns.is_empty() {
337        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
338    }
339    let mut cols = Vec::with_capacity(columns.len());
340    let mut fields = Vec::with_capacity(columns.len());
341    for name in columns {
342        let idx = batch.schema().index_of(name)?;
343        cols.push(batch.column(idx).clone());
344        fields.push(batch.schema().field(idx).clone());
345    }
346    RecordBatch::try_new(Arc::new(Schema::new(fields)), cols)
347}
348
349#[cfg(test)]
350mod tests {
351    use std::sync::Arc;
352
353    use arrow::datatypes::{DataType, Field, Schema};
354
355    use super::*;
356
357    #[test]
358    fn bounded_connector_provider_statistics_returns_none_for_unknown_table() {
359        let registry = Arc::new(krishiv_connectors::ConnectorRegistry::new());
360        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
361        let config = krishiv_connectors::ConnectorConfig::new("unknown", "parquet");
362        let provider = BoundedConnectorProvider {
363            registry,
364            config,
365            schema,
366        };
367        assert!(
368            provider.statistics().is_none(),
369            "no path in config → estimated_row_count returns None → statistics returns None"
370        );
371    }
372
373    #[test]
374    fn extract_create_external_table_name_parses_table_name() {
375        assert_eq!(
376            super::super::extract_create_external_table_name(
377                "CREATE EXTERNAL TABLE my_table STORED AS PARQUET LOCATION 'data.parquet'"
378            ),
379            Some("my_table".to_string())
380        );
381        assert_eq!(
382            super::super::extract_create_external_table_name("SELECT * FROM foo"),
383            None
384        );
385        assert_eq!(
386            super::super::extract_create_external_table_name(
387                "CREATE OR REPLACE EXTERNAL TABLE orders STORED AS PARQUET LOCATION 'orders.parquet'"
388            ),
389            Some("orders".to_string())
390        );
391    }
392
393    /// T7: `project_to_columns` must keep column order and tolerate an
394    /// empty column list (returns an empty projection with the original
395    /// schema).
396    #[test]
397    fn project_to_columns_preserves_order_and_handles_empty() {
398        use arrow::array::Int64Array;
399        let schema = Arc::new(Schema::new(vec![
400            Field::new("a", DataType::Int64, false),
401            Field::new("b", DataType::Int64, false),
402            Field::new("c", DataType::Int64, false),
403        ]));
404        let batch = RecordBatch::try_new(
405            schema.clone(),
406            vec![
407                Arc::new(Int64Array::from(vec![1, 2])) as _,
408                Arc::new(Int64Array::from(vec![3, 4])) as _,
409                Arc::new(Int64Array::from(vec![5, 6])) as _,
410            ],
411        )
412        .unwrap();
413        // Reorder: c, a
414        let projected = super::project_to_columns(&batch, &[String::from("c"), String::from("a")])
415            .expect("project must succeed");
416        assert_eq!(projected.num_columns(), 2);
417        assert_eq!(projected.schema().field(0).name(), "c");
418        assert_eq!(projected.schema().field(1).name(), "a");
419        // No-op projection.
420        let no_op = super::project_to_columns(&batch, &[]).expect("no-op projection must succeed");
421        assert_eq!(no_op.num_columns(), 0);
422    }
423}