datafusion_remote_table/
utils.rs

1use crate::{ConnectionOptions, DFResult, RemoteTable, TableSource};
2use datafusion::arrow::array::{
3    Array, BooleanArray, GenericByteArray, PrimitiveArray, RecordBatch,
4};
5use datafusion::arrow::datatypes::{
6    ArrowPrimitiveType, BinaryType, BooleanType, ByteArrayType, LargeBinaryType, LargeUtf8Type,
7    Utf8Type,
8};
9use datafusion::error::DataFusionError;
10use datafusion::prelude::SessionContext;
11use std::sync::Arc;
12
13pub async fn remote_collect(
14    options: ConnectionOptions,
15    sql: impl Into<String>,
16) -> DFResult<Vec<RecordBatch>> {
17    let table = RemoteTable::try_new(options, TableSource::Query(sql.into())).await?;
18    let ctx = SessionContext::new();
19    ctx.read_table(Arc::new(table))?.collect().await
20}
21
22pub async fn remote_collect_primitive_column<T: ArrowPrimitiveType>(
23    options: ConnectionOptions,
24    sql: impl Into<String>,
25    col_idx: usize,
26) -> DFResult<Vec<Option<T::Native>>> {
27    let batches = remote_collect(options, sql).await?;
28    extract_primitive_array::<T>(&batches, col_idx)
29}
30
31pub async fn remote_collect_utf8_column(
32    options: ConnectionOptions,
33    sql: impl Into<String>,
34    col_idx: usize,
35) -> DFResult<Vec<Option<String>>> {
36    let batches = remote_collect(options, sql).await?;
37    let vec = extract_byte_array::<Utf8Type>(&batches, col_idx)?;
38    Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
39}
40
41pub async fn remote_collect_large_utf8_column(
42    options: ConnectionOptions,
43    sql: impl Into<String>,
44    col_idx: usize,
45) -> DFResult<Vec<Option<String>>> {
46    let batches = remote_collect(options, sql).await?;
47    let vec = extract_byte_array::<LargeUtf8Type>(&batches, col_idx)?;
48    Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
49}
50
51pub async fn remote_collect_binary_column(
52    options: ConnectionOptions,
53    sql: impl Into<String>,
54    col_idx: usize,
55) -> DFResult<Vec<Option<Vec<u8>>>> {
56    let batches = remote_collect(options, sql).await?;
57    let vec = extract_byte_array::<BinaryType>(&batches, col_idx)?;
58    Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
59}
60
61pub async fn remote_collect_large_binary_column(
62    options: ConnectionOptions,
63    sql: impl Into<String>,
64    col_idx: usize,
65) -> DFResult<Vec<Option<Vec<u8>>>> {
66    let batches = remote_collect(options, sql).await?;
67    let vec = extract_byte_array::<LargeBinaryType>(&batches, col_idx)?;
68    Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
69}
70
71pub fn extract_primitive_array<T: ArrowPrimitiveType>(
72    batches: &[RecordBatch],
73    col_idx: usize,
74) -> DFResult<Vec<Option<T::Native>>> {
75    let mut result = Vec::new();
76    for batch in batches {
77        let column = batch.column(col_idx);
78        if let Some(array) = column.as_any().downcast_ref::<PrimitiveArray<T>>() {
79            result.extend(array.iter().collect::<Vec<_>>())
80        } else {
81            return Err(DataFusionError::Execution(format!(
82                "Column at index {col_idx} is not {} instead of {}",
83                T::DATA_TYPE,
84                column.data_type(),
85            )));
86        }
87    }
88    Ok(result)
89}
90
91pub fn extract_boolean_array(
92    batches: &[RecordBatch],
93    col_idx: usize,
94) -> DFResult<Vec<Option<bool>>> {
95    let mut result = Vec::new();
96    for batch in batches {
97        let column = batch.column(col_idx);
98        if let Some(array) = column.as_any().downcast_ref::<BooleanArray>() {
99            result.extend(array.iter().collect::<Vec<_>>())
100        } else {
101            return Err(DataFusionError::Execution(format!(
102                "Column at index {col_idx} is not {} instead of {}",
103                BooleanType::DATA_TYPE,
104                column.data_type(),
105            )));
106        }
107    }
108    Ok(result)
109}
110
111pub fn extract_byte_array<T: ByteArrayType>(
112    batches: &[RecordBatch],
113    col_idx: usize,
114) -> DFResult<Vec<Option<&T::Native>>> {
115    let mut result = Vec::new();
116    for batch in batches {
117        let column = batch.column(col_idx);
118        if let Some(array) = column.as_any().downcast_ref::<GenericByteArray<T>>() {
119            result.extend(array.iter().collect::<Vec<_>>())
120        } else {
121            return Err(DataFusionError::Execution(format!(
122                "Column at index {col_idx} is not {} instead of {}",
123                T::DATA_TYPE,
124                column.data_type(),
125            )));
126        }
127    }
128    Ok(result)
129}
130
131#[cfg(test)]
132mod tests {
133    use crate::{extract_boolean_array, extract_byte_array, extract_primitive_array};
134    use datafusion::arrow::array::{BooleanArray, Int32Array, RecordBatch, StringArray};
135    use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema, Utf8Type};
136    use std::sync::Arc;
137
138    #[tokio::test]
139    async fn test_extract_primitive_array() {
140        let expected = vec![Some(1), Some(2), None];
141        let batches = vec![
142            RecordBatch::try_new(
143                Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])),
144                vec![Arc::new(Int32Array::from(expected.clone()))],
145            )
146            .unwrap(),
147        ];
148        let result: Vec<Option<i32>> = extract_primitive_array::<Int32Type>(&batches, 0).unwrap();
149        assert_eq!(result, expected);
150    }
151
152    #[tokio::test]
153    async fn test_extract_bool_array() {
154        let expected = vec![Some(true), Some(false), None];
155        let batches = vec![
156            RecordBatch::try_new(
157                Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])),
158                vec![Arc::new(BooleanArray::from(expected.clone()))],
159            )
160            .unwrap(),
161        ];
162        let result: Vec<Option<bool>> = extract_boolean_array(&batches, 0).unwrap();
163        assert_eq!(result, expected);
164    }
165
166    #[tokio::test]
167    async fn test_extract_byte_array() {
168        let expected = vec![Some("abc"), Some("def"), None];
169        let batches = vec![
170            RecordBatch::try_new(
171                Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])),
172                vec![Arc::new(StringArray::from(expected.clone()))],
173            )
174            .unwrap(),
175        ];
176        let result: Vec<Option<&str>> = extract_byte_array::<Utf8Type>(&batches, 0).unwrap();
177        assert_eq!(result, expected);
178    }
179}