arkflow_plugin/processor/
sql.rs1use arkflow_core::processor::{register_processor_builder, Processor, ProcessorBuilder};
6use arkflow_core::{Content, Error, MessageBatch};
7use async_trait::async_trait;
8use datafusion::arrow;
9use datafusion::arrow::datatypes::Schema;
10use datafusion::arrow::record_batch::RecordBatch;
11use datafusion::prelude::*;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14
15const DEFAULT_TABLE_NAME: &str = "flow";
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SqlProcessorConfig {
19 pub query: String,
21
22 pub table_name: Option<String>,
24}
25
26pub struct SqlProcessor {
28 config: SqlProcessorConfig,
29}
30
31impl SqlProcessor {
32 pub fn new(config: SqlProcessorConfig) -> Result<Self, Error> {
34 Ok(Self { config })
35 }
36
37 async fn execute_query(&self, batch: RecordBatch) -> Result<RecordBatch, Error> {
39 let ctx = SessionContext::new();
41 let table_name = self
42 .config
43 .table_name
44 .as_deref()
45 .unwrap_or(DEFAULT_TABLE_NAME);
46
47 ctx.register_batch(table_name, batch)
48 .map_err(|e| Error::Process(format!("Registration failed: {}", e)))?;
49
50 let sql_options = SQLOptions::new()
52 .with_allow_ddl(false)
53 .with_allow_dml(false)
54 .with_allow_statements(false);
55 let df = ctx
56 .sql_with_options(&self.config.query, sql_options)
57 .await
58 .map_err(|e| Error::Process(format!("SQL query error: {}", e)))?;
59
60 let result_batches = df
61 .collect()
62 .await
63 .map_err(|e| Error::Process(format!("Collection query results error: {}", e)))?;
64
65 if result_batches.is_empty() {
66 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
67 }
68
69 if result_batches.len() == 1 {
70 return Ok(result_batches[0].clone());
71 }
72
73 Ok(
74 arrow::compute::concat_batches(&&result_batches[0].schema(), &result_batches)
75 .map_err(|e| Error::Process(format!("Batch merge failed: {}", e)))?,
76 )
77 }
78}
79
80#[async_trait]
81impl Processor for SqlProcessor {
82 async fn process(&self, msg_batch: MessageBatch) -> Result<Vec<MessageBatch>, Error> {
83 if msg_batch.is_empty() {
85 return Ok(vec![]);
86 }
87
88 let batch: RecordBatch = match msg_batch.content {
89 Content::Arrow(v) => v,
90 Content::Binary(_) => {
91 return Err(Error::Process("Unsupported input format".to_string()))?;
92 }
93 };
94
95 let result_batch = self.execute_query(batch).await?;
97 Ok(vec![MessageBatch::new_arrow(result_batch)])
98 }
99
100 async fn close(&self) -> Result<(), Error> {
101 Ok(())
102 }
103}
104
105pub(crate) struct SqlProcessorBuilder;
106impl ProcessorBuilder for SqlProcessorBuilder {
107 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Processor>, Error> {
108 if config.is_none() {
109 return Err(Error::Config(
110 "Batch processor configuration is missing".to_string(),
111 ));
112 }
113 let config: SqlProcessorConfig = serde_json::from_value(config.clone().unwrap())?;
114 Ok(Arc::new(SqlProcessor::new(config)?))
115 }
116}
117
118pub fn init() {
119 register_processor_builder("sql", Arc::new(SqlProcessorBuilder));
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use datafusion::arrow::array::{Int32Array, StringArray};
126 use datafusion::arrow::datatypes::{DataType, Field, Schema};
127 use std::sync::Arc;
128
129 fn create_test_batch() -> RecordBatch {
131 let schema = Arc::new(Schema::new(vec![
133 Field::new("id", DataType::Int32, false),
134 Field::new("name", DataType::Utf8, false),
135 ]));
136
137 let id_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
139 let name_array = Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"]));
140
141 RecordBatch::try_new(schema, vec![id_array, name_array]).unwrap()
143 }
144
145 #[tokio::test]
146 async fn test_sql_processor_new() {
147 let config = SqlProcessorConfig {
149 query: "SELECT * FROM flow".to_string(),
150 table_name: None,
151 };
152 let processor = SqlProcessor::new(config);
153 assert!(processor.is_ok());
154 }
155
156 #[tokio::test]
157 async fn test_sql_processor_process_simple_query() -> Result<(), Error> {
158 let config = SqlProcessorConfig {
160 query: "SELECT * FROM flow".to_string(),
161 table_name: None,
162 };
163 let processor = SqlProcessor::new(config)?;
164 let batch = create_test_batch();
165 let msg_batch = MessageBatch::new_arrow(batch);
166
167 let result = processor.process(msg_batch).await?;
168
169 assert_eq!(result.len(), 1);
171 match &result[0].content {
172 Content::Arrow(record_batch) => {
173 assert_eq!(record_batch.num_rows(), 3);
175 assert_eq!(record_batch.num_columns(), 2);
176
177 let id_array = record_batch
179 .column(0)
180 .as_any()
181 .downcast_ref::<Int32Array>()
182 .unwrap();
183 let name_array = record_batch
184 .column(1)
185 .as_any()
186 .downcast_ref::<StringArray>()
187 .unwrap();
188
189 assert_eq!(id_array.value(0), 1);
190 assert_eq!(id_array.value(1), 2);
191 assert_eq!(id_array.value(2), 3);
192 assert_eq!(name_array.value(0), "Alice");
193 assert_eq!(name_array.value(1), "Bob");
194 assert_eq!(name_array.value(2), "Charlie");
195 }
196 _ => panic!("Expected Arrow content"),
197 }
198
199 Ok(())
200 }
201
202 #[tokio::test]
203 async fn test_sql_processor_process_filter_query() -> Result<(), Error> {
204 let config = SqlProcessorConfig {
206 query: "SELECT * FROM flow WHERE id > 1".to_string(),
207 table_name: None,
208 };
209 let processor = SqlProcessor::new(config)?;
210 let batch = create_test_batch();
211 let msg_batch = MessageBatch::new_arrow(batch);
212
213 let result = processor.process(msg_batch).await?;
214
215 assert_eq!(result.len(), 1);
217 match &result[0].content {
218 Content::Arrow(record_batch) => {
219 assert_eq!(record_batch.num_rows(), 2);
221 assert_eq!(record_batch.num_columns(), 2);
222
223 let id_array = record_batch
225 .column(0)
226 .as_any()
227 .downcast_ref::<Int32Array>()
228 .unwrap();
229
230 assert_eq!(id_array.value(0), 2);
231 assert_eq!(id_array.value(1), 3);
232 }
233 _ => panic!("Expected Arrow content"),
234 }
235
236 Ok(())
237 }
238
239 #[tokio::test]
240 async fn test_sql_processor_process_projection_query() -> Result<(), Error> {
241 let config = SqlProcessorConfig {
243 query: "SELECT id FROM flow".to_string(),
244 table_name: None,
245 };
246 let processor = SqlProcessor::new(config)?;
247 let batch = create_test_batch();
248 let msg_batch = MessageBatch::new_arrow(batch);
249
250 let result = processor.process(msg_batch).await?;
251
252 assert_eq!(result.len(), 1);
254 match &result[0].content {
255 Content::Arrow(record_batch) => {
256 assert_eq!(record_batch.num_rows(), 3);
258 assert_eq!(record_batch.num_columns(), 1);
259
260 let id_array = record_batch
262 .column(0)
263 .as_any()
264 .downcast_ref::<Int32Array>()
265 .unwrap();
266
267 assert_eq!(id_array.value(0), 1);
268 assert_eq!(id_array.value(1), 2);
269 assert_eq!(id_array.value(2), 3);
270 }
271 _ => panic!("Expected Arrow content"),
272 }
273
274 Ok(())
275 }
276
277 #[tokio::test]
278 async fn test_sql_processor_process_empty_batch() -> Result<(), Error> {
279 let config = SqlProcessorConfig {
281 query: "SELECT * FROM flow".to_string(),
282 table_name: None,
283 };
284 let processor = SqlProcessor::new(config)?;
285 let msg_batch = MessageBatch::new_binary(vec![]);
286
287 let result = processor.process(msg_batch).await?;
288
289 assert_eq!(result.len(), 0);
291
292 Ok(())
293 }
294
295 #[tokio::test]
296 async fn test_sql_processor_process_binary_content() {
297 let config = SqlProcessorConfig {
299 query: "SELECT * FROM flow".to_string(),
300 table_name: None,
301 };
302 let processor = SqlProcessor::new(config).unwrap();
303 let msg_batch = MessageBatch::new_binary(vec![vec![1]]);
304
305 let result = processor.process(msg_batch).await;
306
307 assert!(matches!(result, Err(Error::Process(_))));
309 }
310
311 #[tokio::test]
312 async fn test_sql_processor_process_invalid_query() {
313 let config = SqlProcessorConfig {
315 query: "INVALID SQL".to_string(),
316 table_name: None,
317 };
318 let processor = SqlProcessor::new(config).unwrap();
319 let batch = create_test_batch();
320 let msg_batch = MessageBatch::new_arrow(batch);
321
322 let result = processor.process(msg_batch).await;
323
324 assert!(matches!(result, Err(Error::Process(_))));
326 }
327
328 #[tokio::test]
329 async fn test_sql_processor_process_custom_table_name() -> Result<(), Error> {
330 let config = SqlProcessorConfig {
332 query: "SELECT * FROM custom_table".to_string(),
333 table_name: Some("custom_table".to_string()),
334 };
335 let processor = SqlProcessor::new(config)?;
336 let batch = create_test_batch();
337 let msg_batch = MessageBatch::new_arrow(batch);
338
339 let result = processor.process(msg_batch).await?;
340
341 assert_eq!(result.len(), 1);
343 match &result[0].content {
344 Content::Arrow(record_batch) => {
345 assert_eq!(record_batch.num_rows(), 3);
346 }
347 _ => panic!("Expected Arrow content"),
348 }
349
350 Ok(())
351 }
352
353 #[tokio::test]
354 async fn test_sql_processor_close() {
355 let config = SqlProcessorConfig {
357 query: "SELECT * FROM flow".to_string(),
358 table_name: None,
359 };
360 let processor = SqlProcessor::new(config).unwrap();
361
362 assert!(processor.close().await.is_ok());
364 }
365}