datafusion_remote_table/
transform.rs1use crate::{DFResult, RemoteDbType, RemoteSchemaRef};
2use datafusion::arrow::array::RecordBatch;
3use datafusion::arrow::datatypes::SchemaRef;
4use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
5use datafusion::common::{DataFusionError, project_schema};
6use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
7use datafusion::logical_expr::TableProviderFilterPushDown;
8use datafusion::prelude::Expr;
9use futures::{Stream, StreamExt};
10use std::any::Any;
11use std::fmt::Debug;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15
16pub struct TransformArgs<'a> {
17 pub db_type: RemoteDbType,
18 pub table_schema: &'a SchemaRef,
19 pub remote_schema: &'a RemoteSchemaRef,
20}
21
22pub trait Transform: Debug + Send + Sync {
23 fn as_any(&self) -> &dyn Any;
24
25 fn transform(&self, batch: RecordBatch, args: TransformArgs) -> DFResult<RecordBatch>;
26
27 fn support_filter_pushdown(
28 &self,
29 filter: &Expr,
30 args: TransformArgs,
31 ) -> DFResult<TableProviderFilterPushDown>;
32
33 fn unparse_filter(&self, filter: &Expr, args: TransformArgs) -> DFResult<String>;
34}
35
36#[derive(Debug)]
37pub struct DefaultTransform {}
38
39impl Transform for DefaultTransform {
40 fn as_any(&self) -> &dyn Any {
41 self
42 }
43
44 fn transform(&self, batch: RecordBatch, _args: TransformArgs) -> DFResult<RecordBatch> {
45 Ok(batch)
46 }
47
48 fn support_filter_pushdown(
49 &self,
50 filter: &Expr,
51 args: TransformArgs,
52 ) -> DFResult<TableProviderFilterPushDown> {
53 let unparser = match args.db_type.create_unparser() {
54 Ok(unparser) => unparser,
55 Err(_) => return Ok(TableProviderFilterPushDown::Unsupported),
56 };
57 if unparser.expr_to_sql(filter).is_err() {
58 return Ok(TableProviderFilterPushDown::Unsupported);
59 }
60
61 let mut pushdown = TableProviderFilterPushDown::Exact;
62 filter
63 .apply(|e| {
64 if matches!(e, Expr::ScalarFunction(_)) {
65 pushdown = TableProviderFilterPushDown::Unsupported;
66 }
67 Ok(TreeNodeRecursion::Continue)
68 })
69 .expect("won't fail");
70
71 Ok(pushdown)
72 }
73
74 fn unparse_filter(&self, filter: &Expr, args: TransformArgs) -> DFResult<String> {
75 let unparser = args.db_type.create_unparser()?;
76 let ast = unparser.expr_to_sql(filter)?;
77 Ok(format!("{ast}"))
78 }
79}
80
81pub(crate) struct TransformStream {
82 input: SendableRecordBatchStream,
83 transform: Arc<dyn Transform>,
84 table_schema: SchemaRef,
85 projection: Option<Vec<usize>>,
86 projected_transformed_schema: SchemaRef,
87 remote_schema: RemoteSchemaRef,
88 db_type: RemoteDbType,
89}
90
91impl TransformStream {
92 pub fn try_new(
93 input: SendableRecordBatchStream,
94 transform: Arc<dyn Transform>,
95 table_schema: SchemaRef,
96 projection: Option<Vec<usize>>,
97 remote_schema: RemoteSchemaRef,
98 db_type: RemoteDbType,
99 ) -> DFResult<Self> {
100 let input_schema = input.schema();
101 if input.schema() != table_schema {
102 return Err(DataFusionError::Execution(format!(
103 "Transform stream input schema is not equals to table schema, input schema: {input_schema:?}, table schema: {table_schema:?}"
104 )));
105 }
106 let transformed_table_schema = transform_schema(
107 transform.as_ref(),
108 table_schema.clone(),
109 Some(&remote_schema),
110 db_type,
111 )?;
112 let projected_transformed_schema =
113 project_schema(&transformed_table_schema, projection.as_ref())?;
114 Ok(Self {
115 input,
116 transform,
117 table_schema,
118 projection,
119 projected_transformed_schema,
120 remote_schema,
121 db_type,
122 })
123 }
124}
125
126impl Stream for TransformStream {
127 type Item = DFResult<RecordBatch>;
128 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129 match self.input.poll_next_unpin(cx) {
130 Poll::Ready(Some(Ok(batch))) => {
131 let args = TransformArgs {
132 db_type: self.db_type,
133 table_schema: &self.table_schema,
134 remote_schema: &self.remote_schema,
135 };
136 match self.transform.transform(batch, args) {
137 Ok(transformed_batch) => {
138 let projected_batch = if let Some(projection) = &self.projection {
139 match transformed_batch.project(projection) {
140 Ok(batch) => batch,
141 Err(e) => return Poll::Ready(Some(Err(DataFusionError::from(e)))),
142 }
143 } else {
144 transformed_batch
145 };
146 Poll::Ready(Some(Ok(projected_batch)))
147 }
148 Err(e) => Poll::Ready(Some(Err(e))),
149 }
150 }
151 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
152 Poll::Ready(None) => Poll::Ready(None),
153 Poll::Pending => Poll::Pending,
154 }
155 }
156}
157
158impl RecordBatchStream for TransformStream {
159 fn schema(&self) -> SchemaRef {
160 self.projected_transformed_schema.clone()
161 }
162}
163
164pub fn transform_schema(
165 transform: &dyn Transform,
166 schema: SchemaRef,
167 remote_schema: Option<&RemoteSchemaRef>,
168 db_type: RemoteDbType,
169) -> DFResult<SchemaRef> {
170 if transform.as_any().is::<DefaultTransform>() {
171 Ok(schema)
172 } else {
173 let Some(remote_schema) = remote_schema else {
174 return Err(DataFusionError::Execution(
175 "remote_schema is required for non-default transform".to_string(),
176 ));
177 };
178 let empty_batch = RecordBatch::new_empty(schema.clone());
179 let args = TransformArgs {
180 db_type,
181 table_schema: &schema,
182 remote_schema,
183 };
184 let transformed_batch = transform.transform(empty_batch, args)?;
185 Ok(transformed_batch.schema())
186 }
187}