1use crate::{ConnectionOptions, DFResult, RemoteSource, RemoteTable};
2use bigdecimal::BigDecimal;
3use bigdecimal::ToPrimitive;
4use datafusion::arrow::array::{
5 Array, BooleanArray, GenericByteArray, PrimitiveArray, RecordBatch,
6};
7use datafusion::arrow::datatypes::{
8 ArrowPrimitiveType, BinaryType, BooleanType, ByteArrayType, LargeBinaryType, LargeUtf8Type,
9 Utf8Type, i256,
10};
11use datafusion::error::DataFusionError;
12use datafusion::prelude::SessionContext;
13use std::str::FromStr;
14use std::sync::Arc;
15
16pub async fn remote_collect(
17 options: ConnectionOptions,
18 sql: impl Into<String>,
19) -> DFResult<Vec<RecordBatch>> {
20 let table = RemoteTable::try_new(options, RemoteSource::Query(sql.into())).await?;
21 let ctx = SessionContext::new();
22 ctx.read_table(Arc::new(table))?.collect().await
23}
24
25pub async fn remote_collect_primitive_column<T: ArrowPrimitiveType>(
26 options: ConnectionOptions,
27 sql: impl Into<String>,
28 col_idx: usize,
29) -> DFResult<Vec<Option<T::Native>>> {
30 let batches = remote_collect(options, sql).await?;
31 extract_primitive_array::<T>(&batches, col_idx)
32}
33
34pub async fn remote_collect_utf8_column(
35 options: ConnectionOptions,
36 sql: impl Into<String>,
37 col_idx: usize,
38) -> DFResult<Vec<Option<String>>> {
39 let batches = remote_collect(options, sql).await?;
40 let vec = extract_byte_array::<Utf8Type>(&batches, col_idx)?;
41 Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
42}
43
44pub async fn remote_collect_large_utf8_column(
45 options: ConnectionOptions,
46 sql: impl Into<String>,
47 col_idx: usize,
48) -> DFResult<Vec<Option<String>>> {
49 let batches = remote_collect(options, sql).await?;
50 let vec = extract_byte_array::<LargeUtf8Type>(&batches, col_idx)?;
51 Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
52}
53
54pub async fn remote_collect_binary_column(
55 options: ConnectionOptions,
56 sql: impl Into<String>,
57 col_idx: usize,
58) -> DFResult<Vec<Option<Vec<u8>>>> {
59 let batches = remote_collect(options, sql).await?;
60 let vec = extract_byte_array::<BinaryType>(&batches, col_idx)?;
61 Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
62}
63
64pub async fn remote_collect_large_binary_column(
65 options: ConnectionOptions,
66 sql: impl Into<String>,
67 col_idx: usize,
68) -> DFResult<Vec<Option<Vec<u8>>>> {
69 let batches = remote_collect(options, sql).await?;
70 let vec = extract_byte_array::<LargeBinaryType>(&batches, col_idx)?;
71 Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
72}
73
74pub fn extract_primitive_array<T: ArrowPrimitiveType>(
75 batches: &[RecordBatch],
76 col_idx: usize,
77) -> DFResult<Vec<Option<T::Native>>> {
78 let mut result = Vec::new();
79 for batch in batches {
80 let column = batch.column(col_idx);
81 if let Some(array) = column.as_any().downcast_ref::<PrimitiveArray<T>>() {
82 result.extend(array.iter().collect::<Vec<_>>())
83 } else {
84 return Err(DataFusionError::Execution(format!(
85 "Column at index {col_idx} is not {} instead of {}",
86 T::DATA_TYPE,
87 column.data_type(),
88 )));
89 }
90 }
91 Ok(result)
92}
93
94pub fn extract_boolean_array(
95 batches: &[RecordBatch],
96 col_idx: usize,
97) -> DFResult<Vec<Option<bool>>> {
98 let mut result = Vec::new();
99 for batch in batches {
100 let column = batch.column(col_idx);
101 if let Some(array) = column.as_any().downcast_ref::<BooleanArray>() {
102 result.extend(array.iter().collect::<Vec<_>>())
103 } else {
104 return Err(DataFusionError::Execution(format!(
105 "Column at index {col_idx} is not {} instead of {}",
106 BooleanType::DATA_TYPE,
107 column.data_type(),
108 )));
109 }
110 }
111 Ok(result)
112}
113
114pub fn extract_byte_array<T: ByteArrayType>(
115 batches: &[RecordBatch],
116 col_idx: usize,
117) -> DFResult<Vec<Option<&T::Native>>> {
118 let mut result = Vec::new();
119 for batch in batches {
120 let column = batch.column(col_idx);
121 if let Some(array) = column.as_any().downcast_ref::<GenericByteArray<T>>() {
122 result.extend(array.iter().collect::<Vec<_>>())
123 } else {
124 return Err(DataFusionError::Execution(format!(
125 "Column at index {col_idx} is not {} instead of {}",
126 T::DATA_TYPE,
127 column.data_type(),
128 )));
129 }
130 }
131 Ok(result)
132}
133
134pub fn gen_tenfold_scaling_factor(scale: i32) -> String {
135 if scale >= 0 {
136 format!("1{}", "0".repeat(scale as usize))
137 } else {
138 format!("0.{}{}", "0".repeat((-scale - 1) as usize), "1")
139 }
140}
141
142pub fn big_decimal_to_i128(decimal: &BigDecimal, scale: Option<i32>) -> DFResult<i128> {
143 let scale = scale.unwrap_or_else(|| {
144 decimal
145 .fractional_digit_count()
146 .try_into()
147 .unwrap_or_default()
148 });
149 let scale_str = gen_tenfold_scaling_factor(scale);
150 let scale_decimal = BigDecimal::from_str(&scale_str).map_err(|e| {
151 DataFusionError::Execution(format!(
152 "Failed to parse str {scale_str} to BigDecimal: {e:?}",
153 ))
154 })?;
155 (decimal * scale_decimal).to_i128().ok_or_else(|| {
156 DataFusionError::Execution(format!(
157 "Failed to convert BigDecimal to i128 for {decimal:?}",
158 ))
159 })
160}
161
162pub fn big_decimal_to_i256(decimal: &BigDecimal, scale: Option<i32>) -> DFResult<i256> {
163 let scale = scale.unwrap_or_else(|| {
164 decimal
165 .fractional_digit_count()
166 .try_into()
167 .unwrap_or_default()
168 });
169 let scale_str = gen_tenfold_scaling_factor(scale);
170 let scale_decimal = BigDecimal::from_str(&scale_str).map_err(|e| {
171 DataFusionError::Execution(format!(
172 "Failed to parse str {scale_str} to BigDecimal: {e:?}",
173 ))
174 })?;
175 let scaled_decimal = decimal * scale_decimal;
176
177 let integer_part = scaled_decimal.with_scale(0);
179
180 integer_part.to_string().parse::<i256>().map_err(|e| {
182 DataFusionError::Execution(format!("Failed to parse str {integer_part} to i256: {e:?}",))
183 })
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use datafusion::arrow::array::{BooleanArray, Int32Array, RecordBatch, StringArray};
190 use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema, Utf8Type};
191 use std::sync::Arc;
192
193 #[tokio::test]
194 async fn test_extract_primitive_array() {
195 let expected = vec![Some(1), Some(2), None];
196 let batches = vec![
197 RecordBatch::try_new(
198 Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])),
199 vec![Arc::new(Int32Array::from(expected.clone()))],
200 )
201 .unwrap(),
202 ];
203 let result: Vec<Option<i32>> = extract_primitive_array::<Int32Type>(&batches, 0).unwrap();
204 assert_eq!(result, expected);
205 }
206
207 #[tokio::test]
208 async fn test_extract_bool_array() {
209 let expected = vec![Some(true), Some(false), None];
210 let batches = vec![
211 RecordBatch::try_new(
212 Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])),
213 vec![Arc::new(BooleanArray::from(expected.clone()))],
214 )
215 .unwrap(),
216 ];
217 let result: Vec<Option<bool>> = extract_boolean_array(&batches, 0).unwrap();
218 assert_eq!(result, expected);
219 }
220
221 #[tokio::test]
222 async fn test_extract_byte_array() {
223 let expected = vec![Some("abc"), Some("def"), None];
224 let batches = vec![
225 RecordBatch::try_new(
226 Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])),
227 vec![Arc::new(StringArray::from(expected.clone()))],
228 )
229 .unwrap(),
230 ];
231 let result: Vec<Option<&str>> = extract_byte_array::<Utf8Type>(&batches, 0).unwrap();
232 assert_eq!(result, expected);
233 }
234
235 #[test]
236 fn test_gen_tenfold_scaling_factor() {
237 assert_eq!(gen_tenfold_scaling_factor(0), "1");
238 assert_eq!(gen_tenfold_scaling_factor(1), "10");
239 assert_eq!(gen_tenfold_scaling_factor(2), "100");
240 assert_eq!(gen_tenfold_scaling_factor(-1), "0.1");
241 assert_eq!(gen_tenfold_scaling_factor(-2), "0.01");
242 }
243}