ella-engine 0.1.5

Core engine implementation for the ella datastore.
Documentation
use std::{fmt::Debug, sync::Arc};

use datafusion::{
    arrow::datatypes::SchemaRef,
    datasource::TableProvider,
    error::{DataFusionError, Result as DfResult},
    execution::context::SessionState,
    logical_expr::{TableProviderFilterPushDown, TableType},
    physical_plan::ExecutionPlan,
    prelude::Expr,
};
use datafusion_proto::logical_plan::LogicalExtensionCodec;

use crate::{cluster::EllaCluster, registry::TableId, table::EllaTable};

fn encode_table(
    node: Arc<dyn datafusion::datasource::TableProvider>,
    buf: &mut Vec<u8>,
) -> datafusion::error::Result<()> {
    if let Some(table) = node.as_any().downcast_ref::<TableStub>() {
        serde_json::to_writer(buf, table.table())
            .map_err(|err| DataFusionError::External(Box::new(err)))?;
        Ok(())
    } else if let Some(table) = node.as_any().downcast_ref::<EllaTable>() {
        serde_json::to_writer(buf, table.id())
            .map_err(|err| DataFusionError::External(Box::new(err)))?;
        Ok(())
    } else {
        Err(DataFusionError::Internal(
            "failed to encode table provider".to_string(),
        ))
    }
}

#[derive(Debug, Default, Clone, Copy)]
pub struct RemoteExtensionCodec {}

impl LogicalExtensionCodec for RemoteExtensionCodec {
    fn try_decode(
        &self,
        _buf: &[u8],
        _inputs: &[datafusion::logical_expr::LogicalPlan],
        _ctx: &datafusion::prelude::SessionContext,
    ) -> datafusion::error::Result<datafusion::logical_expr::Extension> {
        todo!()
    }

    fn try_encode(
        &self,
        node: &datafusion::logical_expr::Extension,
        _buf: &mut Vec<u8>,
    ) -> datafusion::error::Result<()> {
        Err(DataFusionError::NotImplemented(format!(
            "unable to encode extension node: {:?}",
            node.node
        )))
    }

    fn try_decode_table_provider(
        &self,
        buf: &[u8],
        schema: SchemaRef,
        _ctx: &datafusion::prelude::SessionContext,
    ) -> datafusion::error::Result<Arc<dyn datafusion::datasource::TableProvider>> {
        let table: TableId =
            serde_json::from_slice(buf).map_err(|err| DataFusionError::External(Box::new(err)))?;
        Ok(Arc::new(TableStub { schema, table }))
    }

    fn try_encode_table_provider(
        &self,
        node: Arc<dyn datafusion::datasource::TableProvider>,
        buf: &mut Vec<u8>,
    ) -> datafusion::error::Result<()> {
        encode_table(node, buf)
    }
}

#[derive(Clone)]
pub struct EllaExtensionCodec {
    cluster: Arc<EllaCluster>,
}

impl Debug for EllaExtensionCodec {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("EllaExtensionCodec").finish_non_exhaustive()
    }
}

impl EllaExtensionCodec {
    pub fn new(cluster: Arc<EllaCluster>) -> Self {
        Self { cluster }
    }
}

impl LogicalExtensionCodec for EllaExtensionCodec {
    fn try_decode(
        &self,
        _buf: &[u8],
        _inputs: &[datafusion::logical_expr::LogicalPlan],
        _ctx: &datafusion::prelude::SessionContext,
    ) -> datafusion::error::Result<datafusion::logical_expr::Extension> {
        todo!()
    }

    fn try_encode(
        &self,
        _node: &datafusion::logical_expr::Extension,
        _buf: &mut Vec<u8>,
    ) -> datafusion::error::Result<()> {
        todo!()
    }

    fn try_decode_table_provider(
        &self,
        buf: &[u8],
        _schema: SchemaRef,
        _ctx: &datafusion::prelude::SessionContext,
    ) -> datafusion::error::Result<std::sync::Arc<dyn datafusion::datasource::TableProvider>> {
        let table: TableId =
            serde_json::from_slice(buf).map_err(|err| DataFusionError::External(Box::new(err)))?;

        self.cluster
            .catalog(&table.catalog)
            .and_then(|catalog| catalog.schema(&table.schema))
            .and_then(|schema| schema.table(&table.table))
            .ok_or_else(|| DataFusionError::Plan(format!("table {} not found", table)))
            .map(|t| t as Arc<_>)
    }

    fn try_encode_table_provider(
        &self,
        node: std::sync::Arc<dyn datafusion::datasource::TableProvider>,
        buf: &mut Vec<u8>,
    ) -> datafusion::error::Result<()> {
        encode_table(node, buf)
    }
}

#[derive(Debug, Clone)]
pub struct TableStub {
    table: TableId<'static>,
    schema: SchemaRef,
}

impl TableStub {
    pub fn new(table: TableId<'static>, schema: SchemaRef) -> Self {
        Self { table, schema }
    }

    pub fn table(&self) -> &TableId {
        &self.table
    }
}

#[async_trait::async_trait]
impl TableProvider for TableStub {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    fn schema(&self) -> SchemaRef {
        self.schema.clone()
    }

    fn table_type(&self) -> TableType {
        TableType::Base
    }

    fn supports_filters_pushdown(
        &self,
        filters: &[&Expr],
    ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
        Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
    }

    async fn scan(
        &self,
        _state: &SessionState,
        _projection: Option<&Vec<usize>>,
        _filters: &[Expr],
        _limit: Option<usize>,
    ) -> DfResult<Arc<dyn ExecutionPlan>> {
        Err(DataFusionError::NotImplemented(
            "stub tables can't be scanned".to_string(),
        ))
    }
}