1use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder_optional;
2use arrow::{
3 array::{
4 ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder,
5 Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
6 LargeBinaryBuilder, LargeStringBuilder, NullBuilder, RecordBatch, RecordBatchOptions,
7 StringBuilder, StringDictionaryBuilder, Time64NanosecondBuilder,
8 TimestampMicrosecondBuilder, UInt64Builder,
9 },
10 datatypes::{i256, DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit, UInt16Type},
11};
12use bigdecimal::BigDecimal;
13use bigdecimal::ToPrimitive;
14use chrono::{NaiveDate, NaiveTime, Timelike};
15use mysql_async::{consts::ColumnFlags, consts::ColumnType, FromValueError, Row, Value};
16use snafu::{ResultExt, Snafu};
17use std::{convert, sync::Arc};
18use time::PrimitiveDateTime;
19
20#[derive(Debug, Snafu)]
21pub enum Error {
22 #[snafu(display("Failed to build record batch: {source}"))]
23 FailedToBuildRecordBatch {
24 source: datafusion::arrow::error::ArrowError,
25 },
26
27 #[snafu(display("No builder found for index {index}"))]
28 NoBuilderForIndex { index: usize },
29
30 #[snafu(display("Failed to downcast builder for {:?}", mysql_type))]
31 FailedToDowncastBuilder { mysql_type: String },
32
33 #[snafu(display("Integer overflow when converting u64 to i64: {source}"))]
34 FailedToConvertU64toI64 {
35 source: <u64 as convert::TryInto<i64>>::Error,
36 },
37
38 #[snafu(display("Integer overflow when converting u128 to i64: {source}"))]
39 FailedToConvertU128toI64 {
40 source: <u128 as convert::TryInto<i64>>::Error,
41 },
42
43 #[snafu(display("Failed to get a row value for column {column}({mysql_type:?}): {source}"))]
44 FailedToGetRowValue {
45 column: String,
46 mysql_type: ColumnType,
47 source: mysql_async::FromValueError,
48 },
49
50 #[snafu(display("Cannot represent BigDecimal as i128: {big_decimal}"))]
51 FailedToConvertBigDecimalToI128 { big_decimal: BigDecimal },
52
53 #[snafu(display("Failed to find field {column_name} in schema"))]
54 FailedToFindFieldInSchema { column_name: String },
55
56 #[snafu(display("No Arrow field found for index {index}"))]
57 NoArrowFieldForIndex { index: usize },
58
59 #[snafu(display("No column name for index: {index}"))]
60 NoColumnNameForIndex { index: usize },
61}
62
63pub type Result<T, E = Error> = std::result::Result<T, E>;
64
65macro_rules! handle_primitive_type {
66 ($builder:expr, $type:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $column_name:expr) => {{
67 let Some(builder) = $builder else {
68 return NoBuilderForIndexSnafu { index: $index }.fail();
69 };
70 let Some(builder) = builder.as_any_mut().downcast_mut::<$builder_ty>() else {
71 return FailedToDowncastBuilderSnafu {
72 mysql_type: format!("{:?}", $type),
73 }
74 .fail();
75 };
76 let v = handle_null_error($row.get_opt::<$value_ty, usize>($index).transpose()).context(
77 FailedToGetRowValueSnafu {
78 column: $column_name,
79 mysql_type: $type,
80 },
81 )?;
82
83 match v {
84 Some(v) => builder.append_value(v),
85 None => builder.append_null(),
86 }
87 }};
88}
89
90#[allow(clippy::too_many_lines)]
97pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option<SchemaRef>) -> Result<RecordBatch> {
98 let mut arrow_fields: Vec<Option<Field>> = Vec::new();
99 let mut arrow_columns_builders: Vec<Option<Box<dyn ArrayBuilder>>> = Vec::new();
100 let mut mysql_types: Vec<ColumnType> = Vec::new();
101 let mut column_names: Vec<String> = Vec::new();
102 let mut column_is_binary_stats: Vec<bool> = Vec::new();
103 let mut column_is_enum_stats: Vec<bool> = Vec::new();
104 let mut column_use_large_str_or_blob_stats: Vec<bool> = Vec::new();
105
106 if !rows.is_empty() {
107 let row = &rows[0];
108 for column in row.columns().iter() {
109 let column_name = column.name_str();
110 let column_type = column.column_type();
111 let column_is_binary = column.flags().contains(ColumnFlags::BINARY_FLAG);
112 let column_is_enum = column.flags().contains(ColumnFlags::ENUM_FLAG);
113 let column_use_large_str_or_blob = column.column_length() > 2_u32.pow(31) - 1;
114
115 let (decimal_precision, decimal_scale) = match column_type {
116 ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => {
117 match projected_schema {
119 Some(schema) => {
120 let precision =
121 get_decimal_column_precision(&column_name, schema).unwrap_or(76);
122 (Some(precision), Some(column.decimals() as i8))
123 }
124 None => (Some(76), Some(column.decimals() as i8)),
125 }
126 }
127 _ => (None, None),
128 };
129
130 let data_type = map_column_to_data_type(
131 column_type,
132 column_is_binary,
133 column_is_enum,
134 column_use_large_str_or_blob,
135 decimal_precision,
136 decimal_scale,
137 );
138
139 arrow_fields.push(
140 data_type
141 .clone()
142 .map(|data_type| Field::new(column_name.clone(), data_type.clone(), true)),
143 );
144 arrow_columns_builders
145 .push(map_data_type_to_array_builder_optional(data_type.as_ref()));
146 mysql_types.push(column_type);
147 column_names.push(column_name.to_string());
148 column_is_binary_stats.push(column_is_binary);
149 column_is_enum_stats.push(column_is_enum);
150 column_use_large_str_or_blob_stats.push(column_use_large_str_or_blob);
151 }
152 }
153
154 for row in rows {
155 for (i, mysql_type) in mysql_types.iter().enumerate() {
156 let Some(builder) = arrow_columns_builders.get_mut(i) else {
157 return NoBuilderForIndexSnafu { index: i }.fail();
158 };
159
160 let column_name = column_names.get(i).cloned().unwrap_or_default();
161
162 match *mysql_type {
163 ColumnType::MYSQL_TYPE_NULL => {
164 let Some(builder) = builder else {
165 return NoBuilderForIndexSnafu { index: i }.fail();
166 };
167 let Some(builder) = builder.as_any_mut().downcast_mut::<NullBuilder>() else {
168 return FailedToDowncastBuilderSnafu {
169 mysql_type: format!("{mysql_type:?}"),
170 }
171 .fail();
172 };
173 builder.append_null();
174 }
175 ColumnType::MYSQL_TYPE_BIT => {
176 let Some(builder) = builder else {
177 return NoBuilderForIndexSnafu { index: i }.fail();
178 };
179 let Some(builder) = builder.as_any_mut().downcast_mut::<UInt64Builder>() else {
180 return FailedToDowncastBuilderSnafu {
181 mysql_type: format!("{mysql_type:?}"),
182 }
183 .fail();
184 };
185 let value = row.get_opt::<Value, usize>(i).transpose().context(
186 FailedToGetRowValueSnafu {
187 column: column_name,
188 mysql_type: ColumnType::MYSQL_TYPE_BIT,
189 },
190 )?;
191 match value {
192 Some(Value::Bytes(mut bytes)) => {
193 while bytes.len() < 8 {
194 bytes.insert(0, 0);
195 }
196 let mut array = [0u8; 8];
197 array.copy_from_slice(&bytes);
198 builder.append_value(u64::from_be_bytes(array));
199 }
200 _ => builder.append_null(),
201 }
202 }
203 ColumnType::MYSQL_TYPE_TINY => {
204 handle_primitive_type!(
205 builder,
206 ColumnType::MYSQL_TYPE_TINY,
207 Int8Builder,
208 i8,
209 row,
210 i,
211 column_name
212 );
213 }
214 column_type @ (ColumnType::MYSQL_TYPE_SHORT | ColumnType::MYSQL_TYPE_YEAR) => {
215 handle_primitive_type!(
216 builder,
217 column_type,
218 Int16Builder,
219 i16,
220 row,
221 i,
222 column_name
223 );
224 }
225 column_type @ (ColumnType::MYSQL_TYPE_INT24 | ColumnType::MYSQL_TYPE_LONG) => {
226 handle_primitive_type!(
227 builder,
228 column_type,
229 Int32Builder,
230 i32,
231 row,
232 i,
233 column_name
234 );
235 }
236 ColumnType::MYSQL_TYPE_LONGLONG => {
237 handle_primitive_type!(
238 builder,
239 ColumnType::MYSQL_TYPE_LONGLONG,
240 Int64Builder,
241 i64,
242 row,
243 i,
244 column_name
245 );
246 }
247 ColumnType::MYSQL_TYPE_FLOAT => {
248 handle_primitive_type!(
249 builder,
250 ColumnType::MYSQL_TYPE_FLOAT,
251 Float32Builder,
252 f32,
253 row,
254 i,
255 column_name
256 );
257 }
258 ColumnType::MYSQL_TYPE_DOUBLE => {
259 handle_primitive_type!(
260 builder,
261 ColumnType::MYSQL_TYPE_DOUBLE,
262 Float64Builder,
263 f64,
264 row,
265 i,
266 column_name
267 );
268 }
269 ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => {
270 let Some(builder) = builder else {
271 return NoBuilderForIndexSnafu { index: i }.fail();
272 };
273
274 let arrow_field = match arrow_fields.get(i) {
275 Some(Some(field)) => field,
276 _ => return NoArrowFieldForIndexSnafu { index: i }.fail(),
277 };
278
279 match arrow_field.data_type() {
280 DataType::Decimal128(_, _) => {
281 let Some(builder) =
282 builder.as_any_mut().downcast_mut::<Decimal128Builder>()
283 else {
284 return FailedToDowncastBuilderSnafu {
285 mysql_type: format!("{mysql_type:?}"),
286 }
287 .fail();
288 };
289 let val =
290 handle_null_error(row.get_opt::<BigDecimal, usize>(i).transpose())
291 .context(FailedToGetRowValueSnafu {
292 column: column_name,
293 mysql_type: ColumnType::MYSQL_TYPE_DECIMAL,
294 })?;
295
296 let scale = match &val {
297 Some(val) => val.fractional_digit_count(),
298 None => 0,
299 };
300
301 let Some(val) = val else {
302 builder.append_null();
303 continue;
304 };
305
306 let Some(val) = to_decimal_128(&val, scale) else {
307 return FailedToConvertBigDecimalToI128Snafu { big_decimal: val }
308 .fail();
309 };
310
311 builder.append_value(val);
312 }
313 DataType::Decimal256(_, _) => {
314 let Some(builder) =
315 builder.as_any_mut().downcast_mut::<Decimal256Builder>()
316 else {
317 return FailedToDowncastBuilderSnafu {
318 mysql_type: format!("{mysql_type:?}"),
319 }
320 .fail();
321 };
322
323 let val =
324 handle_null_error(row.get_opt::<BigDecimal, usize>(i).transpose())
325 .context(FailedToGetRowValueSnafu {
326 column: column_name,
327 mysql_type: ColumnType::MYSQL_TYPE_DECIMAL,
328 })?;
329
330 let Some(val) = val else {
331 builder.append_null();
332 continue;
333 };
334
335 let val = to_decimal_256(&val);
336
337 builder.append_value(val);
338 }
339 _ => unreachable!(),
341 }
342 }
343 column_type @ (ColumnType::MYSQL_TYPE_VARCHAR | ColumnType::MYSQL_TYPE_JSON) => {
344 handle_primitive_type!(
345 builder,
346 column_type,
347 LargeStringBuilder,
348 String,
349 row,
350 i,
351 column_name
352 );
353 }
354 ColumnType::MYSQL_TYPE_BLOB => {
355 match (
356 column_use_large_str_or_blob_stats[i],
357 column_is_binary_stats[i],
358 ) {
359 (true, true) => handle_primitive_type!(
360 builder,
361 ColumnType::MYSQL_TYPE_BLOB,
362 LargeBinaryBuilder,
363 Vec<u8>,
364 row,
365 i,
366 column_name
367 ),
368 (true, false) => handle_primitive_type!(
369 builder,
370 ColumnType::MYSQL_TYPE_BLOB,
371 LargeStringBuilder,
372 String,
373 row,
374 i,
375 column_name
376 ),
377 (false, true) => handle_primitive_type!(
378 builder,
379 ColumnType::MYSQL_TYPE_BLOB,
380 BinaryBuilder,
381 Vec<u8>,
382 row,
383 i,
384 column_name
385 ),
386 (false, false) => handle_primitive_type!(
387 builder,
388 ColumnType::MYSQL_TYPE_BLOB,
389 StringBuilder,
390 String,
391 row,
392 i,
393 column_name
394 ),
395 }
396 }
397 ColumnType::MYSQL_TYPE_ENUM => {
398 unreachable!()
401 }
402 column_type @ (ColumnType::MYSQL_TYPE_STRING
403 | ColumnType::MYSQL_TYPE_VAR_STRING) => {
404 if column_is_enum_stats[i] {
406 let Some(builder) = builder else {
407 return NoBuilderForIndexSnafu { index: i }.fail();
408 };
409 let Some(builder) = builder
410 .as_any_mut()
411 .downcast_mut::<StringDictionaryBuilder<UInt16Type>>()
412 else {
413 return FailedToDowncastBuilderSnafu {
414 mysql_type: format!("{mysql_type:?}"),
415 }
416 .fail();
417 };
418
419 let v = handle_null_error(row.get_opt::<String, usize>(i).transpose())
420 .context(FailedToGetRowValueSnafu {
421 column: column_name,
422 mysql_type: ColumnType::MYSQL_TYPE_ENUM,
423 })?;
424
425 match v {
426 Some(v) => {
427 builder.append_value(v);
428 }
429 None => builder.append_null(),
430 }
431 } else if column_is_binary_stats[i] {
432 handle_primitive_type!(
433 builder,
434 column_type,
435 BinaryBuilder,
436 Vec<u8>,
437 row,
438 i,
439 column_name
440 );
441 } else {
442 handle_primitive_type!(
443 builder,
444 column_type,
445 StringBuilder,
446 String,
447 row,
448 i,
449 column_name
450 );
451 }
452 }
453 ColumnType::MYSQL_TYPE_DATE => {
454 let Some(builder) = builder else {
455 return NoBuilderForIndexSnafu { index: i }.fail();
456 };
457 let Some(builder) = builder.as_any_mut().downcast_mut::<Date32Builder>() else {
458 return FailedToDowncastBuilderSnafu {
459 mysql_type: format!("{mysql_type:?}"),
460 }
461 .fail();
462 };
463
464 let v = match handle_null_error(row.get_opt::<NaiveDate, usize>(i).transpose())
465 {
466 Ok(v) => v,
467 Err(err) => {
468 if matches!(err, FromValueError(Value::Date(0, 0, 0, 0, 0, 0, 0))) {
470 None
471 } else {
472 return Err(Error::FailedToGetRowValue {
473 column: column_name,
474 mysql_type: ColumnType::MYSQL_TYPE_DATE,
475 source: err,
476 });
477 }
478 }
479 };
480
481 match v {
482 Some(v) => {
483 builder.append_value(Date32Type::from_naive_date(v));
484 }
485 None => builder.append_null(),
486 }
487 }
488 ColumnType::MYSQL_TYPE_TIME => {
489 let Some(builder) = builder else {
490 return NoBuilderForIndexSnafu { index: i }.fail();
491 };
492 let Some(builder) = builder
493 .as_any_mut()
494 .downcast_mut::<Time64NanosecondBuilder>()
495 else {
496 return FailedToDowncastBuilderSnafu {
497 mysql_type: format!("{mysql_type:?}"),
498 }
499 .fail();
500 };
501 let v = handle_null_error(row.get_opt::<NaiveTime, usize>(i).transpose())
502 .context(FailedToGetRowValueSnafu {
503 column: column_name,
504 mysql_type: ColumnType::MYSQL_TYPE_TIME,
505 })?;
506
507 match v {
508 Some(value) => {
509 builder.append_value(
510 i64::from(value.num_seconds_from_midnight()) * 1_000_000_000
511 + i64::from(value.nanosecond()),
512 );
513 }
514 None => builder.append_null(),
515 }
516 }
517 column_type @ (ColumnType::MYSQL_TYPE_TIMESTAMP
518 | ColumnType::MYSQL_TYPE_DATETIME) => {
519 let Some(builder) = builder else {
520 return NoBuilderForIndexSnafu { index: i }.fail();
521 };
522 let Some(builder) = builder
523 .as_any_mut()
524 .downcast_mut::<TimestampMicrosecondBuilder>()
525 else {
526 return FailedToDowncastBuilderSnafu {
527 mysql_type: format!("{mysql_type:?}"),
528 }
529 .fail();
530 };
531 let v = match handle_null_error(
532 row.get_opt::<PrimitiveDateTime, usize>(i).transpose(),
533 ) {
534 Ok(v) => v,
535 Err(err) => {
536 if matches!(err, FromValueError(Value::Date(0, 0, 0, 0, 0, 0, 0))) {
538 None
539 } else {
540 return Err(Error::FailedToGetRowValue {
541 column: column_name,
542 mysql_type: column_type,
543 source: err,
544 });
545 }
546 }
547 };
548
549 match v {
550 Some(v) => {
551 #[allow(clippy::cast_possible_truncation)]
552 let timestamp_micros =
553 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
554 builder.append_value(timestamp_micros);
555 }
556 None => builder.append_null(),
557 }
558 }
559 _ => unimplemented!("Unsupported column type {:?}", mysql_type),
560 }
561 }
562 }
563
564 let columns = arrow_columns_builders
565 .into_iter()
566 .filter_map(|builder| builder.map(|mut b| b.finish()))
567 .collect::<Vec<ArrayRef>>();
568 let arrow_fields = arrow_fields.into_iter().flatten().collect::<Vec<Field>>();
569 let options = &RecordBatchOptions::new().with_row_count(Some(rows.len()));
570 RecordBatch::try_new_with_options(Arc::new(Schema::new(arrow_fields)), columns, options)
571 .map_err(|err| Error::FailedToBuildRecordBatch { source: err })
572}
573
574#[allow(clippy::unnecessary_wraps)]
575pub fn map_column_to_data_type(
576 column_type: ColumnType,
577 column_is_binary: bool,
578 column_is_enum: bool,
579 column_use_large_str_or_blob: bool,
580 column_decimal_precision: Option<u8>,
581 column_decimal_scale: Option<i8>,
582) -> Option<DataType> {
583 match column_type {
584 ColumnType::MYSQL_TYPE_NULL => Some(DataType::Null),
585 ColumnType::MYSQL_TYPE_BIT => Some(DataType::UInt64),
586 ColumnType::MYSQL_TYPE_TINY => Some(DataType::Int8),
587 ColumnType::MYSQL_TYPE_YEAR | ColumnType::MYSQL_TYPE_SHORT => Some(DataType::Int16),
588 ColumnType::MYSQL_TYPE_INT24 | ColumnType::MYSQL_TYPE_LONG => Some(DataType::Int32),
589 ColumnType::MYSQL_TYPE_LONGLONG => Some(DataType::Int64),
590 ColumnType::MYSQL_TYPE_FLOAT => Some(DataType::Float32),
591 ColumnType::MYSQL_TYPE_DOUBLE => Some(DataType::Float64),
592 ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => {
594 if column_decimal_precision.unwrap_or_default() > 38 {
595 return Some(DataType::Decimal256(column_decimal_precision.unwrap_or_default(), column_decimal_scale.unwrap_or_default()));
596 }
597 Some(DataType::Decimal128(column_decimal_precision.unwrap_or_default(), column_decimal_scale.unwrap_or_default()))
598 },
599 ColumnType::MYSQL_TYPE_TIMESTAMP | ColumnType::MYSQL_TYPE_DATETIME => {
600 Some(DataType::Timestamp(TimeUnit::Microsecond, None))
601 },
602 ColumnType::MYSQL_TYPE_DATE => Some(DataType::Date32),
603 ColumnType::MYSQL_TYPE_TIME => {
604 Some(DataType::Time64(TimeUnit::Nanosecond))
605 }
606 ColumnType::MYSQL_TYPE_VARCHAR
607 | ColumnType::MYSQL_TYPE_JSON => Some(DataType::LargeUtf8),
608 ColumnType::MYSQL_TYPE_BLOB => {
612 match (column_use_large_str_or_blob, column_is_binary) {
613 (true, true) => Some(DataType::LargeBinary),
614 (true, false) => Some(DataType::LargeUtf8),
615 (false, true) => Some(DataType::Binary),
616 (false, false) => Some(DataType::Utf8),
617 }
618 }
619 ColumnType::MYSQL_TYPE_ENUM | ColumnType::MYSQL_TYPE_SET => unreachable!(),
620 ColumnType::MYSQL_TYPE_STRING
621 | ColumnType::MYSQL_TYPE_VAR_STRING => {
622 if column_is_enum {
623 Some(DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)))
624 } else if column_is_binary {
625 Some(DataType::Binary)
626 } else {
627 Some(DataType::Utf8)
628 }
629 },
630 ColumnType::MYSQL_TYPE_TYPED_ARRAY
632 | ColumnType::MYSQL_TYPE_NEWDATE
634 | ColumnType::MYSQL_TYPE_UNKNOWN
636 | ColumnType::MYSQL_TYPE_TIMESTAMP2
637 | ColumnType::MYSQL_TYPE_DATETIME2
638 | ColumnType::MYSQL_TYPE_TIME2
639 | ColumnType::MYSQL_TYPE_LONG_BLOB
640 | ColumnType::MYSQL_TYPE_TINY_BLOB
641 | ColumnType::MYSQL_TYPE_MEDIUM_BLOB
642 | ColumnType::MYSQL_TYPE_GEOMETRY
643 | ColumnType::MYSQL_TYPE_VECTOR => {
644 unimplemented!("Unsupported column type {:?}", column_type)
645 }
646 }
647}
648
649fn to_decimal_128(decimal: &BigDecimal, scale: i64) -> Option<i128> {
650 (decimal * 10i128.pow(scale.try_into().unwrap_or_default())).to_i128()
651}
652
653fn to_decimal_256(decimal: &BigDecimal) -> i256 {
654 let (bigint_value, _) = decimal.as_bigint_and_exponent();
655 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
656
657 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
658 let fill_byte = if is_negative { 0xFF } else { 0x00 };
659
660 if bigint_bytes.len() > 32 {
661 bigint_bytes.truncate(32);
662 } else {
663 bigint_bytes.resize(32, fill_byte);
664 };
665
666 let mut array = [0u8; 32];
667 array.copy_from_slice(&bigint_bytes);
668
669 i256::from_le_bytes(array)
670}
671
672fn get_decimal_column_precision(column_name: &str, projected_schema: &SchemaRef) -> Option<u8> {
673 let field = projected_schema.field_with_name(column_name).ok()?;
674 match field.data_type() {
675 DataType::Decimal256(precision, _) | DataType::Decimal128(precision, _) => Some(*precision),
676 _ => None,
677 }
678}
679fn handle_null_error<T>(
680 result: Result<Option<T>, FromValueError>,
681) -> Result<Option<T>, FromValueError> {
682 match result {
683 Ok(val) => Ok(val),
684 Err(FromValueError(Value::NULL)) => Ok(None),
685 err => err,
686 }
687}