datafusion_remote_table/
transform.rs

1use crate::{DFResult, RemoteDbType, RemoteSchemaRef};
2use datafusion::arrow::array::RecordBatch;
3use datafusion::arrow::datatypes::SchemaRef;
4use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
5use datafusion::common::{DataFusionError, project_schema};
6use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
7use datafusion::logical_expr::TableProviderFilterPushDown;
8use datafusion::prelude::Expr;
9use futures::{Stream, StreamExt};
10use std::any::Any;
11use std::fmt::Debug;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15
16pub struct TransformArgs<'a> {
17    pub db_type: RemoteDbType,
18    pub table_schema: &'a SchemaRef,
19    pub remote_schema: &'a RemoteSchemaRef,
20}
21
22pub trait Transform: Debug + Send + Sync {
23    fn as_any(&self) -> &dyn Any;
24
25    fn transform(&self, batch: RecordBatch, args: TransformArgs) -> DFResult<RecordBatch>;
26
27    fn support_filter_pushdown(
28        &self,
29        filter: &Expr,
30        args: TransformArgs,
31    ) -> DFResult<TableProviderFilterPushDown>;
32
33    fn unparse_filter(&self, filter: &Expr, args: TransformArgs) -> DFResult<String>;
34}
35
36#[derive(Debug)]
37pub struct DefaultTransform {}
38
39impl Transform for DefaultTransform {
40    fn as_any(&self) -> &dyn Any {
41        self
42    }
43
44    fn transform(&self, batch: RecordBatch, _args: TransformArgs) -> DFResult<RecordBatch> {
45        Ok(batch)
46    }
47
48    fn support_filter_pushdown(
49        &self,
50        filter: &Expr,
51        args: TransformArgs,
52    ) -> DFResult<TableProviderFilterPushDown> {
53        let unparser = match args.db_type.create_unparser() {
54            Ok(unparser) => unparser,
55            Err(_) => return Ok(TableProviderFilterPushDown::Unsupported),
56        };
57        if unparser.expr_to_sql(filter).is_err() {
58            return Ok(TableProviderFilterPushDown::Unsupported);
59        }
60
61        let mut pushdown = TableProviderFilterPushDown::Exact;
62        filter
63            .apply(|e| {
64                if matches!(e, Expr::ScalarFunction(_)) {
65                    pushdown = TableProviderFilterPushDown::Unsupported;
66                }
67                Ok(TreeNodeRecursion::Continue)
68            })
69            .expect("won't fail");
70
71        Ok(pushdown)
72    }
73
74    fn unparse_filter(&self, filter: &Expr, args: TransformArgs) -> DFResult<String> {
75        let unparser = args.db_type.create_unparser()?;
76        let ast = unparser.expr_to_sql(filter)?;
77        Ok(format!("{ast}"))
78    }
79}
80
81pub(crate) struct TransformStream {
82    input: SendableRecordBatchStream,
83    transform: Arc<dyn Transform>,
84    table_schema: SchemaRef,
85    projection: Option<Vec<usize>>,
86    projected_transformed_schema: SchemaRef,
87    remote_schema: RemoteSchemaRef,
88    db_type: RemoteDbType,
89}
90
91impl TransformStream {
92    pub fn try_new(
93        input: SendableRecordBatchStream,
94        transform: Arc<dyn Transform>,
95        table_schema: SchemaRef,
96        projection: Option<Vec<usize>>,
97        remote_schema: RemoteSchemaRef,
98        db_type: RemoteDbType,
99    ) -> DFResult<Self> {
100        let input_schema = input.schema();
101        if input.schema() != table_schema {
102            return Err(DataFusionError::Execution(format!(
103                "Transform stream input schema is not equals to table schema, input schema: {input_schema:?}, table schema: {table_schema:?}"
104            )));
105        }
106        let transformed_table_schema = transform_schema(
107            transform.as_ref(),
108            table_schema.clone(),
109            Some(&remote_schema),
110            db_type,
111        )?;
112        let projected_transformed_schema =
113            project_schema(&transformed_table_schema, projection.as_ref())?;
114        Ok(Self {
115            input,
116            transform,
117            table_schema,
118            projection,
119            projected_transformed_schema,
120            remote_schema,
121            db_type,
122        })
123    }
124}
125
126impl Stream for TransformStream {
127    type Item = DFResult<RecordBatch>;
128    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129        match self.input.poll_next_unpin(cx) {
130            Poll::Ready(Some(Ok(batch))) => {
131                let args = TransformArgs {
132                    db_type: self.db_type,
133                    table_schema: &self.table_schema,
134                    remote_schema: &self.remote_schema,
135                };
136                match self.transform.transform(batch, args) {
137                    Ok(transformed_batch) => {
138                        let projected_batch = if let Some(projection) = &self.projection {
139                            match transformed_batch.project(projection) {
140                                Ok(batch) => batch,
141                                Err(e) => return Poll::Ready(Some(Err(DataFusionError::from(e)))),
142                            }
143                        } else {
144                            transformed_batch
145                        };
146                        Poll::Ready(Some(Ok(projected_batch)))
147                    }
148                    Err(e) => Poll::Ready(Some(Err(e))),
149                }
150            }
151            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
152            Poll::Ready(None) => Poll::Ready(None),
153            Poll::Pending => Poll::Pending,
154        }
155    }
156}
157
158impl RecordBatchStream for TransformStream {
159    fn schema(&self) -> SchemaRef {
160        self.projected_transformed_schema.clone()
161    }
162}
163
164pub fn transform_schema(
165    transform: &dyn Transform,
166    schema: SchemaRef,
167    remote_schema: Option<&RemoteSchemaRef>,
168    db_type: RemoteDbType,
169) -> DFResult<SchemaRef> {
170    if transform.as_any().is::<DefaultTransform>() {
171        Ok(schema)
172    } else {
173        let Some(remote_schema) = remote_schema else {
174            return Err(DataFusionError::Execution(
175                "remote_schema is required for non-default transform".to_string(),
176            ));
177        };
178        let empty_batch = RecordBatch::new_empty(schema.clone());
179        let args = TransformArgs {
180            db_type,
181            table_schema: &schema,
182            remote_schema,
183        };
184        let transformed_batch = transform.transform(empty_batch, args)?;
185        Ok(transformed_batch.schema())
186    }
187}