convergence_arrow/
datafusion.rs1use crate::table::{record_batch_to_rows, schema_to_field_desc};
4use async_trait::async_trait;
5use convergence::engine::{Engine, Portal};
6use convergence::protocol::{ErrorResponse, FieldDescription, SqlState};
7use convergence::protocol_ext::DataRowBatch;
8use convergence::sqlparser::ast::Statement;
9use convergence::sqlparser::dialect::PostgreSqlDialect;
10use convergence::sqlparser::parser::Parser;
11use datafusion::error::DataFusionError;
12use datafusion::prelude::*;
13
14fn df_err_to_sql(err: DataFusionError) -> ErrorResponse {
15 ErrorResponse::error(SqlState::DataException, err.to_string())
16}
17
18fn dummy_query() -> Statement {
20 let mut statements = Parser::parse_sql(&PostgreSqlDialect {}, "select 1").expect("failed to parse dummy statement");
21 statements.remove(0)
22}
23
24fn translate_statement(statement: &Statement) -> Statement {
25 match statement {
26 Statement::SetVariable { .. } => dummy_query(),
27 other => other.clone(),
28 }
29}
30
31pub struct DataFusionPortal {
33 df: DataFrame,
34}
35
36#[async_trait]
37impl Portal for DataFusionPortal {
38 async fn fetch(&mut self, batch: &mut DataRowBatch) -> Result<(), ErrorResponse> {
39 for arrow_batch in self.df.clone().collect().await.map_err(df_err_to_sql)? {
40 record_batch_to_rows(&arrow_batch, batch)?;
41 }
42 Ok(())
43 }
44}
45
46pub struct DataFusionEngine {
48 ctx: SessionContext,
49}
50
51impl DataFusionEngine {
52 pub fn new(ctx: SessionContext) -> Self {
54 Self { ctx }
55 }
56}
57
58#[async_trait]
59impl Engine for DataFusionEngine {
60 type PortalType = DataFusionPortal;
61
62 async fn prepare(&mut self, statement: &Statement) -> Result<Vec<FieldDescription>, ErrorResponse> {
63 let plan = self
64 .ctx
65 .sql(&translate_statement(statement).to_string())
66 .await
67 .map_err(df_err_to_sql)?;
68
69 schema_to_field_desc(&plan.schema().clone().into())
70 }
71
72 async fn create_portal(&mut self, statement: &Statement) -> Result<Self::PortalType, ErrorResponse> {
73 let df = self
74 .ctx
75 .sql(&translate_statement(statement).to_string())
76 .await
77 .map_err(df_err_to_sql)?;
78
79 Ok(DataFusionPortal { df })
80 }
81}