arkflow 0.1.0

High-performance Rust flow processing engine
Documentation
use crate::input::{Ack, Input, NoopAck};
use crate::{Error, MessageBatch};
use async_trait::async_trait;
use datafusion::arrow;
use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::Schema;
use datafusion::prelude::{SQLOptions, SessionContext};
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SqlConfig {
    select_sql: String,
    create_source_sql: String,
}

pub struct SqlInput {
    sql_config: SqlConfig,
    read: AtomicBool,
}

impl SqlInput {
    pub fn new(sql_config: &SqlConfig) -> Result<Self, Error> {
        Ok(Self {
            sql_config: sql_config.clone(),
            read: AtomicBool::new(false),
        })
    }
}

#[async_trait]
impl Input for SqlInput {
    async fn connect(&self) -> Result<(), Error> {
        Ok(())
    }

    async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
        if self.read.load(Ordering::Acquire) {
            return Err(Error::Done);
        }

        let ctx = SessionContext::new();

        let sql_options = SQLOptions::new()
            .with_allow_ddl(true)
            .with_allow_dml(false)
            .with_allow_statements(false);
        ctx.sql_with_options(&self.sql_config.create_source_sql, sql_options)
            .await
            .map_err(|e| Error::Config(format!("Failed to execute SQL query: {}", e)))?;

        let sql_options = SQLOptions::new()
            .with_allow_ddl(false)
            .with_allow_dml(false)
            .with_allow_statements(false);
        let df = ctx
            .sql_with_options(&self.sql_config.select_sql, sql_options)
            .await
            .map_err(|e| Error::Reading(format!("Failed to execute SQL query: {}", e)))?;

        let result_batches = df
            .collect()
            .await
            .map_err(|e| Error::Reading(format!("Failed to collect data from SQL query: {}", e)))?;

        let x = match result_batches.len() {
            0 => RecordBatch::new_empty(Arc::new(Schema::empty())),
            1 => result_batches[0].clone(),
            _ => arrow::compute::concat_batches(&&result_batches[0].schema(), &result_batches)
                .map_err(|e| Error::Processing(format!("合并批次失败: {}", e)))?,
        };

        self.read.store(true, Ordering::Release);
        Ok((MessageBatch::new_arrow(x), Arc::new(NoopAck)))
    }

    async fn close(&self) -> Result<(), Error> {
        Ok(())
    }
}