use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::array::RecordBatch;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{create_udaf, create_udf, AccumulatorFactoryFunction, Volatility};
use datafusion::physical_plan::displayable;
use datafusion::prelude::{DataFrame, SessionContext};
use oxisql_core::Connection;
use crate::error::OxiSqlFusionError;
use crate::provider::OxiSqlTableProvider;
use crate::stream::OxiSqlStreamProvider;
use crate::types::value_to_arrow_type;
pub struct OxiSqlContext {
inner: SessionContext,
}
impl OxiSqlContext {
pub fn new() -> Self {
Self {
inner: SessionContext::new(),
}
}
pub fn from_session_context(ctx: SessionContext) -> Self {
Self { inner: ctx }
}
pub fn register_table(
&self,
name: &str,
conn: Arc<dyn Connection>,
schema: SchemaRef,
) -> Result<(), OxiSqlFusionError> {
let provider = Arc::new(OxiSqlStreamProvider::new(conn, name, schema));
self.inner
.register_table(name, provider)
.map(|_| ())
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))
}
pub fn register_snapshot(
&self,
name: &str,
rows: Vec<oxisql_core::Row>,
schema: SchemaRef,
) -> Result<(), OxiSqlFusionError> {
use crate::provider::OxiSqlTableProvider;
let provider = Arc::new(OxiSqlTableProvider::from_rows(rows, schema));
self.inner
.register_table(name, provider)
.map(|_| ())
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))
}
pub fn deregister_table(&self, name: &str) -> Result<bool, OxiSqlFusionError> {
self.inner
.deregister_table(name)
.map(|opt| opt.is_some())
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))
}
pub async fn sql(&self, sql: &str) -> Result<DataFrame, OxiSqlFusionError> {
self.inner
.sql(sql)
.await
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))
}
pub async fn execute_sql(&self, sql: &str) -> Result<Vec<RecordBatch>, OxiSqlFusionError> {
let df = self.sql(sql).await?;
df.collect()
.await
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))
}
pub async fn register_view(
&self,
name: &str,
sql: &str,
) -> Result<(), datafusion::error::DataFusionError> {
self.inner
.sql(&format!("CREATE VIEW {name} AS {sql}"))
.await?;
Ok(())
}
pub fn session_context(&self) -> &SessionContext {
&self.inner
}
pub fn into_session_context(self) -> SessionContext {
self.inner
}
pub fn register_scalar_function(
&self,
name: &str,
return_type: DataType,
param_types: Vec<DataType>,
func: impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> + Send + Sync + 'static,
) -> Result<(), OxiSqlFusionError> {
use datafusion::logical_expr::ColumnarValue;
let func = Arc::new(func);
let implementation = Arc::new(
move |args: &[ColumnarValue]| -> Result<ColumnarValue, DataFusionError> {
let arrays = ColumnarValue::values_to_arrays(args)?;
let result = func(&arrays)?;
Ok(ColumnarValue::Array(result))
},
);
let udf = create_udf(
name,
param_types,
return_type,
Volatility::Immutable,
implementation,
);
self.inner.register_udf(udf);
Ok(())
}
pub fn register_aggregate_function(
&self,
name: &str,
input_types: Vec<DataType>,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
state_types: Vec<DataType>,
) -> Result<(), OxiSqlFusionError> {
let udaf = create_udaf(
name,
input_types,
Arc::new(return_type),
Volatility::Immutable,
accumulator,
Arc::new(state_types),
);
self.inner.register_udaf(udaf);
Ok(())
}
pub async fn register_embedded_table(
&self,
conn: &dyn Connection,
table_name: &str,
) -> Result<(), OxiSqlFusionError> {
register_embedded_table(&self.inner, conn, table_name).await
}
pub async fn explain_plan(&self, sql: &str) -> Result<String, OxiSqlFusionError> {
let df = self.sql(sql).await?;
let logical = format!("{}", df.logical_plan().display_indent());
let physical_plan = df
.create_physical_plan()
.await
.map_err(OxiSqlFusionError::DataFusion)?;
let physical = displayable(physical_plan.as_ref()).indent(true).to_string();
Ok(format!(
"== Logical Plan ==\n{logical}\n\n== Physical Plan ==\n{physical}"
))
}
}
impl Default for OxiSqlContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "columnar")]
impl OxiSqlContext {
pub fn register_parquet(
&self,
name: &str,
path: impl AsRef<std::path::Path>,
) -> Result<(), OxiSqlFusionError> {
let provider = crate::parquet::ParquetTableProvider::open(path)?;
self.inner
.register_table(name, Arc::new(provider))
.map(|_| ())
.map_err(OxiSqlFusionError::DataFusion)
}
}
pub fn register_oxisql_table(
ctx: &SessionContext,
name: &str,
conn: Arc<dyn Connection>,
schema: SchemaRef,
) -> Result<(), OxiSqlFusionError> {
let provider = Arc::new(OxiSqlStreamProvider::new(conn, name, schema));
ctx.register_table(name, provider)
.map(|_| ())
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))
}
fn infer_schema_from_first_row(row: &oxisql_core::Row) -> Option<SchemaRef> {
let labels = row.columns();
if labels.is_empty() {
return None;
}
let fields: Vec<Field> = labels
.iter()
.enumerate()
.map(|(idx, name)| {
let dtype = row
.get_by_index(idx)
.and_then(value_to_arrow_type)
.unwrap_or(DataType::Utf8);
Field::new(name.as_str(), dtype, true)
})
.collect();
Some(Arc::new(Schema::new(fields)))
}
pub async fn register_embedded_table(
ctx: &SessionContext,
conn: &dyn Connection,
table_name: &str,
) -> Result<(), OxiSqlFusionError> {
let rows = conn
.query(&format!("SELECT * FROM {table_name}"), &[])
.await
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))?;
if rows.is_empty() {
return Ok(());
}
let schema = infer_schema_from_first_row(&rows[0]).ok_or_else(|| {
OxiSqlFusionError::OxiSql(format!(
"table '{table_name}' returned rows with no columns"
))
})?;
let provider = OxiSqlTableProvider::from_rows(rows, schema);
ctx.register_table(table_name, Arc::new(provider))
.map(|_| ())
.map_err(OxiSqlFusionError::DataFusion)
}