datafusion_remote_table/
utils.rs

1use crate::{ConnectionOptions, DFResult, RemoteTable};
2use datafusion::arrow::array::{Array, GenericByteArray, PrimitiveArray, RecordBatch};
3use datafusion::arrow::datatypes::{
4    ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type,
5};
6use datafusion::error::DataFusionError;
7use datafusion::prelude::SessionContext;
8use std::sync::Arc;
9
10pub async fn remote_collect(
11    options: ConnectionOptions,
12    sql: impl Into<String>,
13) -> DFResult<Vec<RecordBatch>> {
14    let table = RemoteTable::try_new(options, sql).await?;
15    let ctx = SessionContext::new();
16    ctx.read_table(Arc::new(table))?.collect().await
17}
18
19pub async fn remote_collect_primitive_column<T: ArrowPrimitiveType>(
20    options: ConnectionOptions,
21    sql: impl Into<String>,
22    col_idx: usize,
23) -> DFResult<Vec<Option<T::Native>>> {
24    let batches = remote_collect(options, sql).await?;
25    extract_primitive_array::<T>(&batches, col_idx)
26}
27
28pub async fn remote_collect_utf8_column(
29    options: ConnectionOptions,
30    sql: impl Into<String>,
31    col_idx: usize,
32) -> DFResult<Vec<Option<String>>> {
33    let batches = remote_collect(options, sql).await?;
34    let vec = extract_byte_array::<Utf8Type>(&batches, col_idx)?;
35    Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
36}
37
38pub async fn remote_collect_large_utf8_column(
39    options: ConnectionOptions,
40    sql: impl Into<String>,
41    col_idx: usize,
42) -> DFResult<Vec<Option<String>>> {
43    let batches = remote_collect(options, sql).await?;
44    let vec = extract_byte_array::<LargeUtf8Type>(&batches, col_idx)?;
45    Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
46}
47
48pub async fn remote_collect_binary_column(
49    options: ConnectionOptions,
50    sql: impl Into<String>,
51    col_idx: usize,
52) -> DFResult<Vec<Option<Vec<u8>>>> {
53    let batches = remote_collect(options, sql).await?;
54    let vec = extract_byte_array::<BinaryType>(&batches, col_idx)?;
55    Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
56}
57
58pub async fn remote_collect_large_binary_column(
59    options: ConnectionOptions,
60    sql: impl Into<String>,
61    col_idx: usize,
62) -> DFResult<Vec<Option<Vec<u8>>>> {
63    let batches = remote_collect(options, sql).await?;
64    let vec = extract_byte_array::<LargeBinaryType>(&batches, col_idx)?;
65    Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
66}
67
68pub fn extract_primitive_array<T: ArrowPrimitiveType>(
69    batches: &[RecordBatch],
70    col_idx: usize,
71) -> DFResult<Vec<Option<T::Native>>> {
72    let mut result = Vec::new();
73    for batch in batches {
74        let column = batch.column(col_idx);
75        if let Some(array) = column.as_any().downcast_ref::<PrimitiveArray<T>>() {
76            result.extend(array.iter().collect::<Vec<_>>())
77        } else {
78            return Err(DataFusionError::Execution(format!(
79                "Column at index {col_idx} is not {} instead of {}",
80                T::DATA_TYPE,
81                column.data_type(),
82            )));
83        }
84    }
85    Ok(result)
86}
87
88pub fn extract_byte_array<T: ByteArrayType>(
89    batches: &[RecordBatch],
90    col_idx: usize,
91) -> DFResult<Vec<Option<&T::Native>>> {
92    let mut result = Vec::new();
93    for batch in batches {
94        let column = batch.column(col_idx);
95        if let Some(array) = column.as_any().downcast_ref::<GenericByteArray<T>>() {
96            result.extend(array.iter().collect::<Vec<_>>())
97        } else {
98            return Err(DataFusionError::Execution(format!(
99                "Column at index {col_idx} is not {} instead of {}",
100                T::DATA_TYPE,
101                column.data_type(),
102            )));
103        }
104    }
105    Ok(result)
106}
107
108#[cfg(test)]
109mod tests {
110    use crate::{extract_byte_array, extract_primitive_array};
111    use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray};
112    use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema, Utf8Type};
113    use std::sync::Arc;
114
115    #[tokio::test]
116    async fn test_extract_primitive_array() {
117        let expected = vec![Some(1), Some(2), None];
118        let batches = vec![
119            RecordBatch::try_new(
120                Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])),
121                vec![Arc::new(Int32Array::from(expected.clone()))],
122            )
123            .unwrap(),
124        ];
125        let result: Vec<Option<i32>> = extract_primitive_array::<Int32Type>(&batches, 0).unwrap();
126        assert_eq!(result, expected);
127    }
128
129    #[tokio::test]
130    async fn test_extract_byte_array() {
131        let expected = vec![Some("abc"), Some("def"), None];
132        let batches = vec![
133            RecordBatch::try_new(
134                Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])),
135                vec![Arc::new(StringArray::from(expected.clone()))],
136            )
137            .unwrap(),
138        ];
139        let result: Vec<Option<&str>> = extract_byte_array::<Utf8Type>(&batches, 0).unwrap();
140        assert_eq!(result, expected);
141    }
142}