datafusion_remote_table/
utils.rs

1use crate::{ConnectionOptions, DFResult, RemoteSource, RemoteTable};
2use bigdecimal::BigDecimal;
3use bigdecimal::ToPrimitive;
4use datafusion::arrow::array::{
5    Array, BooleanArray, GenericByteArray, PrimitiveArray, RecordBatch,
6};
7use datafusion::arrow::datatypes::{
8    ArrowPrimitiveType, BinaryType, BooleanType, ByteArrayType, LargeBinaryType, LargeUtf8Type,
9    Utf8Type, i256,
10};
11use datafusion::error::DataFusionError;
12use datafusion::prelude::SessionContext;
13use std::str::FromStr;
14use std::sync::Arc;
15
16pub async fn remote_collect(
17    options: ConnectionOptions,
18    sql: impl Into<String>,
19) -> DFResult<Vec<RecordBatch>> {
20    let table = RemoteTable::try_new(options, RemoteSource::Query(sql.into())).await?;
21    let ctx = SessionContext::new();
22    ctx.read_table(Arc::new(table))?.collect().await
23}
24
25pub async fn remote_collect_primitive_column<T: ArrowPrimitiveType>(
26    options: ConnectionOptions,
27    sql: impl Into<String>,
28    col_idx: usize,
29) -> DFResult<Vec<Option<T::Native>>> {
30    let batches = remote_collect(options, sql).await?;
31    extract_primitive_array::<T>(&batches, col_idx)
32}
33
34pub async fn remote_collect_utf8_column(
35    options: ConnectionOptions,
36    sql: impl Into<String>,
37    col_idx: usize,
38) -> DFResult<Vec<Option<String>>> {
39    let batches = remote_collect(options, sql).await?;
40    let vec = extract_byte_array::<Utf8Type>(&batches, col_idx)?;
41    Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
42}
43
44pub async fn remote_collect_large_utf8_column(
45    options: ConnectionOptions,
46    sql: impl Into<String>,
47    col_idx: usize,
48) -> DFResult<Vec<Option<String>>> {
49    let batches = remote_collect(options, sql).await?;
50    let vec = extract_byte_array::<LargeUtf8Type>(&batches, col_idx)?;
51    Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
52}
53
54pub async fn remote_collect_binary_column(
55    options: ConnectionOptions,
56    sql: impl Into<String>,
57    col_idx: usize,
58) -> DFResult<Vec<Option<Vec<u8>>>> {
59    let batches = remote_collect(options, sql).await?;
60    let vec = extract_byte_array::<BinaryType>(&batches, col_idx)?;
61    Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
62}
63
64pub async fn remote_collect_large_binary_column(
65    options: ConnectionOptions,
66    sql: impl Into<String>,
67    col_idx: usize,
68) -> DFResult<Vec<Option<Vec<u8>>>> {
69    let batches = remote_collect(options, sql).await?;
70    let vec = extract_byte_array::<LargeBinaryType>(&batches, col_idx)?;
71    Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
72}
73
74pub fn extract_primitive_array<T: ArrowPrimitiveType>(
75    batches: &[RecordBatch],
76    col_idx: usize,
77) -> DFResult<Vec<Option<T::Native>>> {
78    let mut result = Vec::new();
79    for batch in batches {
80        let column = batch.column(col_idx);
81        if let Some(array) = column.as_any().downcast_ref::<PrimitiveArray<T>>() {
82            result.extend(array.iter().collect::<Vec<_>>())
83        } else {
84            return Err(DataFusionError::Execution(format!(
85                "Column at index {col_idx} is not {} instead of {}",
86                T::DATA_TYPE,
87                column.data_type(),
88            )));
89        }
90    }
91    Ok(result)
92}
93
94pub fn extract_boolean_array(
95    batches: &[RecordBatch],
96    col_idx: usize,
97) -> DFResult<Vec<Option<bool>>> {
98    let mut result = Vec::new();
99    for batch in batches {
100        let column = batch.column(col_idx);
101        if let Some(array) = column.as_any().downcast_ref::<BooleanArray>() {
102            result.extend(array.iter().collect::<Vec<_>>())
103        } else {
104            return Err(DataFusionError::Execution(format!(
105                "Column at index {col_idx} is not {} instead of {}",
106                BooleanType::DATA_TYPE,
107                column.data_type(),
108            )));
109        }
110    }
111    Ok(result)
112}
113
114pub fn extract_byte_array<T: ByteArrayType>(
115    batches: &[RecordBatch],
116    col_idx: usize,
117) -> DFResult<Vec<Option<&T::Native>>> {
118    let mut result = Vec::new();
119    for batch in batches {
120        let column = batch.column(col_idx);
121        if let Some(array) = column.as_any().downcast_ref::<GenericByteArray<T>>() {
122            result.extend(array.iter().collect::<Vec<_>>())
123        } else {
124            return Err(DataFusionError::Execution(format!(
125                "Column at index {col_idx} is not {} instead of {}",
126                T::DATA_TYPE,
127                column.data_type(),
128            )));
129        }
130    }
131    Ok(result)
132}
133
134pub fn gen_tenfold_scaling_factor(scale: i32) -> String {
135    if scale >= 0 {
136        format!("1{}", "0".repeat(scale as usize))
137    } else {
138        format!("0.{}{}", "0".repeat((-scale - 1) as usize), "1")
139    }
140}
141
142pub fn big_decimal_to_i128(decimal: &BigDecimal, scale: Option<i32>) -> DFResult<i128> {
143    let scale = scale.unwrap_or_else(|| {
144        decimal
145            .fractional_digit_count()
146            .try_into()
147            .unwrap_or_default()
148    });
149    let scale_str = gen_tenfold_scaling_factor(scale);
150    let scale_decimal = BigDecimal::from_str(&scale_str).map_err(|e| {
151        DataFusionError::Execution(format!(
152            "Failed to parse str {scale_str} to BigDecimal: {e:?}",
153        ))
154    })?;
155    (decimal * scale_decimal).to_i128().ok_or_else(|| {
156        DataFusionError::Execution(format!(
157            "Failed to convert BigDecimal to i128 for {decimal:?}",
158        ))
159    })
160}
161
162pub fn big_decimal_to_i256(decimal: &BigDecimal, scale: Option<i32>) -> DFResult<i256> {
163    let scale = scale.unwrap_or_else(|| {
164        decimal
165            .fractional_digit_count()
166            .try_into()
167            .unwrap_or_default()
168    });
169    let scale_str = gen_tenfold_scaling_factor(scale);
170    let scale_decimal = BigDecimal::from_str(&scale_str).map_err(|e| {
171        DataFusionError::Execution(format!(
172            "Failed to parse str {scale_str} to BigDecimal: {e:?}",
173        ))
174    })?;
175    let scaled_decimal = decimal * scale_decimal;
176
177    // remove the fractional part, only keep the integer part
178    let integer_part = scaled_decimal.with_scale(0);
179
180    // Convert to string and then parse as i256
181    integer_part.to_string().parse::<i256>().map_err(|e| {
182        DataFusionError::Execution(format!("Failed to parse str {integer_part} to i256: {e:?}",))
183    })
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use datafusion::arrow::array::{BooleanArray, Int32Array, RecordBatch, StringArray};
190    use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema, Utf8Type};
191    use std::sync::Arc;
192
193    #[tokio::test]
194    async fn test_extract_primitive_array() {
195        let expected = vec![Some(1), Some(2), None];
196        let batches = vec![
197            RecordBatch::try_new(
198                Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])),
199                vec![Arc::new(Int32Array::from(expected.clone()))],
200            )
201            .unwrap(),
202        ];
203        let result: Vec<Option<i32>> = extract_primitive_array::<Int32Type>(&batches, 0).unwrap();
204        assert_eq!(result, expected);
205    }
206
207    #[tokio::test]
208    async fn test_extract_bool_array() {
209        let expected = vec![Some(true), Some(false), None];
210        let batches = vec![
211            RecordBatch::try_new(
212                Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])),
213                vec![Arc::new(BooleanArray::from(expected.clone()))],
214            )
215            .unwrap(),
216        ];
217        let result: Vec<Option<bool>> = extract_boolean_array(&batches, 0).unwrap();
218        assert_eq!(result, expected);
219    }
220
221    #[tokio::test]
222    async fn test_extract_byte_array() {
223        let expected = vec![Some("abc"), Some("def"), None];
224        let batches = vec![
225            RecordBatch::try_new(
226                Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])),
227                vec![Arc::new(StringArray::from(expected.clone()))],
228            )
229            .unwrap(),
230        ];
231        let result: Vec<Option<&str>> = extract_byte_array::<Utf8Type>(&batches, 0).unwrap();
232        assert_eq!(result, expected);
233    }
234
235    #[test]
236    fn test_gen_tenfold_scaling_factor() {
237        assert_eq!(gen_tenfold_scaling_factor(0), "1");
238        assert_eq!(gen_tenfold_scaling_factor(1), "10");
239        assert_eq!(gen_tenfold_scaling_factor(2), "100");
240        assert_eq!(gen_tenfold_scaling_factor(-1), "0.1");
241        assert_eq!(gen_tenfold_scaling_factor(-2), "0.01");
242    }
243}