datafusion_remote_table/
transform.rs1use crate::{DFResult, RemoteDbType, RemoteSchemaRef};
2use arrow::array::RecordBatch;
3use arrow::datatypes::SchemaRef;
4use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
5use datafusion_common::{DataFusionError, project_schema};
6use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
7use datafusion_expr::Expr;
8use datafusion_expr::TableProviderFilterPushDown;
9use futures::{Stream, StreamExt};
10use std::any::Any;
11use std::fmt::Debug;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15
16#[derive(Debug, Clone)]
17pub struct TransformArgs<'a> {
18 pub db_type: RemoteDbType,
19 pub table_schema: &'a SchemaRef,
20 pub remote_schema: &'a RemoteSchemaRef,
21}
22
23pub trait Transform: Debug + Send + Sync {
24 fn as_any(&self) -> &dyn Any;
25
26 fn transform(&self, batch: RecordBatch, args: TransformArgs) -> DFResult<RecordBatch>;
27
28 fn support_filter_pushdown(
29 &self,
30 filter: &Expr,
31 args: TransformArgs,
32 ) -> DFResult<TableProviderFilterPushDown>;
33
34 fn unparse_filter(&self, filter: &Expr, args: TransformArgs) -> DFResult<String>;
35}
36
37#[derive(Debug)]
38pub struct DefaultTransform {}
39
40impl Transform for DefaultTransform {
41 fn as_any(&self) -> &dyn Any {
42 self
43 }
44
45 fn transform(&self, batch: RecordBatch, _args: TransformArgs) -> DFResult<RecordBatch> {
46 Ok(batch)
47 }
48
49 fn support_filter_pushdown(
50 &self,
51 filter: &Expr,
52 args: TransformArgs,
53 ) -> DFResult<TableProviderFilterPushDown> {
54 let unparser = match args.db_type.create_unparser() {
55 Ok(unparser) => unparser,
56 Err(_) => return Ok(TableProviderFilterPushDown::Unsupported),
57 };
58 if unparser.expr_to_sql(filter).is_err() {
59 return Ok(TableProviderFilterPushDown::Unsupported);
60 }
61
62 let mut pushdown = TableProviderFilterPushDown::Exact;
63 filter
64 .apply(|e| {
65 if matches!(e, Expr::ScalarFunction(_)) {
66 pushdown = TableProviderFilterPushDown::Unsupported;
67 }
68 Ok(TreeNodeRecursion::Continue)
69 })
70 .expect("won't fail");
71
72 Ok(pushdown)
73 }
74
75 fn unparse_filter(&self, filter: &Expr, args: TransformArgs) -> DFResult<String> {
76 let unparser = args.db_type.create_unparser()?;
77 let ast = unparser.expr_to_sql(filter)?;
78 Ok(format!("{ast}"))
79 }
80}
81
82pub(crate) struct TransformStream {
83 input: SendableRecordBatchStream,
84 transform: Arc<dyn Transform>,
85 table_schema: SchemaRef,
86 projection: Option<Vec<usize>>,
87 projected_transformed_schema: SchemaRef,
88 remote_schema: RemoteSchemaRef,
89 db_type: RemoteDbType,
90}
91
92impl TransformStream {
93 pub fn try_new(
94 input: SendableRecordBatchStream,
95 transform: Arc<dyn Transform>,
96 table_schema: SchemaRef,
97 projection: Option<Vec<usize>>,
98 remote_schema: RemoteSchemaRef,
99 db_type: RemoteDbType,
100 ) -> DFResult<Self> {
101 let input_schema = input.schema();
102 if input.schema() != table_schema {
103 return Err(DataFusionError::Execution(format!(
104 "Transform stream input schema is not equals to table schema, input schema: {input_schema:?}, table schema: {table_schema:?}"
105 )));
106 }
107 let transformed_table_schema = transform_schema(
108 transform.as_ref(),
109 table_schema.clone(),
110 Some(&remote_schema),
111 db_type,
112 )?;
113 let projected_transformed_schema =
114 project_schema(&transformed_table_schema, projection.as_ref())?;
115 Ok(Self {
116 input,
117 transform,
118 table_schema,
119 projection,
120 projected_transformed_schema,
121 remote_schema,
122 db_type,
123 })
124 }
125}
126
127impl Stream for TransformStream {
128 type Item = DFResult<RecordBatch>;
129 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 match self.input.poll_next_unpin(cx) {
131 Poll::Ready(Some(Ok(batch))) => {
132 let args = TransformArgs {
133 db_type: self.db_type,
134 table_schema: &self.table_schema,
135 remote_schema: &self.remote_schema,
136 };
137 match self.transform.transform(batch, args) {
138 Ok(transformed_batch) => {
139 let projected_batch = if let Some(projection) = &self.projection {
140 match transformed_batch.project(projection) {
141 Ok(batch) => batch,
142 Err(e) => return Poll::Ready(Some(Err(DataFusionError::from(e)))),
143 }
144 } else {
145 transformed_batch
146 };
147 Poll::Ready(Some(Ok(projected_batch)))
148 }
149 Err(e) => Poll::Ready(Some(Err(e))),
150 }
151 }
152 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
153 Poll::Ready(None) => Poll::Ready(None),
154 Poll::Pending => Poll::Pending,
155 }
156 }
157}
158
159impl RecordBatchStream for TransformStream {
160 fn schema(&self) -> SchemaRef {
161 self.projected_transformed_schema.clone()
162 }
163}
164
165pub fn transform_schema(
166 transform: &dyn Transform,
167 schema: SchemaRef,
168 remote_schema: Option<&RemoteSchemaRef>,
169 db_type: RemoteDbType,
170) -> DFResult<SchemaRef> {
171 if transform.as_any().is::<DefaultTransform>() {
172 Ok(schema)
173 } else {
174 let Some(remote_schema) = remote_schema else {
175 return Err(DataFusionError::Execution(
176 "remote_schema is required for non-default transform".to_string(),
177 ));
178 };
179 let empty_batch = RecordBatch::new_empty(schema.clone());
180 let args = TransformArgs {
181 db_type,
182 table_schema: &schema,
183 remote_schema,
184 };
185 let transformed_batch = transform.transform(empty_batch, args)?;
186 Ok(transformed_batch.schema())
187 }
188}