datafusion_remote_table/
utils.rs

1use crate::DFResult;
2use arrow::array::{Array, BooleanArray, GenericByteArray, PrimitiveArray, RecordBatch};
3use arrow::datatypes::{ArrowPrimitiveType, BooleanType, ByteArrayType, i256};
4use bigdecimal::BigDecimal;
5use bigdecimal::ToPrimitive;
6use datafusion_common::DataFusionError;
7use std::str::FromStr;
8
9pub fn extract_primitive_array<T: ArrowPrimitiveType>(
10    batches: &[RecordBatch],
11    col_idx: usize,
12) -> DFResult<Vec<Option<T::Native>>> {
13    let mut result = Vec::new();
14    for batch in batches {
15        let column = batch.column(col_idx);
16        if let Some(array) = column.as_any().downcast_ref::<PrimitiveArray<T>>() {
17            result.extend(array.iter().collect::<Vec<_>>())
18        } else {
19            return Err(DataFusionError::Execution(format!(
20                "Column at index {col_idx} is not {} instead of {}",
21                T::DATA_TYPE,
22                column.data_type(),
23            )));
24        }
25    }
26    Ok(result)
27}
28
29pub fn extract_boolean_array(
30    batches: &[RecordBatch],
31    col_idx: usize,
32) -> DFResult<Vec<Option<bool>>> {
33    let mut result = Vec::new();
34    for batch in batches {
35        let column = batch.column(col_idx);
36        if let Some(array) = column.as_any().downcast_ref::<BooleanArray>() {
37            result.extend(array.iter().collect::<Vec<_>>())
38        } else {
39            return Err(DataFusionError::Execution(format!(
40                "Column at index {col_idx} is not {} instead of {}",
41                BooleanType::DATA_TYPE,
42                column.data_type(),
43            )));
44        }
45    }
46    Ok(result)
47}
48
49pub fn extract_byte_array<T: ByteArrayType>(
50    batches: &[RecordBatch],
51    col_idx: usize,
52) -> DFResult<Vec<Option<&T::Native>>> {
53    let mut result = Vec::new();
54    for batch in batches {
55        let column = batch.column(col_idx);
56        if let Some(array) = column.as_any().downcast_ref::<GenericByteArray<T>>() {
57            result.extend(array.iter().collect::<Vec<_>>())
58        } else {
59            return Err(DataFusionError::Execution(format!(
60                "Column at index {col_idx} is not {} instead of {}",
61                T::DATA_TYPE,
62                column.data_type(),
63            )));
64        }
65    }
66    Ok(result)
67}
68
69pub fn gen_tenfold_scaling_factor(scale: i32) -> String {
70    if scale >= 0 {
71        format!("1{}", "0".repeat(scale as usize))
72    } else {
73        format!("0.{}{}", "0".repeat((-scale - 1) as usize), "1")
74    }
75}
76
77pub fn big_decimal_to_i128(decimal: &BigDecimal, scale: Option<i32>) -> DFResult<i128> {
78    let scale = scale.unwrap_or_else(|| {
79        decimal
80            .fractional_digit_count()
81            .try_into()
82            .unwrap_or_default()
83    });
84    let scale_str = gen_tenfold_scaling_factor(scale);
85    let scale_decimal = BigDecimal::from_str(&scale_str).map_err(|e| {
86        DataFusionError::Execution(format!(
87            "Failed to parse str {scale_str} to BigDecimal: {e:?}",
88        ))
89    })?;
90    (decimal * scale_decimal).to_i128().ok_or_else(|| {
91        DataFusionError::Execution(format!(
92            "Failed to convert BigDecimal to i128 for {decimal:?}",
93        ))
94    })
95}
96
97pub fn big_decimal_to_i256(decimal: &BigDecimal, scale: Option<i32>) -> DFResult<i256> {
98    let scale = scale.unwrap_or_else(|| {
99        decimal
100            .fractional_digit_count()
101            .try_into()
102            .unwrap_or_default()
103    });
104    let scale_str = gen_tenfold_scaling_factor(scale);
105    let scale_decimal = BigDecimal::from_str(&scale_str).map_err(|e| {
106        DataFusionError::Execution(format!(
107            "Failed to parse str {scale_str} to BigDecimal: {e:?}",
108        ))
109    })?;
110    let scaled_decimal = decimal * scale_decimal;
111
112    // remove the fractional part, only keep the integer part
113    let integer_part = scaled_decimal.with_scale(0);
114
115    // Convert to string and then parse as i256
116    integer_part.to_string().parse::<i256>().map_err(|e| {
117        DataFusionError::Execution(format!("Failed to parse str {integer_part} to i256: {e:?}",))
118    })
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use arrow::array::{BooleanArray, Int32Array, RecordBatch, StringArray};
125    use arrow::datatypes::{DataType, Field, Int32Type, Schema, Utf8Type};
126    use std::sync::Arc;
127
128    #[tokio::test]
129    async fn test_extract_primitive_array() {
130        let expected = vec![Some(1), Some(2), None];
131        let batches = vec![
132            RecordBatch::try_new(
133                Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])),
134                vec![Arc::new(Int32Array::from(expected.clone()))],
135            )
136            .unwrap(),
137        ];
138        let result: Vec<Option<i32>> = extract_primitive_array::<Int32Type>(&batches, 0).unwrap();
139        assert_eq!(result, expected);
140    }
141
142    #[tokio::test]
143    async fn test_extract_bool_array() {
144        let expected = vec![Some(true), Some(false), None];
145        let batches = vec![
146            RecordBatch::try_new(
147                Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])),
148                vec![Arc::new(BooleanArray::from(expected.clone()))],
149            )
150            .unwrap(),
151        ];
152        let result: Vec<Option<bool>> = extract_boolean_array(&batches, 0).unwrap();
153        assert_eq!(result, expected);
154    }
155
156    #[tokio::test]
157    async fn test_extract_byte_array() {
158        let expected = vec![Some("abc"), Some("def"), None];
159        let batches = vec![
160            RecordBatch::try_new(
161                Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])),
162                vec![Arc::new(StringArray::from(expected.clone()))],
163            )
164            .unwrap(),
165        ];
166        let result: Vec<Option<&str>> = extract_byte_array::<Utf8Type>(&batches, 0).unwrap();
167        assert_eq!(result, expected);
168    }
169
170    #[test]
171    fn test_gen_tenfold_scaling_factor() {
172        assert_eq!(gen_tenfold_scaling_factor(0), "1");
173        assert_eq!(gen_tenfold_scaling_factor(1), "10");
174        assert_eq!(gen_tenfold_scaling_factor(2), "100");
175        assert_eq!(gen_tenfold_scaling_factor(-1), "0.1");
176        assert_eq!(gen_tenfold_scaling_factor(-2), "0.01");
177    }
178}