1use std::any::Any;
21use std::sync::Arc;
22
23use async_trait::async_trait;
24use datafusion::arrow::datatypes::{Field, Schema, SchemaRef as ArrowSchemaRef};
25use datafusion::catalog::Session;
26use datafusion::datasource::sink::DataSinkExec;
27use datafusion::datasource::{TableProvider, TableType};
28use datafusion::error::Result as DFResult;
29use datafusion::logical_expr::dml::InsertOp;
30use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
31use datafusion::physical_plan::ExecutionPlan;
32use paimon::table::Table;
33
34use crate::physical_plan::PaimonDataSink;
35
36use crate::error::to_datafusion_error;
37use crate::filter_pushdown::{build_pushed_predicate, classify_filter_pushdown};
38use crate::physical_plan::PaimonTableScan;
39use crate::runtime::await_with_runtime;
40
41#[derive(Debug, Clone)]
51pub struct PaimonTableProvider {
52 table: Table,
53 schema: ArrowSchemaRef,
54}
55
56impl PaimonTableProvider {
57 pub fn try_new(table: Table) -> DFResult<Self> {
61 let mut fields = table.schema().fields().to_vec();
62 let core_options = paimon::spec::CoreOptions::new(table.schema().options());
63 if core_options.data_evolution_enabled() {
64 fields.push(paimon::spec::DataField::new(
65 paimon::spec::ROW_ID_FIELD_ID,
66 paimon::spec::ROW_ID_FIELD_NAME.to_string(),
67 paimon::spec::DataType::BigInt(paimon::spec::BigIntType::with_nullable(true)),
68 ));
69 }
70 let schema =
71 paimon::arrow::build_target_arrow_schema(&fields).map_err(to_datafusion_error)?;
72 Ok(Self { table, schema })
73 }
74
75 pub fn table(&self) -> &Table {
76 &self.table
77 }
78}
79
80pub(crate) fn bucket_round_robin<T>(items: Vec<T>, num_buckets: usize) -> Vec<Vec<T>> {
82 let mut buckets: Vec<Vec<T>> = (0..num_buckets).map(|_| Vec::new()).collect();
83 for (i, item) in items.into_iter().enumerate() {
84 buckets[i % num_buckets].push(item);
85 }
86 buckets
87}
88
89pub(crate) fn build_paimon_scan(
94 table: &Table,
95 schema: &ArrowSchemaRef,
96 plan: &paimon::table::Plan,
97 projection: Option<&Vec<usize>>,
98 pushed_predicate: Option<paimon::spec::Predicate>,
99 limit: Option<usize>,
100 target_partitions: usize,
101) -> DFResult<Arc<dyn ExecutionPlan>> {
102 let (projected_schema, projected_columns) = if let Some(indices) = projection {
103 let fields: Vec<Field> = indices.iter().map(|&i| schema.field(i).clone()).collect();
104 let column_names: Vec<String> = fields.iter().map(|f| f.name().clone()).collect();
105 (Arc::new(Schema::new(fields)), Some(column_names))
106 } else {
107 let column_names: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
108 (schema.clone(), Some(column_names))
109 };
110
111 let splits = plan.splits().to_vec();
112 let planned_partitions: Vec<Arc<[_]>> = if splits.is_empty() {
113 vec![Arc::from(Vec::new())]
114 } else {
115 let num_partitions = splits.len().min(target_partitions.max(1));
116 bucket_round_robin(splits, num_partitions)
117 .into_iter()
118 .map(Arc::from)
119 .collect()
120 };
121
122 Ok(Arc::new(PaimonTableScan::new(
123 projected_schema,
124 table.clone(),
125 projected_columns,
126 pushed_predicate,
127 planned_partitions,
128 limit,
129 )))
130}
131
132#[async_trait]
133impl TableProvider for PaimonTableProvider {
134 fn as_any(&self) -> &dyn Any {
135 self
136 }
137
138 fn schema(&self) -> ArrowSchemaRef {
139 self.schema.clone()
140 }
141
142 fn table_type(&self) -> TableType {
143 TableType::Base
144 }
145
146 async fn scan(
147 &self,
148 state: &dyn Session,
149 projection: Option<&Vec<usize>>,
150 filters: &[Expr],
151 limit: Option<usize>,
152 ) -> DFResult<Arc<dyn ExecutionPlan>> {
153 let pushed_predicate = build_pushed_predicate(filters, self.table.schema().fields());
155 let mut read_builder = self.table.new_read_builder();
156 if let Some(filter) = pushed_predicate.clone() {
157 read_builder.with_filter(filter);
158 }
159 if let Some(limit) = limit {
162 read_builder.with_limit(limit);
163 }
164 let scan = read_builder.new_scan();
165 let plan = await_with_runtime(scan.plan())
170 .await
171 .map_err(to_datafusion_error)?;
172
173 let target = state.config_options().execution.target_partitions;
174 build_paimon_scan(
175 &self.table,
176 &self.schema,
177 &plan,
178 projection,
179 pushed_predicate,
180 limit,
181 target,
182 )
183 }
184
185 async fn insert_into(
186 &self,
187 _state: &dyn Session,
188 input: Arc<dyn ExecutionPlan>,
189 insert_op: InsertOp,
190 ) -> DFResult<Arc<dyn ExecutionPlan>> {
191 let overwrite = match insert_op {
192 InsertOp::Append => false,
193 InsertOp::Overwrite => true,
194 other => {
195 return Err(datafusion::error::DataFusionError::NotImplemented(format!(
196 "{other} is not supported for Paimon tables"
197 )));
198 }
199 };
200 let sink = PaimonDataSink::new(self.table.clone(), self.schema.clone(), overwrite);
201 Ok(Arc::new(DataSinkExec::new(input, Arc::new(sink), None)))
202 }
203
204 fn supports_filters_pushdown(
205 &self,
206 filters: &[&Expr],
207 ) -> DFResult<Vec<TableProviderFilterPushDown>> {
208 let fields = self.table.schema().fields();
209 let partition_keys = self.table.schema().partition_keys();
210
211 Ok(filters
212 .iter()
213 .map(|filter| classify_filter_pushdown(filter, fields, partition_keys))
214 .collect())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use std::collections::BTreeSet;
222 use std::sync::Arc;
223
224 use datafusion::datasource::TableProvider;
225 use datafusion::logical_expr::{col, lit, Expr};
226 use datafusion::prelude::{SessionConfig, SessionContext};
227 use paimon::catalog::Identifier;
228 use paimon::{Catalog, CatalogOptions, DataSplit, FileSystemCatalog, Options};
229
230 use crate::physical_plan::PaimonTableScan;
231
232 #[test]
233 fn test_bucket_round_robin_distributes_evenly() {
234 let result = bucket_round_robin(vec![0, 1, 2, 3, 4], 3);
235 assert_eq!(result, vec![vec![0, 3], vec![1, 4], vec![2]]);
236 }
237
238 #[test]
239 fn test_bucket_round_robin_fewer_items_than_buckets() {
240 let result = bucket_round_robin(vec![10, 20], 2);
241 assert_eq!(result, vec![vec![10], vec![20]]);
242 }
243
244 #[test]
245 fn test_bucket_round_robin_single_bucket() {
246 let result = bucket_round_robin(vec![1, 2, 3], 1);
247 assert_eq!(result, vec![vec![1, 2, 3]]);
248 }
249
250 fn get_test_warehouse() -> String {
251 std::env::var("PAIMON_TEST_WAREHOUSE")
252 .unwrap_or_else(|_| "/tmp/paimon-warehouse".to_string())
253 }
254
255 fn create_catalog() -> FileSystemCatalog {
256 let warehouse = get_test_warehouse();
257 let mut options = Options::new();
258 options.set(CatalogOptions::WAREHOUSE, warehouse);
259 FileSystemCatalog::new(options).expect("Failed to create catalog")
260 }
261
262 async fn create_provider(table_name: &str) -> PaimonTableProvider {
263 let catalog = create_catalog();
264 let identifier = Identifier::new("default", table_name);
265 let table = catalog
266 .get_table(&identifier)
267 .await
268 .expect("Failed to get table");
269
270 PaimonTableProvider::try_new(table).expect("Failed to create table provider")
271 }
272
273 async fn plan_partitions(
274 provider: &PaimonTableProvider,
275 filters: Vec<Expr>,
276 ) -> Vec<Arc<[DataSplit]>> {
277 let config = SessionConfig::new().with_target_partitions(8);
278 let ctx = SessionContext::new_with_config(config);
279 let state = ctx.state();
280 let plan = provider
281 .scan(&state, None, &filters, None)
282 .await
283 .expect("scan() should succeed");
284 let scan = plan
285 .as_any()
286 .downcast_ref::<PaimonTableScan>()
287 .expect("Expected PaimonTableScan");
288
289 scan.planned_partitions().to_vec()
290 }
291
292 fn extract_dt_partition_set(planned_partitions: &[Arc<[DataSplit]>]) -> BTreeSet<String> {
293 planned_partitions
294 .iter()
295 .flat_map(|splits| splits.iter())
296 .map(|split| {
297 split
298 .partition()
299 .get_string(0)
300 .expect("Failed to decode dt")
301 .to_string()
302 })
303 .collect()
304 }
305
306 fn extract_dt_hr_partition_set(
307 planned_partitions: &[Arc<[DataSplit]>],
308 ) -> BTreeSet<(String, i32)> {
309 planned_partitions
310 .iter()
311 .flat_map(|splits| splits.iter())
312 .map(|split| {
313 let partition = split.partition();
314 (
315 partition
316 .get_string(0)
317 .expect("Failed to decode dt")
318 .to_string(),
319 partition.get_int(1).expect("Failed to decode hr"),
320 )
321 })
322 .collect()
323 }
324
325 #[tokio::test]
326 async fn test_scan_partition_filter_plans_matching_partition_set() {
327 let provider = create_provider("partitioned_log_table").await;
328 let planned_partitions =
329 plan_partitions(&provider, vec![col("dt").eq(lit("2024-01-01"))]).await;
330
331 assert_eq!(
332 extract_dt_partition_set(&planned_partitions),
333 BTreeSet::from(["2024-01-01".to_string()]),
334 );
335 }
336
337 #[tokio::test]
338 async fn test_scan_mixed_and_filter_keeps_partition_pruning() {
339 let provider = create_provider("partitioned_log_table").await;
340 let planned_partitions = plan_partitions(
341 &provider,
342 vec![col("dt").eq(lit("2024-01-01")).and(col("id").gt(lit(1)))],
343 )
344 .await;
345
346 assert_eq!(
347 extract_dt_partition_set(&planned_partitions),
348 BTreeSet::from(["2024-01-01".to_string()]),
349 );
350 }
351
352 #[tokio::test]
353 async fn test_scan_multi_partition_filter_plans_exact_partition_set() {
354 let provider = create_provider("multi_partitioned_log_table").await;
355
356 let dt_only_partitions =
357 plan_partitions(&provider, vec![col("dt").eq(lit("2024-01-01"))]).await;
358 let dt_hr_partitions = plan_partitions(
359 &provider,
360 vec![col("dt").eq(lit("2024-01-01")).and(col("hr").eq(lit(10)))],
361 )
362 .await;
363
364 assert_eq!(
365 extract_dt_hr_partition_set(&dt_only_partitions),
366 BTreeSet::from([
367 ("2024-01-01".to_string(), 10),
368 ("2024-01-01".to_string(), 20),
369 ]),
370 );
371 assert_eq!(
372 extract_dt_hr_partition_set(&dt_hr_partitions),
373 BTreeSet::from([("2024-01-01".to_string(), 10)]),
374 );
375 }
376
377 #[tokio::test]
378 async fn test_scan_keeps_pushed_predicate_for_execute() {
379 let provider = create_provider("partitioned_log_table").await;
380 let filter = col("id").gt(lit(1));
381
382 let config = SessionConfig::new().with_target_partitions(8);
383 let ctx = SessionContext::new_with_config(config);
384 let state = ctx.state();
385 let plan = provider
386 .scan(&state, None, std::slice::from_ref(&filter), None)
387 .await
388 .expect("scan() should succeed");
389 let scan = plan
390 .as_any()
391 .downcast_ref::<PaimonTableScan>()
392 .expect("Expected PaimonTableScan");
393
394 let expected = build_pushed_predicate(&[filter], provider.table().schema().fields())
395 .expect("data filter should translate");
396
397 assert_eq!(scan.pushed_predicate(), Some(&expected));
398 }
399
400 #[tokio::test]
401 async fn test_insert_into_and_read_back() {
402 use paimon::io::FileIOBuilder;
403 use paimon::spec::{DataType, IntType, Schema as PaimonSchema, TableSchema};
404
405 let file_io = FileIOBuilder::new("memory").build().unwrap();
406 let table_path = "memory:/test_df_insert_into";
407 file_io
408 .mkdirs(&format!("{table_path}/snapshot/"))
409 .await
410 .unwrap();
411 file_io
412 .mkdirs(&format!("{table_path}/manifest/"))
413 .await
414 .unwrap();
415
416 let schema = PaimonSchema::builder()
417 .column("id", DataType::Int(IntType::new()))
418 .column("value", DataType::Int(IntType::new()))
419 .build()
420 .unwrap();
421 let table_schema = TableSchema::new(0, &schema);
422 let table = paimon::table::Table::new(
423 file_io,
424 Identifier::new("default", "test_insert"),
425 table_path.to_string(),
426 table_schema,
427 None,
428 );
429
430 let provider = PaimonTableProvider::try_new(table).unwrap();
431 let ctx = SessionContext::new();
432 ctx.register_table("t", Arc::new(provider)).unwrap();
433
434 let result = ctx
436 .sql("INSERT INTO t VALUES (1, 10), (2, 20), (3, 30)")
437 .await
438 .unwrap()
439 .collect()
440 .await
441 .unwrap();
442
443 let count_array = result[0]
445 .column(0)
446 .as_any()
447 .downcast_ref::<datafusion::arrow::array::UInt64Array>()
448 .unwrap();
449 assert_eq!(count_array.value(0), 3);
450
451 let batches = ctx
453 .sql("SELECT id, value FROM t ORDER BY id")
454 .await
455 .unwrap()
456 .collect()
457 .await
458 .unwrap();
459
460 let mut rows = Vec::new();
461 for batch in &batches {
462 let ids = batch
463 .column(0)
464 .as_any()
465 .downcast_ref::<datafusion::arrow::array::Int32Array>()
466 .unwrap();
467 let vals = batch
468 .column(1)
469 .as_any()
470 .downcast_ref::<datafusion::arrow::array::Int32Array>()
471 .unwrap();
472 for i in 0..batch.num_rows() {
473 rows.push((ids.value(i), vals.value(i)));
474 }
475 }
476 assert_eq!(rows, vec![(1, 10), (2, 20), (3, 30)]);
477 }
478
479 #[tokio::test]
480 async fn test_insert_overwrite() {
481 use paimon::io::FileIOBuilder;
482 use paimon::spec::{DataType, IntType, Schema as PaimonSchema, TableSchema, VarCharType};
483
484 let file_io = FileIOBuilder::new("memory").build().unwrap();
485 let table_path = "memory:/test_df_insert_overwrite";
486 file_io
487 .mkdirs(&format!("{table_path}/snapshot/"))
488 .await
489 .unwrap();
490 file_io
491 .mkdirs(&format!("{table_path}/manifest/"))
492 .await
493 .unwrap();
494
495 let schema = PaimonSchema::builder()
496 .column("pt", DataType::VarChar(VarCharType::string_type()))
497 .column("id", DataType::Int(IntType::new()))
498 .partition_keys(["pt"])
499 .build()
500 .unwrap();
501 let table_schema = TableSchema::new(0, &schema);
502 let table = paimon::table::Table::new(
503 file_io,
504 Identifier::new("default", "test_overwrite"),
505 table_path.to_string(),
506 table_schema,
507 None,
508 );
509
510 let provider = PaimonTableProvider::try_new(table).unwrap();
511 let ctx = SessionContext::new();
512 ctx.register_table("t", Arc::new(provider)).unwrap();
513
514 ctx.sql("INSERT INTO t VALUES ('a', 1), ('a', 2), ('b', 3), ('b', 4)")
516 .await
517 .unwrap()
518 .collect()
519 .await
520 .unwrap();
521
522 ctx.sql("INSERT OVERWRITE t VALUES ('a', 10), ('a', 20)")
525 .await
526 .unwrap()
527 .collect()
528 .await
529 .unwrap();
530
531 let batches = ctx
533 .sql("SELECT pt, id FROM t ORDER BY pt, id")
534 .await
535 .unwrap()
536 .collect()
537 .await
538 .unwrap();
539
540 let mut rows = Vec::new();
541 for batch in &batches {
542 let pts = batch
543 .column(0)
544 .as_any()
545 .downcast_ref::<datafusion::arrow::array::StringArray>()
546 .unwrap();
547 let ids = batch
548 .column(1)
549 .as_any()
550 .downcast_ref::<datafusion::arrow::array::Int32Array>()
551 .unwrap();
552 for i in 0..batch.num_rows() {
553 rows.push((pts.value(i).to_string(), ids.value(i)));
554 }
555 }
556 assert_eq!(
558 rows,
559 vec![
560 ("a".to_string(), 10),
561 ("a".to_string(), 20),
562 ("b".to_string(), 3),
563 ("b".to_string(), 4),
564 ]
565 );
566 }
567
568 #[tokio::test]
569 async fn test_insert_overwrite_unpartitioned() {
570 use paimon::io::FileIOBuilder;
571 use paimon::spec::{DataType, IntType, Schema as PaimonSchema, TableSchema};
572
573 let file_io = FileIOBuilder::new("memory").build().unwrap();
574 let table_path = "memory:/test_df_insert_overwrite_unpart";
575 file_io
576 .mkdirs(&format!("{table_path}/snapshot/"))
577 .await
578 .unwrap();
579 file_io
580 .mkdirs(&format!("{table_path}/manifest/"))
581 .await
582 .unwrap();
583
584 let schema = PaimonSchema::builder()
585 .column("id", DataType::Int(IntType::new()))
586 .column("value", DataType::Int(IntType::new()))
587 .build()
588 .unwrap();
589 let table_schema = TableSchema::new(0, &schema);
590 let table = paimon::table::Table::new(
591 file_io,
592 Identifier::new("default", "test_overwrite_unpart"),
593 table_path.to_string(),
594 table_schema,
595 None,
596 );
597
598 let provider = PaimonTableProvider::try_new(table).unwrap();
599 let ctx = SessionContext::new();
600 ctx.register_table("t", Arc::new(provider)).unwrap();
601
602 ctx.sql("INSERT INTO t VALUES (1, 10), (2, 20), (3, 30)")
604 .await
605 .unwrap()
606 .collect()
607 .await
608 .unwrap();
609
610 ctx.sql("INSERT OVERWRITE t VALUES (4, 40), (5, 50)")
612 .await
613 .unwrap()
614 .collect()
615 .await
616 .unwrap();
617
618 let batches = ctx
619 .sql("SELECT id, value FROM t ORDER BY id")
620 .await
621 .unwrap()
622 .collect()
623 .await
624 .unwrap();
625
626 let mut rows = Vec::new();
627 for batch in &batches {
628 let ids = batch
629 .column(0)
630 .as_any()
631 .downcast_ref::<datafusion::arrow::array::Int32Array>()
632 .unwrap();
633 let vals = batch
634 .column(1)
635 .as_any()
636 .downcast_ref::<datafusion::arrow::array::Int32Array>()
637 .unwrap();
638 for i in 0..batch.num_rows() {
639 rows.push((ids.value(i), vals.value(i)));
640 }
641 }
642 assert_eq!(rows, vec![(4, 40), (5, 50)]);
644 }
645}