1use std::any::Any;
5use std::collections::HashSet;
6use std::path::PathBuf;
7use std::sync::{Arc, RwLock};
8
9use arrow::datatypes::{Schema, SchemaRef};
10use arrow::record_batch::RecordBatch;
11use async_trait::async_trait;
12use datafusion::catalog::TableProvider;
13use datafusion::catalog::TableProviderFactory;
14use datafusion::catalog::streaming::StreamingTable;
15use datafusion::datasource::MemTable;
16use datafusion::error::{DataFusionError, Result as DataFusionResult};
17use datafusion::logical_expr::CreateExternalTable;
18use datafusion::physical_plan::ExecutionPlan;
19use krishiv_connectors::{ConnectorConfig, ConnectorError, ConnectorRegistry, default_registry};
20
21use crate::kafka_table::{KafkaPartitionStream, kafka_auto_commit_interval_ms, project_batch};
22
23fn validate_path_under_warehouse(location: &str) -> DataFusionResult<()> {
25 let warehouse = std::env::var("KRISHIV_WAREHOUSE_ROOT").unwrap_or_else(|_| ".".to_string());
26 let base = PathBuf::from(&warehouse).canonicalize().map_err(|e| {
27 DataFusionError::External(Box::new(ConnectorError::Unsupported {
28 message: format!("warehouse root '{warehouse}' not accessible: {e}"),
29 }))
30 })?;
31 let candidate = PathBuf::from(location);
32 let resolved = if candidate.is_relative() {
33 base.join(&candidate)
34 } else {
35 candidate
36 };
37 let canonical = resolved.canonicalize().map_err(|e| {
38 DataFusionError::External(Box::new(ConnectorError::Unsupported {
39 message: format!("path '{location}' not accessible: {e}"),
40 }))
41 })?;
42 if !canonical.starts_with(&base) {
43 return Err(DataFusionError::External(Box::new(
44 ConnectorError::Unsupported {
45 message: format!("path '{location}' escapes warehouse root '{warehouse}'"),
46 },
47 )));
48 }
49 Ok(())
50}
51
52pub fn shared_connector_registry() -> Arc<ConnectorRegistry> {
54 Arc::new(default_registry())
55}
56
57pub fn register_connector_table_factories(
59 table_factories: &mut std::collections::HashMap<String, Arc<dyn TableProviderFactory>>,
60 streaming_sources: Arc<RwLock<HashSet<String>>>,
61) {
62 let registry = shared_connector_registry();
63 table_factories.insert(
64 "PARQUET".to_string(),
65 Arc::new(ConnectorTableFactory::bounded(
66 "parquet",
67 Arc::clone(®istry),
68 )),
69 );
70 table_factories.insert(
71 "S3".to_string(),
72 Arc::new(ConnectorTableFactory::bounded("s3", registry)),
73 );
74 table_factories.insert(
75 "KAFKA".to_string(),
76 Arc::new(ConnectorTableFactory::streaming(streaming_sources)),
77 );
78}
79
80pub fn connector_config_from_ddl(
82 kind: &str,
83 cmd: &CreateExternalTable,
84) -> DataFusionResult<ConnectorConfig> {
85 let name = cmd.name.table().to_string();
86 Ok(match kind {
87 "parquet" => {
88 if !cmd.location.is_empty() {
89 validate_path_under_warehouse(&cmd.location)?;
90 }
91 ConnectorConfig::new(name, kind).with_property("path", cmd.location.clone())
92 }
93 "s3" => {
94 let mut cfg = ConnectorConfig::new(cmd.name.table(), kind)
95 .with_property("object_path", cmd.location.clone());
96 for (key, value) in &cmd.options {
97 if key == "base_path" {
98 cfg = cfg.with_property("base_path", value.clone());
99 }
100 }
101 cfg
102 }
103 "kafka" => {
104 let mut cfg = ConnectorConfig::new(cmd.name.table(), kind)
105 .with_property("topic", cmd.location.clone())
106 .with_property("bootstrap.servers", "127.0.0.1:9092".to_string())
107 .with_property("group.id", "krishiv-sql".to_string());
108 for (key, value) in &cmd.options {
109 match key.as_str() {
110 "bootstrap.servers" => {
111 cfg = cfg.with_property("bootstrap.servers", value.clone());
112 }
113 "group.id" => {
114 cfg = cfg.with_property("group.id", value.clone());
115 }
116 other => {
117 cfg = cfg.with_property(other, value.clone());
118 }
119 }
120 }
121 if let Some(ms) = kafka_auto_commit_interval_ms() {
122 cfg = cfg.with_property("auto.commit.interval.ms", ms.to_string());
123 }
124 cfg
125 }
126 _ => ConnectorConfig::new(name, kind).with_property("path", cmd.location.clone()),
127 })
128}
129
130fn connector_error(err: ConnectorError) -> DataFusionError {
131 DataFusionError::External(Box::new(err))
132}
133
134pub struct ConnectorTableFactory {
136 connector_kind: &'static str,
137 registry: Arc<ConnectorRegistry>,
138 streaming_sources: Option<Arc<RwLock<HashSet<String>>>>,
139}
140
141impl std::fmt::Debug for ConnectorTableFactory {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 f.debug_struct("ConnectorTableFactory")
144 .field("connector_kind", &self.connector_kind)
145 .finish_non_exhaustive()
146 }
147}
148
149impl ConnectorTableFactory {
150 pub fn bounded(connector_kind: &'static str, registry: Arc<ConnectorRegistry>) -> Self {
151 Self {
152 connector_kind,
153 registry,
154 streaming_sources: None,
155 }
156 }
157
158 pub fn streaming(streaming_sources: Arc<RwLock<HashSet<String>>>) -> Self {
159 Self {
160 connector_kind: "kafka",
161 registry: shared_connector_registry(),
162 streaming_sources: Some(streaming_sources),
163 }
164 }
165}
166
167#[async_trait]
168impl TableProviderFactory for ConnectorTableFactory {
169 async fn create(
170 &self,
171 _state: &dyn datafusion::catalog::Session,
172 cmd: &CreateExternalTable,
173 ) -> DataFusionResult<Arc<dyn TableProvider>> {
174 let config = connector_config_from_ddl(self.connector_kind, cmd)?;
175 self.registry
176 .validate_source(&config)
177 .map_err(connector_error)?;
178
179 if self.connector_kind == "kafka" {
180 return create_kafka_table_provider(cmd, &config, self.streaming_sources.as_ref())
181 .await;
182 }
183
184 let schema: SchemaRef = cmd.schema.as_ref().inner().clone();
185 Ok(Arc::new(BoundedConnectorProvider {
186 registry: Arc::clone(&self.registry),
187 config,
188 schema,
189 }))
190 }
191}
192
193async fn create_kafka_table_provider(
194 cmd: &CreateExternalTable,
195 config: &ConnectorConfig,
196 streaming_sources: Option<&Arc<RwLock<HashSet<String>>>>,
197) -> DataFusionResult<Arc<dyn TableProvider>> {
198 use krishiv_connectors::kafka::{KafkaConfig, KafkaSource};
199
200 let kafka_config = KafkaConfig::from_config(config).map_err(connector_error)?;
201 let schema: SchemaRef = cmd.schema.as_ref().inner().clone();
202 let source = KafkaSource::new(kafka_config).map_err(connector_error)?;
203 let partition = Arc::new(KafkaPartitionStream::new(schema.clone(), source));
204 let table = StreamingTable::try_new(schema, vec![partition])?;
205
206 if let Some(streaming_sources) = streaming_sources {
207 let table_name = cmd.name.table().to_string();
208 streaming_sources
209 .write()
210 .unwrap_or_else(|e| e.into_inner())
211 .insert(table_name);
212 }
213
214 Ok(Arc::new(table))
215}
216
217struct BoundedConnectorProvider {
219 registry: Arc<ConnectorRegistry>,
220 config: ConnectorConfig,
221 schema: SchemaRef,
222}
223
224impl std::fmt::Debug for BoundedConnectorProvider {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 f.debug_struct("BoundedConnectorProvider")
227 .field("config", &self.config)
228 .finish_non_exhaustive()
229 }
230}
231
232#[async_trait]
233impl TableProvider for BoundedConnectorProvider {
234 fn as_any(&self) -> &dyn Any {
235 self
236 }
237
238 fn schema(&self) -> SchemaRef {
239 Arc::clone(&self.schema)
240 }
241
242 fn table_type(&self) -> datafusion::logical_expr::TableType {
243 datafusion::logical_expr::TableType::Base
244 }
245
246 fn statistics(&self) -> Option<datafusion::physical_plan::Statistics> {
247 use datafusion::common::stats::Precision;
248 use datafusion::physical_plan::Statistics;
249 let row_count = self.registry.estimated_row_count(&self.config)?;
250 Some(Statistics {
251 num_rows: Precision::Inexact(row_count as usize),
252 ..Statistics::new_unknown(&self.schema)
253 })
254 }
255
256 async fn scan(
257 &self,
258 state: &dyn datafusion::catalog::Session,
259 projection: Option<&Vec<usize>>,
260 filters: &[datafusion::logical_expr::Expr],
261 limit: Option<usize>,
262 ) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
263 let mut source = self
264 .registry
265 .open_source(&self.config)
266 .await
267 .map_err(connector_error)?;
268
269 let projection_columns: Option<Vec<String>> = projection.map(|idxs| {
288 idxs.iter()
289 .map(|&i| self.schema.field(i).name().clone())
290 .collect()
291 });
292 let mut batches: Vec<RecordBatch> = Vec::new();
293 let mut rows_accumulated: usize = 0;
294 let limit_threshold: Option<usize> = limit;
295 loop {
296 let batch = source.read_batch_dyn().await.map_err(connector_error)?;
297 let Some(batch) = batch else { break };
298 let batch = project_batch(&batch, &self.schema)
299 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
300 let batch = match &projection_columns {
302 Some(cols) => project_to_columns(&batch, cols)
303 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?,
304 None => batch,
305 };
306 if batch.num_rows() == 0 {
307 continue;
308 }
309 let batch = match limit_threshold {
311 Some(threshold) if rows_accumulated + batch.num_rows() > threshold => {
312 let take = threshold.saturating_sub(rows_accumulated);
313 batch.slice(0, take)
314 }
315 _ => batch,
316 };
317 rows_accumulated += batch.num_rows();
318 batches.push(batch);
319 if let Some(threshold) = limit_threshold
320 && rows_accumulated >= threshold
321 {
322 break;
323 }
324 }
325
326 let table = MemTable::try_new(Arc::clone(&self.schema), vec![batches])?;
327 table.scan(state, projection, filters, limit).await
328 }
329}
330
331fn project_to_columns(
333 batch: &RecordBatch,
334 columns: &[String],
335) -> arrow::error::Result<RecordBatch> {
336 if columns.is_empty() {
337 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
338 }
339 let mut cols = Vec::with_capacity(columns.len());
340 let mut fields = Vec::with_capacity(columns.len());
341 for name in columns {
342 let idx = batch.schema().index_of(name)?;
343 cols.push(batch.column(idx).clone());
344 fields.push(batch.schema().field(idx).clone());
345 }
346 RecordBatch::try_new(Arc::new(Schema::new(fields)), cols)
347}
348
349#[cfg(test)]
350mod tests {
351 use std::sync::Arc;
352
353 use arrow::datatypes::{DataType, Field, Schema};
354
355 use super::*;
356
357 #[test]
358 fn bounded_connector_provider_statistics_returns_none_for_unknown_table() {
359 let registry = Arc::new(krishiv_connectors::ConnectorRegistry::new());
360 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
361 let config = krishiv_connectors::ConnectorConfig::new("unknown", "parquet");
362 let provider = BoundedConnectorProvider {
363 registry,
364 config,
365 schema,
366 };
367 assert!(
368 provider.statistics().is_none(),
369 "no path in config → estimated_row_count returns None → statistics returns None"
370 );
371 }
372
373 #[test]
374 fn extract_create_external_table_name_parses_table_name() {
375 assert_eq!(
376 super::super::extract_create_external_table_name(
377 "CREATE EXTERNAL TABLE my_table STORED AS PARQUET LOCATION 'data.parquet'"
378 ),
379 Some("my_table".to_string())
380 );
381 assert_eq!(
382 super::super::extract_create_external_table_name("SELECT * FROM foo"),
383 None
384 );
385 assert_eq!(
386 super::super::extract_create_external_table_name(
387 "CREATE OR REPLACE EXTERNAL TABLE orders STORED AS PARQUET LOCATION 'orders.parquet'"
388 ),
389 Some("orders".to_string())
390 );
391 }
392
393 #[test]
397 fn project_to_columns_preserves_order_and_handles_empty() {
398 use arrow::array::Int64Array;
399 let schema = Arc::new(Schema::new(vec![
400 Field::new("a", DataType::Int64, false),
401 Field::new("b", DataType::Int64, false),
402 Field::new("c", DataType::Int64, false),
403 ]));
404 let batch = RecordBatch::try_new(
405 schema.clone(),
406 vec![
407 Arc::new(Int64Array::from(vec![1, 2])) as _,
408 Arc::new(Int64Array::from(vec![3, 4])) as _,
409 Arc::new(Int64Array::from(vec![5, 6])) as _,
410 ],
411 )
412 .unwrap();
413 let projected = super::project_to_columns(&batch, &[String::from("c"), String::from("a")])
415 .expect("project must succeed");
416 assert_eq!(projected.num_columns(), 2);
417 assert_eq!(projected.schema().field(0).name(), "c");
418 assert_eq!(projected.schema().field(1).name(), "a");
419 let no_op = super::project_to_columns(&batch, &[]).expect("no-op projection must succeed");
421 assert_eq!(no_op.num_columns(), 0);
422 }
423}