1use crate::{ConnectionOptions, DFResult, RemoteTable, TableSource};
2use datafusion::arrow::array::{
3 Array, BooleanArray, GenericByteArray, PrimitiveArray, RecordBatch,
4};
5use datafusion::arrow::datatypes::{
6 ArrowPrimitiveType, BinaryType, BooleanType, ByteArrayType, LargeBinaryType, LargeUtf8Type,
7 Utf8Type,
8};
9use datafusion::error::DataFusionError;
10use datafusion::prelude::SessionContext;
11use std::sync::Arc;
12
13pub async fn remote_collect(
14 options: ConnectionOptions,
15 sql: impl Into<String>,
16) -> DFResult<Vec<RecordBatch>> {
17 let table = RemoteTable::try_new(options, TableSource::Query(sql.into())).await?;
18 let ctx = SessionContext::new();
19 ctx.read_table(Arc::new(table))?.collect().await
20}
21
22pub async fn remote_collect_primitive_column<T: ArrowPrimitiveType>(
23 options: ConnectionOptions,
24 sql: impl Into<String>,
25 col_idx: usize,
26) -> DFResult<Vec<Option<T::Native>>> {
27 let batches = remote_collect(options, sql).await?;
28 extract_primitive_array::<T>(&batches, col_idx)
29}
30
31pub async fn remote_collect_utf8_column(
32 options: ConnectionOptions,
33 sql: impl Into<String>,
34 col_idx: usize,
35) -> DFResult<Vec<Option<String>>> {
36 let batches = remote_collect(options, sql).await?;
37 let vec = extract_byte_array::<Utf8Type>(&batches, col_idx)?;
38 Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
39}
40
41pub async fn remote_collect_large_utf8_column(
42 options: ConnectionOptions,
43 sql: impl Into<String>,
44 col_idx: usize,
45) -> DFResult<Vec<Option<String>>> {
46 let batches = remote_collect(options, sql).await?;
47 let vec = extract_byte_array::<LargeUtf8Type>(&batches, col_idx)?;
48 Ok(vec.into_iter().map(|s| s.map(|s| s.to_string())).collect())
49}
50
51pub async fn remote_collect_binary_column(
52 options: ConnectionOptions,
53 sql: impl Into<String>,
54 col_idx: usize,
55) -> DFResult<Vec<Option<Vec<u8>>>> {
56 let batches = remote_collect(options, sql).await?;
57 let vec = extract_byte_array::<BinaryType>(&batches, col_idx)?;
58 Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
59}
60
61pub async fn remote_collect_large_binary_column(
62 options: ConnectionOptions,
63 sql: impl Into<String>,
64 col_idx: usize,
65) -> DFResult<Vec<Option<Vec<u8>>>> {
66 let batches = remote_collect(options, sql).await?;
67 let vec = extract_byte_array::<LargeBinaryType>(&batches, col_idx)?;
68 Ok(vec.into_iter().map(|s| s.map(|s| s.to_vec())).collect())
69}
70
71pub fn extract_primitive_array<T: ArrowPrimitiveType>(
72 batches: &[RecordBatch],
73 col_idx: usize,
74) -> DFResult<Vec<Option<T::Native>>> {
75 let mut result = Vec::new();
76 for batch in batches {
77 let column = batch.column(col_idx);
78 if let Some(array) = column.as_any().downcast_ref::<PrimitiveArray<T>>() {
79 result.extend(array.iter().collect::<Vec<_>>())
80 } else {
81 return Err(DataFusionError::Execution(format!(
82 "Column at index {col_idx} is not {} instead of {}",
83 T::DATA_TYPE,
84 column.data_type(),
85 )));
86 }
87 }
88 Ok(result)
89}
90
91pub fn extract_boolean_array(
92 batches: &[RecordBatch],
93 col_idx: usize,
94) -> DFResult<Vec<Option<bool>>> {
95 let mut result = Vec::new();
96 for batch in batches {
97 let column = batch.column(col_idx);
98 if let Some(array) = column.as_any().downcast_ref::<BooleanArray>() {
99 result.extend(array.iter().collect::<Vec<_>>())
100 } else {
101 return Err(DataFusionError::Execution(format!(
102 "Column at index {col_idx} is not {} instead of {}",
103 BooleanType::DATA_TYPE,
104 column.data_type(),
105 )));
106 }
107 }
108 Ok(result)
109}
110
111pub fn extract_byte_array<T: ByteArrayType>(
112 batches: &[RecordBatch],
113 col_idx: usize,
114) -> DFResult<Vec<Option<&T::Native>>> {
115 let mut result = Vec::new();
116 for batch in batches {
117 let column = batch.column(col_idx);
118 if let Some(array) = column.as_any().downcast_ref::<GenericByteArray<T>>() {
119 result.extend(array.iter().collect::<Vec<_>>())
120 } else {
121 return Err(DataFusionError::Execution(format!(
122 "Column at index {col_idx} is not {} instead of {}",
123 T::DATA_TYPE,
124 column.data_type(),
125 )));
126 }
127 }
128 Ok(result)
129}
130
131#[cfg(test)]
132mod tests {
133 use crate::{extract_boolean_array, extract_byte_array, extract_primitive_array};
134 use datafusion::arrow::array::{BooleanArray, Int32Array, RecordBatch, StringArray};
135 use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Schema, Utf8Type};
136 use std::sync::Arc;
137
138 #[tokio::test]
139 async fn test_extract_primitive_array() {
140 let expected = vec![Some(1), Some(2), None];
141 let batches = vec![
142 RecordBatch::try_new(
143 Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])),
144 vec![Arc::new(Int32Array::from(expected.clone()))],
145 )
146 .unwrap(),
147 ];
148 let result: Vec<Option<i32>> = extract_primitive_array::<Int32Type>(&batches, 0).unwrap();
149 assert_eq!(result, expected);
150 }
151
152 #[tokio::test]
153 async fn test_extract_bool_array() {
154 let expected = vec![Some(true), Some(false), None];
155 let batches = vec![
156 RecordBatch::try_new(
157 Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])),
158 vec![Arc::new(BooleanArray::from(expected.clone()))],
159 )
160 .unwrap(),
161 ];
162 let result: Vec<Option<bool>> = extract_boolean_array(&batches, 0).unwrap();
163 assert_eq!(result, expected);
164 }
165
166 #[tokio::test]
167 async fn test_extract_byte_array() {
168 let expected = vec![Some("abc"), Some("def"), None];
169 let batches = vec![
170 RecordBatch::try_new(
171 Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])),
172 vec![Arc::new(StringArray::from(expected.clone()))],
173 )
174 .unwrap(),
175 ];
176 let result: Vec<Option<&str>> = extract_byte_array::<Utf8Type>(&batches, 0).unwrap();
177 assert_eq!(result, expected);
178 }
179}