convergence_arrow/
datafusion.rs

1//! Provides a DataFusion-powered implementation of the [Engine] trait.
2
3use crate::table::{record_batch_to_rows, schema_to_field_desc};
4use async_trait::async_trait;
5use convergence::engine::{Engine, Portal};
6use convergence::protocol::{ErrorResponse, FieldDescription, SqlState};
7use convergence::protocol_ext::DataRowBatch;
8use convergence::sqlparser::ast::Statement;
9use convergence::sqlparser::dialect::PostgreSqlDialect;
10use convergence::sqlparser::parser::Parser;
11use datafusion::error::DataFusionError;
12use datafusion::prelude::*;
13
14fn df_err_to_sql(err: DataFusionError) -> ErrorResponse {
15	ErrorResponse::error(SqlState::DataException, err.to_string())
16}
17
18// dummy query used as replacement for set variable statements etc
19fn dummy_query() -> Statement {
20	let mut statements = Parser::parse_sql(&PostgreSqlDialect {}, "select 1").expect("failed to parse dummy statement");
21	statements.remove(0)
22}
23
24fn translate_statement(statement: &Statement) -> Statement {
25	match statement {
26		Statement::SetVariable { .. } => dummy_query(),
27		other => other.clone(),
28	}
29}
30
31/// A portal built using a logical DataFusion query plan.
32pub struct DataFusionPortal {
33	df: DataFrame,
34}
35
36#[async_trait]
37impl Portal for DataFusionPortal {
38	async fn fetch(&mut self, batch: &mut DataRowBatch) -> Result<(), ErrorResponse> {
39		for arrow_batch in self.df.clone().collect().await.map_err(df_err_to_sql)? {
40			record_batch_to_rows(&arrow_batch, batch)?;
41		}
42		Ok(())
43	}
44}
45
46/// An engine instance using DataFusion for catalogue management and queries.
47pub struct DataFusionEngine {
48	ctx: SessionContext,
49}
50
51impl DataFusionEngine {
52	/// Creates a new engine instance using the given DataFusion execution context.
53	pub fn new(ctx: SessionContext) -> Self {
54		Self { ctx }
55	}
56}
57
58#[async_trait]
59impl Engine for DataFusionEngine {
60	type PortalType = DataFusionPortal;
61
62	async fn prepare(&mut self, statement: &Statement) -> Result<Vec<FieldDescription>, ErrorResponse> {
63		let plan = self
64			.ctx
65			.sql(&translate_statement(statement).to_string())
66			.await
67			.map_err(df_err_to_sql)?;
68
69		schema_to_field_desc(&plan.schema().clone().into())
70	}
71
72	async fn create_portal(&mut self, statement: &Statement) -> Result<Self::PortalType, ErrorResponse> {
73		let df = self
74			.ctx
75			.sql(&translate_statement(statement).to_string())
76			.await
77			.map_err(df_err_to_sql)?;
78
79		Ok(DataFusionPortal { df })
80	}
81}