datafusion_remote_table/connection/
mysql.rs1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4 project_remote_schema, Connection, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5 RemoteType, Transform,
6};
7use async_stream::stream;
8use datafusion::arrow::array::{
9 make_builder, ArrayRef, Float32Builder, Float64Builder, Int16Builder, Int32Builder,
10 Int64Builder, Int8Builder, RecordBatch,
11};
12use datafusion::arrow::datatypes::SchemaRef;
13use datafusion::common::{project_schema, DataFusionError};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::lock::Mutex;
17use futures::StreamExt;
18use mysql_async::consts::ColumnType;
19use mysql_async::prelude::Queryable;
20use mysql_async::{Column, Row};
21use std::sync::Arc;
22
23#[derive(Debug, Clone, derive_with::With)]
24pub struct MysqlConnectionOptions {
25 pub(crate) host: String,
26 pub(crate) port: u16,
27 pub(crate) username: String,
28 pub(crate) password: String,
29 pub(crate) database: Option<String>,
30}
31
32impl MysqlConnectionOptions {
33 pub fn new(
34 host: impl Into<String>,
35 port: u16,
36 username: impl Into<String>,
37 password: impl Into<String>,
38 ) -> Self {
39 Self {
40 host: host.into(),
41 port,
42 username: username.into(),
43 password: password.into(),
44 database: None,
45 }
46 }
47}
48
49#[derive(Debug)]
50pub struct MysqlPool {
51 pool: mysql_async::Pool,
52}
53
54pub fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
55 let opts_builder = mysql_async::OptsBuilder::default()
56 .ip_or_hostname(options.host.clone())
57 .tcp_port(options.port)
58 .user(Some(options.username.clone()))
59 .pass(Some(options.password.clone()))
60 .db_name(options.database.clone());
61 let pool = mysql_async::Pool::new(opts_builder);
62 Ok(MysqlPool { pool })
63}
64
65#[async_trait::async_trait]
66impl Pool for MysqlPool {
67 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
68 let conn = self.pool.get_conn().await.map_err(|e| {
69 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
70 })?;
71 Ok(Arc::new(MysqlConnection {
72 conn: Arc::new(Mutex::new(conn)),
73 }))
74 }
75}
76
77#[derive(Debug)]
78pub struct MysqlConnection {
79 conn: Arc<Mutex<mysql_async::Conn>>,
80}
81
82#[async_trait::async_trait]
83impl Connection for MysqlConnection {
84 async fn infer_schema(
85 &self,
86 sql: &str,
87 transform: Option<Arc<dyn Transform>>,
88 ) -> DFResult<(RemoteSchema, SchemaRef)> {
89 let mut conn = self.conn.lock().await;
90 let conn = &mut *conn;
91 let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
92 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
93 })?;
94 let Some(row) = row else {
95 return Err(DataFusionError::Execution(
96 "No rows returned to infer schema".to_string(),
97 ));
98 };
99 let remote_schema = build_remote_schema(&row)?;
100 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
101 if let Some(transform) = transform {
102 let batch = rows_to_batch(&[row], arrow_schema.clone(), None)?;
103 let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
104 Ok((remote_schema, transformed_batch.schema()))
105 } else {
106 Ok((remote_schema, arrow_schema))
107 }
108 }
109
110 async fn query(
111 &self,
112 sql: String,
113 projection: Option<Vec<usize>>,
114 ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
115 let conn = Arc::clone(&self.conn);
116 let mut stream = Box::pin(stream! {
117 let mut conn = conn.lock().await;
118 let mut query_iter = conn
119 .query_iter(sql)
120 .await
121 .map_err(|e| {
122 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
123 })?;
124
125 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
126 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
127 })? else {
128 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
129 return;
130 };
131
132 let mut chunked_stream = stream.chunks(4_000).boxed();
133
134 while let Some(chunk) = chunked_stream.next().await {
135 let rows = chunk
136 .into_iter()
137 .collect::<Result<Vec<_>, _>>()
138 .map_err(|e| {
139 DataFusionError::Execution(format!(
140 "Failed to collect rows from mysql due to {e}",
141 ))
142 })?;
143
144 yield Ok::<_, DataFusionError>(rows)
145 }
146 });
147
148 let Some(first_chunk) = stream.next().await else {
149 return Err(DataFusionError::Execution(
150 "No data returned from mysql".to_string(),
151 ));
152 };
153 let first_chunk = first_chunk?;
154
155 let Some(first_row) = first_chunk.first() else {
156 return Err(DataFusionError::Execution(
157 "No data returned from mysql".to_string(),
158 ));
159 };
160
161 let remote_schema = build_remote_schema(first_row)?;
162 let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
163 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
164 let first_chunk = rows_to_batch(
165 first_chunk.as_slice(),
166 arrow_schema.clone(),
167 projection.as_ref(),
168 )?;
169 let schema = first_chunk.schema();
170
171 let mut stream = stream.map(move |rows| {
172 let rows = rows?;
173 let batch = rows_to_batch(rows.as_slice(), arrow_schema.clone(), projection.as_ref())?;
174 Ok::<RecordBatch, DataFusionError>(batch)
175 });
176
177 let output_stream = async_stream::stream! {
178 yield Ok(first_chunk);
179 while let Some(batch) = stream.next().await {
180 yield batch
181 }
182 };
183
184 Ok((
185 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
186 projected_remote_schema,
187 ))
188 }
189}
190
191fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
192 match mysql_col.column_type() {
193 ColumnType::MYSQL_TYPE_TINY => Ok(RemoteType::Mysql(MysqlType::TinyInt)),
194 ColumnType::MYSQL_TYPE_SHORT => Ok(RemoteType::Mysql(MysqlType::SmallInt)),
195 ColumnType::MYSQL_TYPE_LONG => Ok(RemoteType::Mysql(MysqlType::Integer)),
196 ColumnType::MYSQL_TYPE_LONGLONG => Ok(RemoteType::Mysql(MysqlType::BigInt)),
197 ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
198 ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
199 _ => Err(DataFusionError::NotImplemented(format!(
200 "Unsupported mysql type: {mysql_col:?}",
201 ))),
202 }
203}
204
205fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
206 let mut remote_fields = vec![];
207 for col in row.columns_ref() {
208 remote_fields.push(RemoteField::new(
209 col.name_str().to_string(),
210 mysql_type_to_remote_type(col)?,
211 true,
212 ));
213 }
214 Ok(RemoteSchema::new(remote_fields))
215}
216
217macro_rules! handle_primitive_type {
218 ($builder:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
219 let builder = $builder
220 .as_any_mut()
221 .downcast_mut::<$builder_ty>()
222 .expect(concat!(
223 "Failed to downcast builder to ",
224 stringify!($builder_ty),
225 " for ",
226 stringify!($mysql_col)
227 ));
228 let v = $row.get::<$value_ty, usize>($index);
229
230 match v {
231 Some(v) => builder.append_value(v),
232 None => builder.append_null(),
233 }
234 }};
235}
236
237fn rows_to_batch(
238 rows: &[Row],
239 arrow_schema: SchemaRef,
240 projection: Option<&Vec<usize>>,
241) -> DFResult<RecordBatch> {
242 let projected_schema = project_schema(&arrow_schema, projection)?;
243 let mut array_builders = vec![];
244 for field in arrow_schema.fields() {
245 let builder = make_builder(field.data_type(), rows.len());
246 array_builders.push(builder);
247 }
248
249 for row in rows {
250 for (idx, col) in row.columns_ref().iter().enumerate() {
251 if !projections_contains(projection, idx) {
252 continue;
253 }
254 let builder = &mut array_builders[idx];
255 match col.column_type() {
256 ColumnType::MYSQL_TYPE_TINY => {
257 handle_primitive_type!(builder, col, Int8Builder, i8, row, idx);
258 }
259 ColumnType::MYSQL_TYPE_SHORT => {
260 handle_primitive_type!(builder, col, Int16Builder, i16, row, idx);
261 }
262 ColumnType::MYSQL_TYPE_LONG => {
263 handle_primitive_type!(builder, col, Int32Builder, i32, row, idx);
264 }
265 ColumnType::MYSQL_TYPE_LONGLONG => {
266 handle_primitive_type!(builder, col, Int64Builder, i64, row, idx);
267 }
268 ColumnType::MYSQL_TYPE_FLOAT => {
269 handle_primitive_type!(builder, col, Float32Builder, f32, row, idx);
270 }
271 ColumnType::MYSQL_TYPE_DOUBLE => {
272 handle_primitive_type!(builder, col, Float64Builder, f64, row, idx);
273 }
274 _ => {
275 return Err(DataFusionError::NotImplemented(format!(
276 "Unsupported mysql type: {col:?}",
277 )));
278 }
279 }
280 }
281 }
282 let projected_columns = array_builders
283 .into_iter()
284 .enumerate()
285 .filter(|(idx, _)| projections_contains(projection, *idx))
286 .map(|(_, mut builder)| builder.finish())
287 .collect::<Vec<ArrayRef>>();
288 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
289}