llkv-sql 0.4.5-alpha

SQL interface for the LLKV toolkit.
Documentation
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

use arrow::array::Array as ArrowArray;
use llkv_runtime::{RuntimeContext, RuntimeStatementResult};
use llkv_sql::SqlEngine;
use llkv_storage::pager::MemPager;
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};

/// Format a struct value in DuckDB-compatible format: {'field1': value1, 'field2': value2}
fn format_struct_value(struct_array: &arrow::array::StructArray, row_idx: usize) -> String {
    let mut parts = Vec::new();
    let field_names = struct_array.column_names();
    for (field_idx, column) in struct_array.columns().iter().enumerate() {
        let field_name = field_names[field_idx];
        let value_str = match column.data_type() {
            arrow::datatypes::DataType::Int64 => {
                let a = column
                    .as_any()
                    .downcast_ref::<arrow::array::Int64Array>()
                    .unwrap();
                if a.is_null(row_idx) {
                    "NULL".to_string()
                } else {
                    a.value(row_idx).to_string()
                }
            }
            arrow::datatypes::DataType::Int32 => {
                let a = column
                    .as_any()
                    .downcast_ref::<arrow::array::Int32Array>()
                    .unwrap();
                if a.is_null(row_idx) {
                    "NULL".to_string()
                } else {
                    a.value(row_idx).to_string()
                }
            }
            arrow::datatypes::DataType::Utf8 => {
                let a = column
                    .as_any()
                    .downcast_ref::<arrow::array::StringArray>()
                    .unwrap();
                if a.is_null(row_idx) {
                    "NULL".to_string()
                } else {
                    format!("'{}'", a.value(row_idx))
                }
            }
            arrow::datatypes::DataType::Float64 => {
                let a = column
                    .as_any()
                    .downcast_ref::<arrow::array::Float64Array>()
                    .unwrap();
                if a.is_null(row_idx) {
                    "NULL".to_string()
                } else {
                    a.value(row_idx).to_string()
                }
            }
            arrow::datatypes::DataType::Boolean => {
                let a = column
                    .as_any()
                    .downcast_ref::<arrow::array::BooleanArray>()
                    .unwrap();
                if a.is_null(row_idx) {
                    "NULL".to_string()
                } else if a.value(row_idx) {
                    "1".to_string()
                } else {
                    "0".to_string()
                }
            }
            arrow::datatypes::DataType::Struct(_) => {
                // Recursively format nested struct
                let a = column
                    .as_any()
                    .downcast_ref::<arrow::array::StructArray>()
                    .unwrap();
                if a.is_null(row_idx) {
                    "NULL".to_string()
                } else {
                    format_struct_value(a, row_idx)
                }
            }
            _ => "NULL".to_string(),
        };
        parts.push(format!("'{}': {}", field_name, value_str));
    }
    format!("{{{}}}", parts.join(", "))
}

pub struct EngineHarness {
    engine: SqlEngine<MemPager>,
}

impl EngineHarness {
    pub fn new(engine: SqlEngine<MemPager>) -> Self {
        let harness = Self { engine };
        tracing::debug!("[HARNESS] new() created harness at {:p}", &harness);
        harness
    }
}

#[derive(Clone)]
pub struct SharedContext {
    context: Arc<RuntimeContext<MemPager>>,
}

impl Default for SharedContext {
    fn default() -> Self {
        Self::new()
    }
}

impl SharedContext {
    pub fn new() -> Self {
        let pager = Arc::new(MemPager::default());
        let context = Arc::new(RuntimeContext::new(pager));
        Self { context }
    }

    pub fn make_engine(&self) -> SqlEngine<MemPager> {
        SqlEngine::with_context(Arc::clone(&self.context), false)
    }
}

#[async_trait::async_trait]
impl AsyncDB for EngineHarness {
    type Error = llkv_result::Error;
    type ColumnType = DefaultColumnType;

