datafusion_remote_table/
utils.rs1use crate::{ConnectionOptions, DFResult, RemoteTable};
2use datafusion::arrow::array::{Array, GenericByteArray, PrimitiveArray, RecordBatch};
3use datafusion::arrow::datatypes::{
4 ArrowPrimitiveType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type,
5};
6use datafusion::error::DataFusionError;
7use datafusion::prelude::SessionContext;
8use std::sync::Arc;
9
10pub async fn remote_collect(
11 options: ConnectionOptions,
12 sql: impl Into<String>,
13) -> DFResult<Vec<RecordBatch>> {
14 let table = RemoteTable::try_new(options, sql).await?;
15 let ctx = SessionContext::new();
16 ctx.read_table(Arc::new(table))?.collect().await
17}
18
19pub async fn remote_collect_primitive_column<T: ArrowPrimitiveType>(
20 options: ConnectionOptions,
21 sql: impl Into<String>,
22 col_idx: usize,
23) -> DFResult<Vec<Option<T::Native>>> {
24 let batches = remote_collect(options, sql).await?;
25 extract_primitive_array::<T>(&batches, col_idx)
26}
27
28pub async fn remote_collect_utf8_column(
29 options: ConnectionOptions,
30 sql: impl Into<String>,
31 col_idx: usize,
32) -> DFResult<Vec<Option<String>>> {
33 let batches = remote_collect(options, sql).await?;
34 let vec = extract_byte_array::<Utf8Type>(&batches, col_idx)?;
35 Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
36}
37
38pub async fn remote_collect_large_utf8_column(
39 options: ConnectionOptions,
40 sql: impl Into<String>,
41 col_idx: usize,
42) -> DFResult<Vec<Option<String>>> {
43 let batches = remote_collect(options, sql).await?;
44 let vec = extract_byte_array::<LargeUtf8Type>(&batches, col_idx)?;
45 Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
46}
47
48pub async fn remote_collect_binary_column(
49 options: ConnectionOptions,
50 sql: impl Into<String>,
51 col_idx: usize,
52) -> DFResult<Vec<Option<Vec<u8>>>> {
53 let batches = remote_collect(options, sql).await?;
54 let vec = extract_byte_array::<BinaryType>(&batches, col_idx)?;
55 Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
56}
57
58pub async fn remote_collect_large_binary_column(
59 options: ConnectionOptions,
60 sql: impl Into<String>,
61 col_idx: usize,
62) -> DFResult<Vec<Option<Vec<u8>>>> {
63 let batches = remote_collect(options, sql).await?;
64 let vec = extract_byte_array::<LargeBinaryType>(&batches, col_idx)?;
65 Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
66}
67
68pub fn extract_primitive_array<T: ArrowPrimitiveType>(
69 batches: &[RecordBatch],
70 col_idx: usize,
71) -> DFResult<Vec<Option<T::Native>>> {
72 let mut result = Vec::new();
73 for batch in batches {
74 let column = batch.column(col_idx);
75 if let Some(array) = column.as_any().downcast_ref::<PrimitiveArray<T>>() {
76 result.extend(array.iter().collect::<Vec<_>>())
77 } else {
78 return Err(DataFusionError::Execution(format!(
79 "Column at index {col_idx} is not {} instead of {}",
80 T::DATA_TYPE,
81 column.data_type(),
82 )));
83 }
84 }
85 Ok(result)
86}
87
88pub fn extract_byte_array<T: ByteArrayType>(
89 batches: &[RecordBatch],
90 col_idx: usize,
91) -> DFResult<Vec<Option<&T::Native>>> {
92 let mut result = Vec::new();
93 for batch in batches {
94 let column = batch.column(col_idx);
95 if let Some(array) = column.as_any().downcast_ref::<GenericByteArray<T>>() {
96 result.extend(array.iter().collect::<Vec<_>>())
97 } else {
98 return Err(DataFusionError::Execution(format!(
99 "Column at index {col_idx} is not {} instead of {}",
100 T::DATA_TYPE,
101 column.data_type(),
102 )));
103 }
104 }
105 Ok(result)
106}
107
108#[cfg(test)]
109mod tests {
110 use crate::{extract_byte_array, extract_primitive_array};
111 use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray};
112 use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema, Utf8Type};
113 use std::sync::Arc;
114
115 #[tokio::test]
116 async fn test_extract_primitive_array() {
117 let expected = vec![Some(1), Some(2), None];
118 let batches = vec![
119 RecordBatch::try_new(
120 Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])),
121 vec![Arc::new(Int32Array::from(expected.clone()))],
122 )
123 .unwrap(),
124 ];
125 let result: Vec<Option<i32>> = extract_primitive_array::<Int32Type>(&batches, 0).unwrap();
126 assert_eq!(result, expected);
127 }
128
129 #[tokio::test]
130 async fn test_extract_byte_array() {
131 let expected = vec![Some("abc"), Some("def"), None];
132 let batches = vec![
133 RecordBatch::try_new(
134 Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])),
135 vec![Arc::new(StringArray::from(expected.clone()))],
136 )
137 .unwrap(),
138 ];
139 let result: Vec<Option<&str>> = extract_byte_array::<Utf8Type>(&batches, 0).unwrap();
140 assert_eq!(result, expected);
141 }
142}