datafusion_remote_table/
codec.rs1use crate::generated::prost as protobuf;
2use crate::{
3 connect, ConnectionOptions, DFResult, MysqlConnectionOptions, OracleConnectionOptions,
4 PostgresConnectionOptions, RemoteTableExec, Transform,
5};
6use datafusion::arrow::datatypes::SchemaRef;
7use datafusion::common::DataFusionError;
8use datafusion::execution::FunctionRegistry;
9use datafusion::physical_plan::ExecutionPlan;
10use datafusion_proto::convert_required;
11use datafusion_proto::physical_plan::PhysicalExtensionCodec;
12use datafusion_proto::protobuf::proto_error;
13use prost::Message;
14use std::fmt::Debug;
15use std::path::Path;
16use std::sync::Arc;
17
18pub trait TransformCodec: Debug + Send + Sync {
19 fn try_encode(&self, value: &dyn Transform) -> DFResult<Vec<u8>>;
20 fn try_decode(&self, value: &[u8]) -> DFResult<Arc<dyn Transform>>;
21}
22
23#[derive(Debug)]
24pub struct RemotePhysicalCodec {
25 transform_codec: Option<Arc<dyn TransformCodec>>,
26}
27
28impl RemotePhysicalCodec {
29 pub fn new(transform_codec: Option<Arc<dyn TransformCodec>>) -> Self {
30 Self { transform_codec }
31 }
32}
33
34impl PhysicalExtensionCodec for RemotePhysicalCodec {
35 fn try_decode(
36 &self,
37 buf: &[u8],
38 _inputs: &[Arc<dyn ExecutionPlan>],
39 _registry: &dyn FunctionRegistry,
40 ) -> DFResult<Arc<dyn ExecutionPlan>> {
41 let proto = protobuf::RemoteTableExec::decode(buf).map_err(|e| {
42 DataFusionError::Internal(format!(
43 "Failed to decode remote table execution plan: {e:?}"
44 ))
45 })?;
46
47 let transform = if let Some(bytes) = proto.transform {
48 let Some(transform_codec) = self.transform_codec.as_ref() else {
49 return Err(DataFusionError::Execution(
50 "No transform codec found".to_string(),
51 ));
52 };
53 Some(transform_codec.try_decode(&bytes)?)
54 } else {
55 None
56 };
57
58 let projected_schema: SchemaRef = Arc::new(convert_required!(&proto.projected_schema)?);
59
60 let projection: Option<Vec<usize>> = proto
61 .projection
62 .map(|p| p.projection.iter().map(|n| *n as usize).collect());
63
64 let conn_options = parse_connection_options(proto.conn_options.unwrap());
65 let conn = tokio::task::block_in_place(|| {
66 tokio::runtime::Handle::current().block_on(async {
67 let pool = connect(&conn_options).await?;
68 let conn = pool.get().await?;
69 Ok::<_, DataFusionError>(conn)
70 })
71 })?;
72
73 Ok(Arc::new(RemoteTableExec::new(
74 conn_options,
75 projected_schema,
76 proto.sql,
77 projection,
78 transform,
79 conn,
80 )))
81 }
82
83 fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> DFResult<()> {
84 if let Some(exec) = node.as_any().downcast_ref::<RemoteTableExec>() {
85 let serialized_transform = if let Some(transform) = exec.transform.as_ref() {
86 let Some(transform_codec) = self.transform_codec.as_ref() else {
87 return Err(DataFusionError::Execution(
88 "No transform codec found".to_string(),
89 ));
90 };
91 let bytes = transform_codec.try_encode(transform.as_ref())?;
92 Some(bytes)
93 } else {
94 None
95 };
96
97 let serialized_connection_options = serialize_connection_options(&exec.conn_options);
98
99 let proto = protobuf::RemoteTableExec {
100 conn_options: Some(serialized_connection_options),
101 sql: exec.sql.clone(),
102 projected_schema: Some(exec.schema().as_ref().try_into()?),
103 projection: exec
104 .projection
105 .as_ref()
106 .map(|p| serialize_projection(p.as_slice())),
107 transform: serialized_transform,
108 };
109
110 proto.encode(buf).map_err(|e| {
111 DataFusionError::Internal(format!(
112 "Failed to encode remote table execution plan: {e:?}"
113 ))
114 })?;
115 Ok(())
116 } else {
117 Err(DataFusionError::Execution(format!(
118 "Failed to encode {}",
119 RemoteTableExec::static_name()
120 )))
121 }
122 }
123}
124
125fn serialize_connection_options(options: &ConnectionOptions) -> protobuf::ConnectionOptions {
126 match options {
127 ConnectionOptions::Postgres(options) => protobuf::ConnectionOptions {
128 connection_options: Some(protobuf::connection_options::ConnectionOptions::Postgres(
129 protobuf::PostgresConnectionOptions {
130 host: options.host.clone(),
131 port: options.port as u32,
132 username: options.username.clone(),
133 password: options.password.clone(),
134 database: options.database.clone(),
135 },
136 )),
137 },
138 ConnectionOptions::Mysql(options) => protobuf::ConnectionOptions {
139 connection_options: Some(protobuf::connection_options::ConnectionOptions::Mysql(
140 protobuf::MysqlConnectionOptions {
141 host: options.host.clone(),
142 port: options.port as u32,
143 username: options.username.clone(),
144 password: options.password.clone(),
145 database: options.database.clone(),
146 },
147 )),
148 },
149 ConnectionOptions::Oracle(options) => protobuf::ConnectionOptions {
150 connection_options: Some(protobuf::connection_options::ConnectionOptions::Oracle(
151 protobuf::OracleConnectionOptions {
152 host: options.host.clone(),
153 port: options.port as u32,
154 username: options.username.clone(),
155 password: options.password.clone(),
156 database: options.database.clone(),
157 },
158 )),
159 },
160 ConnectionOptions::Sqlite(path) => protobuf::ConnectionOptions {
161 connection_options: Some(protobuf::connection_options::ConnectionOptions::Sqlite(
162 protobuf::SqliteConnectionOptions {
163 path: path.to_str().unwrap().to_string(),
164 },
165 )),
166 },
167 }
168}
169
170fn parse_connection_options(options: protobuf::ConnectionOptions) -> ConnectionOptions {
171 match options.connection_options {
172 Some(protobuf::connection_options::ConnectionOptions::Postgres(options)) => {
173 ConnectionOptions::Postgres(PostgresConnectionOptions {
174 host: options.host,
175 port: options.port as u16,
176 username: options.username,
177 password: options.password,
178 database: options.database,
179 })
180 }
181 Some(protobuf::connection_options::ConnectionOptions::Mysql(options)) => {
182 ConnectionOptions::Mysql(MysqlConnectionOptions {
183 host: options.host,
184 port: options.port as u16,
185 username: options.username,
186 password: options.password,
187 database: options.database,
188 })
189 }
190 Some(protobuf::connection_options::ConnectionOptions::Oracle(options)) => {
191 ConnectionOptions::Oracle(OracleConnectionOptions {
192 host: options.host,
193 port: options.port as u16,
194 username: options.username,
195 password: options.password,
196 database: options.database,
197 })
198 }
199 Some(protobuf::connection_options::ConnectionOptions::Sqlite(options)) => {
200 ConnectionOptions::Sqlite(Path::new(&options.path).to_path_buf())
201 }
202 _ => panic!("Failed to parse connection options: {options:?}"),
203 }
204}
205
206fn serialize_projection(projection: &[usize]) -> protobuf::Projection {
207 protobuf::Projection {
208 projection: projection.iter().map(|n| *n as u32).collect(),
209 }
210}