    async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
        // Log which SQL is being executed by this harness
        tracing::debug!("[HARNESS {:p}] run() called, sql=\"{}\"", self, sql.trim());
        match self.engine.execute(sql) {
            Ok(mut results) => {
                tracing::trace!(
                    "[HARNESS] execute() returned Ok with {} results",
                    results.len()
                );
                if results.is_empty() {
                    return Ok(DBOutput::StatementComplete(0));
                }
                let result = results.remove(0);
                match result {
                    RuntimeStatementResult::Select { execution, .. } => {
                        let batches = execution.collect()?;
                        let mut rows: Vec<Vec<String>> = Vec::new();
                        for batch in &batches {
                            for row_idx in 0..batch.num_rows() {
                                let mut row: Vec<String> = Vec::new();
                                for col in 0..batch.num_columns() {
                                    let array = batch.column(col);
                                    let val = match array.data_type() {
                                        arrow::datatypes::DataType::Int64 => {
                                            let a = array
                                                .as_any()
                                                .downcast_ref::<arrow::array::Int64Array>()
                                                .unwrap();
                                            if a.is_null(row_idx) {
                                                "NULL".to_string()
                                            } else {
                                                a.value(row_idx).to_string()
                                            }
                                        }
                                        arrow::datatypes::DataType::UInt64 => {
                                            let a = array
                                                .as_any()
                                                .downcast_ref::<arrow::array::UInt64Array>()
                                                .unwrap();
                                            if a.is_null(row_idx) {
                                                "NULL".to_string()
                                            } else {
                                                a.value(row_idx).to_string()
                                            }
                                        }
                                        arrow::datatypes::DataType::Float64 => {
                                            let a = array
                                                .as_any()
                                                .downcast_ref::<arrow::array::Float64Array>()
                                                .unwrap();
                                            if a.is_null(row_idx) {
                                                "NULL".to_string()
                                            } else {
                                                a.value(row_idx).to_string()
                                            }
                                        }
                                        arrow::datatypes::DataType::Utf8 => {
                                            let a = array
                                                .as_any()
                                                .downcast_ref::<arrow::array::StringArray>()
                                                .unwrap();
                                            if a.is_null(row_idx) {
                                                "NULL".to_string()
                                            } else {
                                                a.value(row_idx).to_string()
                                            }
                                        }
                                        arrow::datatypes::DataType::Boolean => {
                                            let a = array
                                                .as_any()
                                                .downcast_ref::<arrow::array::BooleanArray>()
                                                .unwrap();
                                            if a.is_null(row_idx) {
                                                "NULL".to_string()
                                            } else if a.value(row_idx) {
                                                "1".to_string()
                                            } else {
                                                "0".to_string()
                                            }
                                        }
                                        arrow::datatypes::DataType::Struct(_) => {
                                            let a = array
                                                .as_any()
                                                .downcast_ref::<arrow::array::StructArray>()
                                                .unwrap();
                                            if a.is_null(row_idx) {
                                                "NULL".to_string()
                                            } else {
                                                format_struct_value(a, row_idx)
                                            }
                                        }
                                        _ => "".to_string(),
                                    };
                                    row.push(val);
                                }
                                rows.push(row);
                            }
                        }

                        let types = if let Some(first) = batches.first() {
                            (0..first.num_columns())
                                .map(|col| match first.column(col).data_type() {
                                    arrow::datatypes::DataType::Int64
                                    | arrow::datatypes::DataType::UInt64 => {
                                        DefaultColumnType::Integer
                                    }
                                    arrow::datatypes::DataType::Float64 => {
                                        DefaultColumnType::FloatingPoint
                                    }
                                    arrow::datatypes::DataType::Utf8 => DefaultColumnType::Text,
                                    _ => DefaultColumnType::Any,
                                })
                                .collect()
                        } else {
                            vec![]
                        };

                        Ok(DBOutput::Rows { types, rows })
                    }
                    RuntimeStatementResult::Insert { rows_inserted, .. } => {
                        // Return as a single-row result for compatibility with query directives
                        Ok(DBOutput::Rows {
                            types: vec![DefaultColumnType::Integer],
                            rows: vec![vec![rows_inserted.to_string()]],
                        })
                    }
                    RuntimeStatementResult::Update { rows_updated, .. } => {
                        // Return as a single-row result for compatibility with query directives
                        Ok(DBOutput::Rows {
                            types: vec![DefaultColumnType::Integer],
                            rows: vec![vec![rows_updated.to_string()]],
                        })
                    }
                    RuntimeStatementResult::Delete { rows_deleted, .. } => {
                        // Return as a single-row result for compatibility with query directives
                        Ok(DBOutput::Rows {
                            types: vec![DefaultColumnType::Integer],
                            rows: vec![vec![rows_deleted.to_string()]],
                        })
                    }
                    RuntimeStatementResult::CreateTable { .. } => {
                        Ok(DBOutput::StatementComplete(0))
                    }
                    RuntimeStatementResult::CreateIndex { .. } => {
                        Ok(DBOutput::StatementComplete(0))
                    }
                    RuntimeStatementResult::Transaction { .. } => {
                        Ok(DBOutput::StatementComplete(0))
                    }
                    RuntimeStatementResult::NoOp => Ok(DBOutput::StatementComplete(0)),
                }
            }
            Err(e) => {
                tracing::trace!("[HARNESS] execute() returned Err: {:?}", e);
                Err(e)
            }
        }
    }

    async fn shutdown(&mut self) {}
}

pub type HarnessFuture = Pin<Box<dyn Future<Output = Result<EngineHarness, ()>> + Send + 'static>>;
pub type HarnessFactory = Box<dyn Fn() -> HarnessFuture + Send + Sync + 'static>;

pub fn make_factory_factory() -> impl Fn() -> HarnessFactory + Clone {
    || {
        tracing::trace!("[FACTORY] make_factory_factory: Creating SharedContext");
        let shared = SharedContext::new();
        let counter = Arc::new(AtomicUsize::new(0));
        let factory: HarnessFactory = Box::new(move || {
            let n = counter.fetch_add(1, Ordering::SeqCst);
            tracing::debug!(
                "[FACTORY] Factory called #{}: Creating new EngineHarness",
                n
            );
            let shared_clone = shared.clone();
            Box::pin(async move {
                let engine = shared_clone.make_engine();
                tracing::debug!(
                    "[FACTORY] Factory #{}: Created SqlEngine with new Session",
                    n
                );
                Ok::<_, ()>(EngineHarness::new(engine))
            })
        });
        factory
    }
}