datafusion_remote_table/
transform.rs

1use crate::{DFResult, RemoteDbType, RemoteSchemaRef};
2use arrow::array::RecordBatch;
3use arrow::datatypes::SchemaRef;
4use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
5use datafusion_common::{DataFusionError, project_schema};
6use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
7use datafusion_expr::Expr;
8use datafusion_expr::TableProviderFilterPushDown;
9use futures::{Stream, StreamExt};
10use std::any::Any;
11use std::fmt::Debug;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15
16#[derive(Debug, Clone)]
17pub struct TransformArgs<'a> {
18    pub db_type: RemoteDbType,
19    pub table_schema: &'a SchemaRef,
20    pub remote_schema: &'a RemoteSchemaRef,
21}
22
23pub trait Transform: Debug + Send + Sync {
24    fn as_any(&self) -> &dyn Any;
25
26    fn transform(&self, batch: RecordBatch, args: TransformArgs) -> DFResult<RecordBatch>;
27
28    fn support_filter_pushdown(
29        &self,
30        filter: &Expr,
31        args: TransformArgs,
32    ) -> DFResult<TableProviderFilterPushDown>;
33
34    fn unparse_filter(&self, filter: &Expr, args: TransformArgs) -> DFResult<String>;
35}
36
37#[derive(Debug)]
38pub struct DefaultTransform {}
39
40impl Transform for DefaultTransform {
41    fn as_any(&self) -> &dyn Any {
42        self
43    }
44
45    fn transform(&self, batch: RecordBatch, _args: TransformArgs) -> DFResult<RecordBatch> {
46        Ok(batch)
47    }
48
49    fn support_filter_pushdown(
50        &self,
51        filter: &Expr,
52        args: TransformArgs,
53    ) -> DFResult<TableProviderFilterPushDown> {
54        let unparser = match args.db_type.create_unparser() {
55            Ok(unparser) => unparser,
56            Err(_) => return Ok(TableProviderFilterPushDown::Unsupported),
57        };
58        if unparser.expr_to_sql(filter).is_err() {
59            return Ok(TableProviderFilterPushDown::Unsupported);
60        }
61
62        let mut pushdown = TableProviderFilterPushDown::Exact;
63        filter
64            .apply(|e| {
65                if matches!(e, Expr::ScalarFunction(_)) {
66                    pushdown = TableProviderFilterPushDown::Unsupported;
67                }
68                Ok(TreeNodeRecursion::Continue)
69            })
70            .expect("won't fail");
71
72        Ok(pushdown)
73    }
74
75    fn unparse_filter(&self, filter: &Expr, args: TransformArgs) -> DFResult<String> {
76        let unparser = args.db_type.create_unparser()?;
77        let ast = unparser.expr_to_sql(filter)?;
78        Ok(format!("{ast}"))
79    }
80}
81
82pub(crate) struct TransformStream {
83    input: SendableRecordBatchStream,
84    transform: Arc<dyn Transform>,
85    table_schema: SchemaRef,
86    projection: Option<Vec<usize>>,
87    projected_transformed_schema: SchemaRef,
88    remote_schema: RemoteSchemaRef,
89    db_type: RemoteDbType,
90}
91
92impl TransformStream {
93    pub fn try_new(
94        input: SendableRecordBatchStream,
95        transform: Arc<dyn Transform>,
96        table_schema: SchemaRef,
97        projection: Option<Vec<usize>>,
98        remote_schema: RemoteSchemaRef,
99        db_type: RemoteDbType,
100    ) -> DFResult<Self> {
101        let input_schema = input.schema();
102        if input.schema() != table_schema {
103            return Err(DataFusionError::Execution(format!(
104                "Transform stream input schema is not equals to table schema, input schema: {input_schema:?}, table schema: {table_schema:?}"
105            )));
106        }
107        let transformed_table_schema = transform_schema(
108            transform.as_ref(),
109            table_schema.clone(),
110            Some(&remote_schema),
111            db_type,
112        )?;
113        let projected_transformed_schema =
114            project_schema(&transformed_table_schema, projection.as_ref())?;
115        Ok(Self {
116            input,
117            transform,
118            table_schema,
119            projection,
120            projected_transformed_schema,
121            remote_schema,
122            db_type,
123        })
124    }
125}
126
127impl Stream for TransformStream {
128    type Item = DFResult<RecordBatch>;
129    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130        match self.input.poll_next_unpin(cx) {
131            Poll::Ready(Some(Ok(batch))) => {
132                let args = TransformArgs {
133                    db_type: self.db_type,
134                    table_schema: &self.table_schema,
135                    remote_schema: &self.remote_schema,
136                };
137                match self.transform.transform(batch, args) {
138                    Ok(transformed_batch) => {
139                        let projected_batch = if let Some(projection) = &self.projection {
140                            match transformed_batch.project(projection) {
141                                Ok(batch) => batch,
142                                Err(e) => return Poll::Ready(Some(Err(DataFusionError::from(e)))),
143                            }
144                        } else {
145                            transformed_batch
146                        };
147                        Poll::Ready(Some(Ok(projected_batch)))
148                    }
149                    Err(e) => Poll::Ready(Some(Err(e))),
150                }
151            }
152            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
153            Poll::Ready(None) => Poll::Ready(None),
154            Poll::Pending => Poll::Pending,
155        }
156    }
157}
158
159impl RecordBatchStream for TransformStream {
160    fn schema(&self) -> SchemaRef {
161        self.projected_transformed_schema.clone()
162    }
163}
164
165pub fn transform_schema(
166    transform: &dyn Transform,
167    schema: SchemaRef,
168    remote_schema: Option<&RemoteSchemaRef>,
169    db_type: RemoteDbType,
170) -> DFResult<SchemaRef> {
171    if transform.as_any().is::<DefaultTransform>() {
172        Ok(schema)
173    } else {
174        let Some(remote_schema) = remote_schema else {
175            return Err(DataFusionError::Execution(
176                "remote_schema is required for non-default transform".to_string(),
177            ));
178        };
179        let empty_batch = RecordBatch::new_empty(schema.clone());
180        let args = TransformArgs {
181            db_type,
182            table_schema: &schema,
183            remote_schema,
184        };
185        let transformed_batch = transform.transform(empty_batch, args)?;
186        Ok(transformed_batch.schema())
187    }
188}