arkflow_plugin/input/
sql.rs

1use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
2use arkflow_core::{Error, MessageBatch};
3
4use async_trait::async_trait;
5use datafusion::arrow;
6use datafusion::arrow::array::RecordBatch;
7use datafusion::arrow::datatypes::Schema;
8use datafusion::prelude::{SQLOptions, SessionContext};
9use serde::{Deserialize, Serialize};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SqlInputConfig {
15    select_sql: String,
16    create_table_sql: String,
17}
18
19pub struct SqlInput {
20    sql_config: SqlInputConfig,
21    read: AtomicBool,
22}
23
24impl SqlInput {
25    pub fn new(sql_config: SqlInputConfig) -> Result<Self, Error> {
26        Ok(Self {
27            sql_config,
28            read: AtomicBool::new(false),
29        })
30    }
31}
32
33#[async_trait]
34impl Input for SqlInput {
35    async fn connect(&self) -> Result<(), Error> {
36        Ok(())
37    }
38
39    async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
40        if self.read.load(Ordering::Acquire) {
41            return Err(Error::EOF);
42        }
43
44        let ctx = SessionContext::new();
45        let sql_options = SQLOptions::new()
46            .with_allow_ddl(true)
47            .with_allow_dml(false)
48            .with_allow_statements(false);
49        ctx.sql_with_options(&self.sql_config.create_table_sql, sql_options)
50            .await
51            .map_err(|e| Error::Config(format!("Failed to execute create_table_sql: {}", e)))?;
52
53        let sql_options = SQLOptions::new()
54            .with_allow_ddl(false)
55            .with_allow_dml(false)
56            .with_allow_statements(false);
57        let df = ctx
58            .sql_with_options(&self.sql_config.select_sql, sql_options)
59            .await
60            .map_err(|e| Error::Read(format!("Failed to execute select_sql: {}", e)))?;
61
62        let result_batches = df
63            .collect()
64            .await
65            .map_err(|e| Error::Read(format!("Failed to collect data from SQL query: {}", e)))?;
66
67        let x = match result_batches.len() {
68            0 => RecordBatch::new_empty(Arc::new(Schema::empty())),
69            1 => result_batches[0].clone(),
70            _ => arrow::compute::concat_batches(&&result_batches[0].schema(), &result_batches)
71                .map_err(|e| Error::Process(format!("Merge batches failed: {}", e)))?,
72        };
73
74        self.read.store(true, Ordering::Release);
75        Ok((MessageBatch::new_arrow(x), Arc::new(NoopAck)))
76    }
77
78    async fn close(&self) -> Result<(), Error> {
79        Ok(())
80    }
81}
82
83pub(crate) struct SqlInputBuilder;
84impl InputBuilder for SqlInputBuilder {
85    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
86        if config.is_none() {
87            return Err(Error::Config(
88                "SQL input configuration is missing".to_string(),
89            ));
90        }
91
92        let config: SqlInputConfig = serde_json::from_value(config.clone().unwrap())?;
93        Ok(Arc::new(SqlInput::new(config)?))
94    }
95}
96
97pub fn init() {
98    register_input_builder("sql", Arc::new(SqlInputBuilder));
99}
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use arkflow_core::Content;
104    use datafusion::arrow::array::{Int32Array, StringArray};
105    use std::fs::File;
106    use std::io::Write;
107    use tempfile::tempdir;
108
109    #[tokio::test]
110    async fn test_sql_input_new() {
111        let config = SqlInputConfig {
112            select_sql: "SELECT * FROM test".to_string(),
113            create_table_sql:
114                "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION 'test.csv'"
115                    .to_string(),
116        };
117        let input = SqlInput::new(config);
118        assert!(input.is_ok());
119    }
120
121    #[tokio::test]
122    async fn test_sql_input_connect() {
123        let config = SqlInputConfig {
124            select_sql: "SELECT * FROM test".to_string(),
125            create_table_sql:
126                "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION 'test.csv'"
127                    .to_string(),
128        };
129        let input = SqlInput::new(config).unwrap();
130        assert!(input.connect().await.is_ok());
131    }
132
133    #[tokio::test]
134    async fn test_sql_input_read() -> Result<(), Error> {
135        // 创建临时目录和测试数据文件
136        let temp_dir = tempdir().unwrap();
137        let csv_path = temp_dir.path().join("test.csv");
138        let mut file = File::create(&csv_path).unwrap();
139        writeln!(file, "id,name").unwrap();
140        writeln!(file, "1,Alice").unwrap();
141        writeln!(file, "2,Bob").unwrap();
142
143        let config = SqlInputConfig {
144            select_sql: "SELECT * FROM test".to_string(),
145            create_table_sql: format!(
146                "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION '{}'",
147                csv_path.to_str().unwrap()
148            ),
149        };
150
151        let input = SqlInput::new(config)?;
152        let (batch, _ack) = input.read().await?;
153
154        // 验证返回的数据
155        match batch.content {
156            Content::Arrow(record_batch) => {
157                assert_eq!(record_batch.num_rows(), 2);
158                assert_eq!(record_batch.num_columns(), 2);
159
160                let id_array = record_batch
161                    .column(0)
162                    .as_any()
163                    .downcast_ref::<Int32Array>()
164                    .unwrap();
165                let name_array = record_batch
166                    .column(1)
167                    .as_any()
168                    .downcast_ref::<StringArray>()
169                    .unwrap();
170
171                assert_eq!(id_array.value(0), 1);
172                assert_eq!(id_array.value(1), 2);
173                assert_eq!(name_array.value(0), "Alice");
174                assert_eq!(name_array.value(1), "Bob");
175            }
176            _ => panic!("Expected Arrow content"),
177        }
178
179        // Verify idempotency (second read should return Done error)
180        let result = input.read().await;
181        assert!(matches!(result, Err(Error::EOF)));
182
183        Ok(())
184    }
185
186    #[tokio::test]
187    async fn test_sql_input_invalid_sql() {
188        let config = SqlInputConfig {
189            select_sql: "INVALID SQL".to_string(),
190            create_table_sql:
191                "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION 'test.csv'"
192                    .to_string(),
193        };
194        let input = SqlInput::new(config).unwrap();
195        let result = input.read().await;
196        assert!(matches!(result, Err(Error::Read(_))));
197    }
198
199    #[tokio::test]
200    async fn test_sql_input_close() {
201        let config = SqlInputConfig {
202            select_sql: "SELECT * FROM test".to_string(),
203            create_table_sql:
204                "CREATE EXTERNAL TABLE test (id INT, name STRING) STORED AS CSV LOCATION 'test.csv'"
205                    .to_string(),
206        };
207        let input = SqlInput::new(config).unwrap();
208        assert!(input.close().await.is_ok());
209    }
210}