mod common;
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::prelude::SessionContext;
use oxisql_core::Connection;
#[test]
fn test_expr_to_sql_eq() {
use datafusion::common::Column;
use datafusion::logical_expr::{BinaryExpr, Expr, Operator};
use datafusion::scalar::ScalarValue;
let expr = Expr::BinaryExpr(BinaryExpr {
left: Box::new(Expr::Column(Column::new_unqualified("id"))),
op: Operator::Eq,
right: Box::new(Expr::Literal(ScalarValue::Int64(Some(42)), None)),
});
let sql = oxisql_datafusion::stream::expr_to_sql(&expr).expect("should translate");
assert_eq!(sql, "(id = 42)");
}
#[test]
fn test_expr_to_sql_is_null() {
use datafusion::common::Column;
use datafusion::logical_expr::Expr;
let expr = Expr::IsNull(Box::new(Expr::Column(Column::new_unqualified("name"))));
let sql = oxisql_datafusion::stream::expr_to_sql(&expr).expect("should translate");
assert_eq!(sql, "(name IS NULL)");
}
#[test]
fn test_expr_to_sql_is_not_null() {
use datafusion::common::Column;
use datafusion::logical_expr::Expr;
let expr = Expr::IsNotNull(Box::new(Expr::Column(Column::new_unqualified("age"))));
let sql = oxisql_datafusion::stream::expr_to_sql(&expr).expect("should translate");
assert_eq!(sql, "(age IS NOT NULL)");
}
#[test]
fn test_expr_to_sql_not() {
use datafusion::common::Column;
use datafusion::logical_expr::{BinaryExpr, Expr, Operator};
use datafusion::scalar::ScalarValue;
let inner = Expr::BinaryExpr(BinaryExpr {
left: Box::new(Expr::Column(Column::new_unqualified("active"))),
op: Operator::Eq,
right: Box::new(Expr::Literal(ScalarValue::Boolean(Some(true)), None)),
});
let expr = Expr::Not(Box::new(inner));
let sql = oxisql_datafusion::stream::expr_to_sql(&expr).expect("should translate");
assert_eq!(sql, "(NOT (active = TRUE))");
}
#[test]
fn test_expr_to_sql_unsupported_returns_none() {
use datafusion::common::Column;
use datafusion::logical_expr::Expr;
let expr = Expr::IsTrue(Box::new(Expr::Column(Column::new_unqualified("flag"))));
assert!(oxisql_datafusion::stream::expr_to_sql(&expr).is_none());
}
#[tokio::test]
async fn test_stream_provider_basic() {
use oxisql_datafusion::OxiSqlStreamProvider;
use oxisql_embedded::EmbeddedConnection;
let conn = EmbeddedConnection::open_memory().expect("open_memory");
conn.execute("CREATE TABLE users (id INTEGER, name TEXT)", &[])
.await
.expect("CREATE TABLE");
conn.execute("INSERT INTO users VALUES (1, 'Alice')", &[])
.await
.expect("INSERT 1");
conn.execute("INSERT INTO users VALUES (2, 'Bob')", &[])
.await
.expect("INSERT 2");
conn.execute("INSERT INTO users VALUES (3, 'Carol')", &[])
.await
.expect("INSERT 3");
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, true),
Field::new("name", DataType::Utf8, true),
]));
let conn_arc = Arc::new(conn) as Arc<dyn oxisql_core::Connection>;
let provider = OxiSqlStreamProvider::new(conn_arc, "users", schema);
let ctx = SessionContext::new();
ctx.register_table("users", Arc::new(provider))
.expect("register_table");
let df = ctx
.sql("SELECT * FROM users ORDER BY id")
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3, "expected 3 rows");
}
#[tokio::test]
async fn test_stream_provider_with_limit() {
use oxisql_datafusion::OxiSqlStreamProvider;
use oxisql_embedded::EmbeddedConnection;
let conn = EmbeddedConnection::open_memory().expect("open_memory");
conn.execute("CREATE TABLE nums (v INTEGER)", &[])
.await
.expect("CREATE TABLE");
for i in 0_i64..10 {
conn.execute(&format!("INSERT INTO nums VALUES ({i})"), &[])
.await
.expect("INSERT");
}
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int64, true)]));
let conn_arc = Arc::new(conn) as Arc<dyn oxisql_core::Connection>;
let provider = OxiSqlStreamProvider::new(conn_arc, "nums", schema);
let ctx = SessionContext::new();
ctx.register_table("nums", Arc::new(provider))
.expect("register_table");
let df = ctx
.sql("SELECT v FROM nums ORDER BY v LIMIT 5")
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 5, "LIMIT 5 should return exactly 5 rows");
}
#[tokio::test]
async fn test_filter_pushdown_reduces_rows() {
use oxisql_embedded::EmbeddedConnection;
let conn = EmbeddedConnection::open_memory().expect("open");
conn.execute("CREATE TABLE scores (id INTEGER, score INTEGER)", &[])
.await
.expect("create");
for i in 1i64..=10 {
conn.execute(
&format!("INSERT INTO scores (id, score) VALUES ({i}, {})", i * 10),
&[],
)
.await
.expect("insert");
}
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, true),
Field::new("score", DataType::Int64, true),
]));
let provider = oxisql_datafusion::OxiSqlStreamProvider::new(
Arc::new(conn) as Arc<dyn oxisql_core::Connection>,
"scores",
Arc::clone(&schema),
);
let ctx = SessionContext::new();
ctx.register_table("scores", Arc::new(provider))
.expect("register");
let df = ctx
.sql("SELECT id FROM scores WHERE score > 50")
.await
.expect("sql");
let results = df.collect().await.expect("collect");
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 5, "filter should reduce to 5 rows");
}
#[tokio::test]
async fn test_projection_pushdown() {
use oxisql_datafusion::OxiSqlStreamProvider;
use oxisql_embedded::EmbeddedConnection;
let conn = EmbeddedConnection::open_memory().expect("open");
conn.execute(
"CREATE TABLE items (id INTEGER, name TEXT, value INTEGER)",
&[],
)
.await
.expect("create");
for i in 1i64..=3 {
conn.execute(
&format!(
"INSERT INTO items (id, name, value) VALUES ({i}, 'item{i}', {})",
i * 100
),
&[],
)
.await
.expect("insert");
}
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, true),
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Int64, true),
]));
let provider = OxiSqlStreamProvider::new(
Arc::new(conn) as Arc<dyn oxisql_core::Connection>,
"items",
Arc::clone(&schema),
);
let ctx = SessionContext::new();
ctx.register_table("items", Arc::new(provider))
.expect("register");
let df = ctx.sql("SELECT id, name FROM items").await.expect("sql");
let results = df.collect().await.expect("collect");
assert_eq!(results[0].num_columns(), 2);
assert_eq!(results[0].num_rows(), 3);
}
#[tokio::test]
async fn test_limit_pushdown() {
use oxisql_embedded::EmbeddedConnection;
let conn = EmbeddedConnection::open_memory().expect("open");
conn.execute("CREATE TABLE nums (n INTEGER)", &[])
.await
.expect("create");
for i in 1i64..=20 {
conn.execute(&format!("INSERT INTO nums (n) VALUES ({i})"), &[])
.await
.expect("insert");
}
let schema = Arc::new(Schema::new(vec![Field::new("n", DataType::Int64, true)]));
let provider = oxisql_datafusion::OxiSqlStreamProvider::new(
Arc::new(conn) as Arc<dyn oxisql_core::Connection>,
"nums",
Arc::clone(&schema),
);
let ctx = SessionContext::new();
ctx.register_table("nums", Arc::new(provider))
.expect("register");
let df = ctx.sql("SELECT n FROM nums LIMIT 5").await.expect("sql");
let results = df.collect().await.expect("collect");
let total: usize = results.iter().map(|b| b.num_rows()).sum();
assert!(
total <= 5,
"limit should produce at most 5 rows, got {total}"
);
}
#[tokio::test]
async fn test_stream_provider_projection() {
use oxisql_datafusion::OxiSqlStreamProvider;
use oxisql_embedded::EmbeddedConnection;
let conn = EmbeddedConnection::open_memory().expect("open_memory");
conn.execute(
"CREATE TABLE items (id INTEGER, label TEXT, score FLOAT)",
&[],
)
.await
.expect("CREATE TABLE");
conn.execute("INSERT INTO items VALUES (1, 'foo', 1.5)", &[])
.await
.expect("INSERT");
conn.execute("INSERT INTO items VALUES (2, 'bar', 2.5)", &[])
.await
.expect("INSERT");
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, true),
Field::new("label", DataType::Utf8, true),
Field::new("score", DataType::Float64, true),
]));
let conn_arc = Arc::new(conn) as Arc<dyn oxisql_core::Connection>;
let provider = OxiSqlStreamProvider::new(conn_arc, "items", schema);
let ctx = SessionContext::new();
ctx.register_table("items", Arc::new(provider))
.expect("register_table");
let df = ctx
.sql("SELECT id, label FROM items ORDER BY id")
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 2);
for batch in &batches {
assert_eq!(
batch.num_columns(),
2,
"projected batch must have 2 columns"
);
}
}
#[cfg(feature = "sqlite")]
#[tokio::test]
async fn test_sqlite_stream_provider() {
use oxisql_datafusion::stream::OxiSqlStreamProvider;
use oxisql_sqlite_compat::SqliteConnection;
let conn = SqliteConnection::open_memory()
.await
.expect("open in-memory SQLite");
conn.execute(
"CREATE TABLE products (id INTEGER, name TEXT, price FLOAT)",
&[],
)
.await
.expect("CREATE TABLE products");
conn.execute("INSERT INTO products VALUES (1, 'Alpha', 9.99)", &[])
.await
.expect("INSERT 1");
conn.execute("INSERT INTO products VALUES (2, 'Beta', 19.99)", &[])
.await
.expect("INSERT 2");
conn.execute("INSERT INTO products VALUES (3, 'Gamma', 29.99)", &[])
.await
.expect("INSERT 3");
conn.execute("INSERT INTO products VALUES (4, 'Delta', 4.99)", &[])
.await
.expect("INSERT 4");
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, true),
Field::new("name", DataType::Utf8, true),
Field::new("price", DataType::Float64, true),
]));
let provider = OxiSqlStreamProvider::from_sqlite(conn, "products", Arc::clone(&schema));
let ctx = SessionContext::new();
ctx.register_table("products", Arc::new(provider))
.expect("register_table");
let df = ctx
.sql("SELECT * FROM products ORDER BY id")
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 4, "expected 4 rows from SQLite via DataFusion");
let df_filtered = ctx
.sql("SELECT name FROM products WHERE price >= 10.0 ORDER BY id")
.await
.expect("sql parse filtered");
let batches_filtered = df_filtered.collect().await.expect("collect filtered");
let filtered_rows: usize = batches_filtered.iter().map(|b| b.num_rows()).sum();
assert_eq!(
filtered_rows, 2,
"filter price >= 10.0 should return 2 rows (Beta, Gamma)"
);
}