datafusion_remote_table/
codec.rs

1use crate::generated::prost as protobuf;
2use crate::{
3    connect, ConnectionOptions, DFResult, MysqlConnectionOptions, OracleConnectionOptions,
4    PostgresConnectionOptions, RemoteTableExec, Transform,
5};
6use datafusion::arrow::datatypes::SchemaRef;
7use datafusion::common::DataFusionError;
8use datafusion::execution::FunctionRegistry;
9use datafusion::physical_plan::ExecutionPlan;
10use datafusion_proto::convert_required;
11use datafusion_proto::physical_plan::PhysicalExtensionCodec;
12use datafusion_proto::protobuf::proto_error;
13use prost::Message;
14use std::fmt::Debug;
15use std::path::Path;
16use std::sync::Arc;
17
18pub trait TransformCodec: Debug + Send + Sync {
19    fn try_encode(&self, value: &dyn Transform) -> DFResult<Vec<u8>>;
20    fn try_decode(&self, value: &[u8]) -> DFResult<Arc<dyn Transform>>;
21}
22
23#[derive(Debug)]
24pub struct RemotePhysicalCodec {
25    transform_codec: Option<Arc<dyn TransformCodec>>,
26}
27
28impl RemotePhysicalCodec {
29    pub fn new(transform_codec: Option<Arc<dyn TransformCodec>>) -> Self {
30        Self { transform_codec }
31    }
32}
33
34impl PhysicalExtensionCodec for RemotePhysicalCodec {
35    fn try_decode(
36        &self,
37        buf: &[u8],
38        _inputs: &[Arc<dyn ExecutionPlan>],
39        _registry: &dyn FunctionRegistry,
40    ) -> DFResult<Arc<dyn ExecutionPlan>> {
41        let proto = protobuf::RemoteTableExec::decode(buf).map_err(|e| {
42            DataFusionError::Internal(format!(
43                "Failed to decode remote table execution plan: {e:?}"
44            ))
45        })?;
46
47        let transform = if let Some(bytes) = proto.transform {
48            let Some(transform_codec) = self.transform_codec.as_ref() else {
49                return Err(DataFusionError::Execution(
50                    "No transform codec found".to_string(),
51                ));
52            };
53            Some(transform_codec.try_decode(&bytes)?)
54        } else {
55            None
56        };
57
58        let projected_schema: SchemaRef = Arc::new(convert_required!(&proto.projected_schema)?);
59
60        let projection: Option<Vec<usize>> = proto
61            .projection
62            .map(|p| p.projection.iter().map(|n| *n as usize).collect());
63
64        let conn_options = parse_connection_options(proto.conn_options.unwrap());
65        let conn = tokio::task::block_in_place(|| {
66            tokio::runtime::Handle::current().block_on(async {
67                let pool = connect(&conn_options).await?;
68                let conn = pool.get().await?;
69                Ok::<_, DataFusionError>(conn)
70            })
71        })?;
72
73        Ok(Arc::new(RemoteTableExec::new(
74            conn_options,
75            projected_schema,
76            proto.sql,
77            projection,
78            transform,
79            conn,
80        )))
81    }
82
83    fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> DFResult<()> {
84        if let Some(exec) = node.as_any().downcast_ref::<RemoteTableExec>() {
85            let serialized_transform = if let Some(transform) = exec.transform.as_ref() {
86                let Some(transform_codec) = self.transform_codec.as_ref() else {
87                    return Err(DataFusionError::Execution(
88                        "No transform codec found".to_string(),
89                    ));
90                };
91                let bytes = transform_codec.try_encode(transform.as_ref())?;
92                Some(bytes)
93            } else {
94                None
95            };
96
97            let serialized_connection_options = serialize_connection_options(&exec.conn_options);
98
99            let proto = protobuf::RemoteTableExec {
100                conn_options: Some(serialized_connection_options),
101                sql: exec.sql.clone(),
102                projected_schema: Some(exec.schema().as_ref().try_into()?),
103                projection: exec
104                    .projection
105                    .as_ref()
106                    .map(|p| serialize_projection(p.as_slice())),
107                transform: serialized_transform,
108            };
109
110            proto.encode(buf).map_err(|e| {
111                DataFusionError::Internal(format!(
112                    "Failed to encode remote table execution plan: {e:?}"
113                ))
114            })?;
115            Ok(())
116        } else {
117            Err(DataFusionError::Execution(format!(
118                "Failed to encode {}",
119                RemoteTableExec::static_name()
120            )))
121        }
122    }
123}
124
125fn serialize_connection_options(options: &ConnectionOptions) -> protobuf::ConnectionOptions {
126    match options {
127        ConnectionOptions::Postgres(options) => protobuf::ConnectionOptions {
128            connection_options: Some(protobuf::connection_options::ConnectionOptions::Postgres(
129                protobuf::PostgresConnectionOptions {
130                    host: options.host.clone(),
131                    port: options.port as u32,
132                    username: options.username.clone(),
133                    password: options.password.clone(),
134                    database: options.database.clone(),
135                },
136            )),
137        },
138        ConnectionOptions::Mysql(options) => protobuf::ConnectionOptions {
139            connection_options: Some(protobuf::connection_options::ConnectionOptions::Mysql(
140                protobuf::MysqlConnectionOptions {
141                    host: options.host.clone(),
142                    port: options.port as u32,
143                    username: options.username.clone(),
144                    password: options.password.clone(),
145                    database: options.database.clone(),
146                },
147            )),
148        },
149        ConnectionOptions::Oracle(options) => protobuf::ConnectionOptions {
150            connection_options: Some(protobuf::connection_options::ConnectionOptions::Oracle(
151                protobuf::OracleConnectionOptions {
152                    host: options.host.clone(),
153                    port: options.port as u32,
154                    username: options.username.clone(),
155                    password: options.password.clone(),
156                    database: options.database.clone(),
157                },
158            )),
159        },
160        ConnectionOptions::Sqlite(path) => protobuf::ConnectionOptions {
161            connection_options: Some(protobuf::connection_options::ConnectionOptions::Sqlite(
162                protobuf::SqliteConnectionOptions {
163                    path: path.to_str().unwrap().to_string(),
164                },
165            )),
166        },
167    }
168}
169
170fn parse_connection_options(options: protobuf::ConnectionOptions) -> ConnectionOptions {
171    match options.connection_options {
172        Some(protobuf::connection_options::ConnectionOptions::Postgres(options)) => {
173            ConnectionOptions::Postgres(PostgresConnectionOptions {
174                host: options.host,
175                port: options.port as u16,
176                username: options.username,
177                password: options.password,
178                database: options.database,
179            })
180        }
181        Some(protobuf::connection_options::ConnectionOptions::Mysql(options)) => {
182            ConnectionOptions::Mysql(MysqlConnectionOptions {
183                host: options.host,
184                port: options.port as u16,
185                username: options.username,
186                password: options.password,
187                database: options.database,
188            })
189        }
190        Some(protobuf::connection_options::ConnectionOptions::Oracle(options)) => {
191            ConnectionOptions::Oracle(OracleConnectionOptions {
192                host: options.host,
193                port: options.port as u16,
194                username: options.username,
195                password: options.password,
196                database: options.database,
197            })
198        }
199        Some(protobuf::connection_options::ConnectionOptions::Sqlite(options)) => {
200            ConnectionOptions::Sqlite(Path::new(&options.path).to_path_buf())
201        }
202        _ => panic!("Failed to parse connection options: {options:?}"),
203    }
204}
205
206fn serialize_projection(projection: &[usize]) -> protobuf::Projection {
207    protobuf::Projection {
208        projection: projection.iter().map(|n| *n as u32).collect(),
209    }
210}