arkflow_plugin/input/
sql.rs1use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
2use arkflow_core::{Error, MessageBatch};
3
4use async_trait::async_trait;
5use datafusion::arrow;
6use datafusion::arrow::array::RecordBatch;
7use datafusion::arrow::datatypes::Schema;
8use datafusion::prelude::{SQLOptions, SessionContext};
9use serde::{Deserialize, Serialize};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SqlInputConfig {
15 select_sql: String,
16 create_table_sql: String,
17}
18
19pub struct SqlInput {
20 sql_config: SqlInputConfig,
21 read: AtomicBool,
22}
23
24impl SqlInput {
25 pub fn new(sql_config: SqlInputConfig) -> Result<Self, Error> {
26 Ok(Self {
27 sql_config,
28 read: AtomicBool::new(false),
29 })
30 }
31}
32
33#[async_trait]
34impl Input for SqlInput {
35 async fn connect(&self) -> Result<(), Error> {
36 Ok(())
37 }
38
39 async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
40 if self.read.load(Ordering::Acquire) {
41 return Err(Error::EOF);
42 }
43
44 let ctx = SessionContext::new();
45 let sql_options = SQLOptions::new()
46 .with_allow_ddl(true)
47 .with_allow_dml(false)
48 .with_allow_statements(false);
49 ctx.sql_with_options(&self.sql_config.create_table_sql, sql_options)
50 .await
51 .map_err(|e| Error::Config(format!("Failed to execute create_table_sql: {}", e)))?;
52
53 let sql_options = SQLOptions::new()
54 .with_allow_ddl(false)
55 .with_allow_dml(false)
56 .with_allow_statements(false);
57 let df = ctx
58 .sql_with_options(&self.sql_config.select_sql, sql_options)
59 .await
60 .map_err(|e| Error::Read(format!("Failed to execute select_sql: {}", e)))?;
61
62 let result_batches = df
63 .collect()
64 .await
65 .map_err(|e| Error::Read(format!("Failed to collect data from SQL query: {}", e)))?;
66
67 let x = match result_batches.len() {
68 0 => RecordBatch::new_empty(Arc::new(Schema::empty())),
69 1 => result_batches[0].clone(),
70 _ => arrow::compute::concat_batches(&&result_batches[0].schema(), &result_batches)
71 .map_err(|e| Error::Process(format!("Merge batches failed: {}", e)))?,
72 };
73
74 self.read.store(true, Ordering::Release);
75 Ok((MessageBatch::new_arrow(x), Arc::new(NoopAck)))
76 }
77
78 async fn close(&self) -> Result<(), Error> {
79 Ok(())
80 }
81}
82
83pub(crate) struct SqlInputBuilder;
84impl InputBuilder for SqlInputBuilder {
85 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
86 if config.is_none() {
87 return Err(Error::Config(
88 "SQL input configuration is missing".to_string(),
89 ));
90 }
91
92 let config: SqlInputConfig = serde_json::from_value(config.clone().unwrap())?;
93 Ok(Arc::new(SqlInput::new(config)?))
94 }
95}
96
97pub fn init() {
98 register_input_builder("sql", Arc::new(SqlInputBuilder));
99}
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use arkflow_core::Content;
104 use datafusion::arrow::array::{Int32Array, StringArray};
105 use std::fs::File;
106 use std::io::Write;
107 use tempfile::tempdir;
108
109 #[tokio::test]
110 async fn test_sql_input_new() {
111 let config = SqlInputConfig {
112 select_sql: "SELECT * FROM test".to_string(),
113 create_table_sql:
114 "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION 'test.csv'"
115 .to_string(),
116 };
117 let input = SqlInput::new(config);
118 assert!(input.is_ok());
119 }
120
121 #[tokio::test]
122 async fn test_sql_input_connect() {
123 let config = SqlInputConfig {
124 select_sql: "SELECT * FROM test".to_string(),
125 create_table_sql:
126 "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION 'test.csv'"
127 .to_string(),
128 };
129 let input = SqlInput::new(config).unwrap();
130 assert!(input.connect().await.is_ok());
131 }
132
133 #[tokio::test]
134 async fn test_sql_input_read() -> Result<(), Error> {
135 let temp_dir = tempdir().unwrap();
137 let csv_path = temp_dir.path().join("test.csv");
138 let mut file = File::create(&csv_path).unwrap();
139 writeln!(file, "id,name").unwrap();
140 writeln!(file, "1,Alice").unwrap();
141 writeln!(file, "2,Bob").unwrap();
142
143 let config = SqlInputConfig {
144 select_sql: "SELECT * FROM test".to_string(),
145 create_table_sql: format!(
146 "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION '{}'",
147 csv_path.to_str().unwrap()
148 ),
149 };
150
151 let input = SqlInput::new(config)?;
152 let (batch, _ack) = input.read().await?;
153
154 match batch.content {
156 Content::Arrow(record_batch) => {
157 assert_eq!(record_batch.num_rows(), 2);
158 assert_eq!(record_batch.num_columns(), 2);
159
160 let id_array = record_batch
161 .column(0)
162 .as_any()
163 .downcast_ref::<Int32Array>()
164 .unwrap();
165 let name_array = record_batch
166 .column(1)
167 .as_any()
168 .downcast_ref::<StringArray>()
169 .unwrap();
170
171 assert_eq!(id_array.value(0), 1);
172 assert_eq!(id_array.value(1), 2);
173 assert_eq!(name_array.value(0), "Alice");
174 assert_eq!(name_array.value(1), "Bob");
175 }
176 _ => panic!("Expected Arrow content"),
177 }
178
179 let result = input.read().await;
181 assert!(matches!(result, Err(Error::EOF)));
182
183 Ok(())
184 }
185
186 #[tokio::test]
187 async fn test_sql_input_invalid_sql() {
188 let config = SqlInputConfig {
189 select_sql: "INVALID SQL".to_string(),
190 create_table_sql:
191 "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION 'test.csv'"
192 .to_string(),
193 };
194 let input = SqlInput::new(config).unwrap();
195 let result = input.read().await;
196 assert!(matches!(result, Err(Error::Read(_))));
197 }
198
199 #[tokio::test]
200 async fn test_sql_input_close() {
201 let config = SqlInputConfig {
202 select_sql: "SELECT * FROM test".to_string(),
203 create_table_sql:
204 "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION 'test.csv'"
205 .to_string(),
206 };
207 let input = SqlInput::new(config).unwrap();
208 assert!(input.close().await.is_ok());
209 }
210}