arkflow_plugin/processor/
sql.rs

1//! SQL processor component
2//!
3//! DataFusion is used to process data with SQL queries.
4
5use arkflow_core::processor::{register_processor_builder, Processor, ProcessorBuilder};
6use arkflow_core::{Content, Error, MessageBatch};
7use async_trait::async_trait;
8use datafusion::arrow;
9use datafusion::arrow::datatypes::Schema;
10use datafusion::arrow::record_batch::RecordBatch;
11use datafusion::prelude::*;
12use serde::{Deserialize, Serialize};
13use std::sync::Arc;
14
15const DEFAULT_TABLE_NAME: &str = "flow";
16/// SQL processor configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SqlProcessorConfig {
19    /// SQL query statement
20    pub query: String,
21
22    /// Table name (used in SQL queries)
23    pub table_name: Option<String>,
24}
25
26/// SQL processor component
27pub struct SqlProcessor {
28    config: SqlProcessorConfig,
29}
30
31impl SqlProcessor {
32    /// Create a new SQL processor component.
33    pub fn new(config: SqlProcessorConfig) -> Result<Self, Error> {
34        Ok(Self { config })
35    }
36
37    /// Execute SQL query
38    async fn execute_query(&self, batch: RecordBatch) -> Result<RecordBatch, Error> {
39        // Create a session context
40        let ctx = SessionContext::new();
41        let table_name = self
42            .config
43            .table_name
44            .as_deref()
45            .unwrap_or(DEFAULT_TABLE_NAME);
46
47        ctx.register_batch(table_name, batch)
48            .map_err(|e| Error::Process(format!("Registration failed: {}", e)))?;
49
50        // Execute the SQL query and collect the results.
51        let sql_options = SQLOptions::new()
52            .with_allow_ddl(false)
53            .with_allow_dml(false)
54            .with_allow_statements(false);
55        let df = ctx
56            .sql_with_options(&self.config.query, sql_options)
57            .await
58            .map_err(|e| Error::Process(format!("SQL query error: {}", e)))?;
59
60        let result_batches = df
61            .collect()
62            .await
63            .map_err(|e| Error::Process(format!("Collection query results error: {}", e)))?;
64
65        if result_batches.is_empty() {
66            return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
67        }
68
69        if result_batches.len() == 1 {
70            return Ok(result_batches[0].clone());
71        }
72
73        Ok(
74            arrow::compute::concat_batches(&&result_batches[0].schema(), &result_batches)
75                .map_err(|e| Error::Process(format!("Batch merge failed: {}", e)))?,
76        )
77    }
78}
79
80#[async_trait]
81impl Processor for SqlProcessor {
82    async fn process(&self, msg_batch: MessageBatch) -> Result<Vec<MessageBatch>, Error> {
83        // If the batch is empty, return an empty result.
84        if msg_batch.is_empty() {
85            return Ok(vec![]);
86        }
87
88        let batch: RecordBatch = match msg_batch.content {
89            Content::Arrow(v) => v,
90            Content::Binary(_) => {
91                return Err(Error::Process("Unsupported input format".to_string()))?;
92            }
93        };
94
95        // Execute SQL query
96        let result_batch = self.execute_query(batch).await?;
97        Ok(vec![MessageBatch::new_arrow(result_batch)])
98    }
99
100    async fn close(&self) -> Result<(), Error> {
101        Ok(())
102    }
103}
104
105pub(crate) struct SqlProcessorBuilder;
106impl ProcessorBuilder for SqlProcessorBuilder {
107    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Processor>, Error> {
108        if config.is_none() {
109            return Err(Error::Config(
110                "Batch processor configuration is missing".to_string(),
111            ));
112        }
113        let config: SqlProcessorConfig = serde_json::from_value(config.clone().unwrap())?;
114        Ok(Arc::new(SqlProcessor::new(config)?))
115    }
116}
117
118pub fn init() {
119    register_processor_builder("sql", Arc::new(SqlProcessorBuilder));
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use datafusion::arrow::array::{Int32Array, StringArray};
126    use datafusion::arrow::datatypes::{DataType, Field, Schema};
127    use std::sync::Arc;
128
129    // Helper function to create a test record batch
130    fn create_test_batch() -> RecordBatch {
131        // Create schema
132        let schema = Arc::new(Schema::new(vec![
133            Field::new("id", DataType::Int32, false),
134            Field::new("name", DataType::Utf8, false),
135        ]));
136
137        // Create data
138        let id_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
139        let name_array = Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"]));
140
141        // Create record batch
142        RecordBatch::try_new(schema, vec![id_array, name_array]).unwrap()
143    }
144
145    #[tokio::test]
146    async fn test_sql_processor_new() {
147        // Test creating a new SQL processor
148        let config = SqlProcessorConfig {
149            query: "SELECT * FROM flow".to_string(),
150            table_name: None,
151        };
152        let processor = SqlProcessor::new(config);
153        assert!(processor.is_ok());
154    }
155
156    #[tokio::test]
157    async fn test_sql_processor_process_simple_query() -> Result<(), Error> {
158        // Test processing a simple SELECT query
159        let config = SqlProcessorConfig {
160            query: "SELECT * FROM flow".to_string(),
161            table_name: None,
162        };
163        let processor = SqlProcessor::new(config)?;
164        let batch = create_test_batch();
165        let msg_batch = MessageBatch::new_arrow(batch);
166
167        let result = processor.process(msg_batch).await?;
168
169        // Verify the result
170        assert_eq!(result.len(), 1);
171        match &result[0].content {
172            Content::Arrow(record_batch) => {
173                // Check that all rows were returned
174                assert_eq!(record_batch.num_rows(), 3);
175                assert_eq!(record_batch.num_columns(), 2);
176
177                // Check column values
178                let id_array = record_batch
179                    .column(0)
180                    .as_any()
181                    .downcast_ref::<Int32Array>()
182                    .unwrap();
183                let name_array = record_batch
184                    .column(1)
185                    .as_any()
186                    .downcast_ref::<StringArray>()
187                    .unwrap();
188
189                assert_eq!(id_array.value(0), 1);
190                assert_eq!(id_array.value(1), 2);
191                assert_eq!(id_array.value(2), 3);
192                assert_eq!(name_array.value(0), "Alice");
193                assert_eq!(name_array.value(1), "Bob");
194                assert_eq!(name_array.value(2), "Charlie");
195            }
196            _ => panic!("Expected Arrow content"),
197        }
198
199        Ok(())
200    }
201
202    #[tokio::test]
203    async fn test_sql_processor_process_filter_query() -> Result<(), Error> {
204        // Test processing a query with a filter
205        let config = SqlProcessorConfig {
206            query: "SELECT * FROM flow WHERE id > 1".to_string(),
207            table_name: None,
208        };
209        let processor = SqlProcessor::new(config)?;
210        let batch = create_test_batch();
211        let msg_batch = MessageBatch::new_arrow(batch);
212
213        let result = processor.process(msg_batch).await?;
214
215        // Verify the result
216        assert_eq!(result.len(), 1);
217        match &result[0].content {
218            Content::Arrow(record_batch) => {
219                // Check that only filtered rows were returned
220                assert_eq!(record_batch.num_rows(), 2);
221                assert_eq!(record_batch.num_columns(), 2);
222
223                // Check column values
224                let id_array = record_batch
225                    .column(0)
226                    .as_any()
227                    .downcast_ref::<Int32Array>()
228                    .unwrap();
229
230                assert_eq!(id_array.value(0), 2);
231                assert_eq!(id_array.value(1), 3);
232            }
233            _ => panic!("Expected Arrow content"),
234        }
235
236        Ok(())
237    }
238
239    #[tokio::test]
240    async fn test_sql_processor_process_projection_query() -> Result<(), Error> {
241        // Test processing a query with column projection
242        let config = SqlProcessorConfig {
243            query: "SELECT id FROM flow".to_string(),
244            table_name: None,
245        };
246        let processor = SqlProcessor::new(config)?;
247        let batch = create_test_batch();
248        let msg_batch = MessageBatch::new_arrow(batch);
249
250        let result = processor.process(msg_batch).await?;
251
252        // Verify the result
253        assert_eq!(result.len(), 1);
254        match &result[0].content {
255            Content::Arrow(record_batch) => {
256                // Check that only the id column was returned
257                assert_eq!(record_batch.num_rows(), 3);
258                assert_eq!(record_batch.num_columns(), 1);
259
260                // Check column values
261                let id_array = record_batch
262                    .column(0)
263                    .as_any()
264                    .downcast_ref::<Int32Array>()
265                    .unwrap();
266
267                assert_eq!(id_array.value(0), 1);
268                assert_eq!(id_array.value(1), 2);
269                assert_eq!(id_array.value(2), 3);
270            }
271            _ => panic!("Expected Arrow content"),
272        }
273
274        Ok(())
275    }
276
277    #[tokio::test]
278    async fn test_sql_processor_process_empty_batch() -> Result<(), Error> {
279        // Test processing an empty batch
280        let config = SqlProcessorConfig {
281            query: "SELECT * FROM flow".to_string(),
282            table_name: None,
283        };
284        let processor = SqlProcessor::new(config)?;
285        let msg_batch = MessageBatch::new_binary(vec![]);
286
287        let result = processor.process(msg_batch).await?;
288
289        // Verify that an empty result is returned
290        assert_eq!(result.len(), 0);
291
292        Ok(())
293    }
294
295    #[tokio::test]
296    async fn test_sql_processor_process_binary_content() {
297        // Test processing binary content (should return an error)
298        let config = SqlProcessorConfig {
299            query: "SELECT * FROM flow".to_string(),
300            table_name: None,
301        };
302        let processor = SqlProcessor::new(config).unwrap();
303        let msg_batch = MessageBatch::new_binary(vec![vec![1]]);
304
305        let result = processor.process(msg_batch).await;
306
307        // Verify that an error is returned
308        assert!(matches!(result, Err(Error::Process(_))));
309    }
310
311    #[tokio::test]
312    async fn test_sql_processor_process_invalid_query() {
313        // Test processing with an invalid SQL query
314        let config = SqlProcessorConfig {
315            query: "INVALID SQL".to_string(),
316            table_name: None,
317        };
318        let processor = SqlProcessor::new(config).unwrap();
319        let batch = create_test_batch();
320        let msg_batch = MessageBatch::new_arrow(batch);
321
322        let result = processor.process(msg_batch).await;
323
324        // Verify that an error is returned
325        assert!(matches!(result, Err(Error::Process(_))));
326    }
327
328    #[tokio::test]
329    async fn test_sql_processor_process_custom_table_name() -> Result<(), Error> {
330        // Test processing with a custom table name
331        let config = SqlProcessorConfig {
332            query: "SELECT * FROM custom_table".to_string(),
333            table_name: Some("custom_table".to_string()),
334        };
335        let processor = SqlProcessor::new(config)?;
336        let batch = create_test_batch();
337        let msg_batch = MessageBatch::new_arrow(batch);
338
339        let result = processor.process(msg_batch).await?;
340
341        // Verify the result
342        assert_eq!(result.len(), 1);
343        match &result[0].content {
344            Content::Arrow(record_batch) => {
345                assert_eq!(record_batch.num_rows(), 3);
346            }
347            _ => panic!("Expected Arrow content"),
348        }
349
350        Ok(())
351    }
352
353    #[tokio::test]
354    async fn test_sql_processor_close() {
355        // Test closing the processor
356        let config = SqlProcessorConfig {
357            query: "SELECT * FROM flow".to_string(),
358            table_name: None,
359        };
360        let processor = SqlProcessor::new(config).unwrap();
361
362        // Verify that close returns Ok
363        assert!(processor.close().await.is_ok());
364    }
365}