mod common;
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::prelude::SessionContext;
use oxisql_core::{Connection, Row, Value};
use oxisql_datafusion::OxiSqlTableProvider;
#[tokio::test]
async fn test_aggregation_query() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let cols: Vec<String> = vec!["id".into(), "name".into()];
let rows: Vec<Row> = (1i64..=5)
.map(|i| {
Row::new(
cols.clone(),
vec![Value::I64(i), Value::Text(format!("item{i}"))],
)
})
.collect();
let provider = OxiSqlTableProvider::from_rows(rows, schema);
let ctx = SessionContext::new();
ctx.register_table("agg_test", Arc::new(provider))
.expect("register_table");
let df = ctx
.sql("SELECT COUNT(*) AS cnt FROM agg_test")
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
assert_eq!(batches[0].num_rows(), 1, "aggregate should return 1 row");
use arrow::array::Int64Array;
let col = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.expect("count column should be Int64");
assert_eq!(col.value(0), 5, "COUNT(*) should equal 5");
}
#[test]
fn test_provider_display() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let rows = vec![Row::new(vec!["id".into()], vec![Value::I64(1)])];
let provider = OxiSqlTableProvider::from_rows(rows, schema);
let s = format!("{provider}");
assert!(
s.contains('1'),
"Display must include '1' for row/col count: {s}"
);
assert!(
s.starts_with("OxiSqlTableProvider("),
"Display must start with 'OxiSqlTableProvider(': {s}"
);
}
#[tokio::test]
async fn test_from_connection() {
use oxisql_embedded::EmbeddedConnection;
let conn = EmbeddedConnection::open_memory().expect("open_memory");
conn.execute("CREATE TABLE fc_test (id INTEGER, val TEXT)", &[])
.await
.expect("CREATE TABLE");
conn.execute("INSERT INTO fc_test VALUES (1, 'hello')", &[])
.await
.expect("INSERT 1");
conn.execute("INSERT INTO fc_test VALUES (2, 'world')", &[])
.await
.expect("INSERT 2");
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, true),
Field::new("val", DataType::Utf8, true),
]));
let provider = OxiSqlTableProvider::from_connection(&conn, "fc_test", schema)
.await
.expect("from_connection");
assert_eq!(provider.len(), 2, "from_connection should load 2 rows");
}
#[tokio::test]
async fn test_provider_refresh() {
use oxisql_embedded::EmbeddedConnection;
let conn = EmbeddedConnection::open_memory().expect("open_memory");
conn.execute("CREATE TABLE refresh_test (n INTEGER)", &[])
.await
.expect("CREATE TABLE");
conn.execute("INSERT INTO refresh_test VALUES (1)", &[])
.await
.expect("INSERT 1");
let schema = Arc::new(Schema::new(vec![Field::new("n", DataType::Int64, true)]));
let mut provider = OxiSqlTableProvider::from_connection(&conn, "refresh_test", schema)
.await
.expect("from_connection");
let initial_count = provider.len();
assert_eq!(initial_count, 1, "initial snapshot should have 1 row");
conn.execute("INSERT INTO refresh_test VALUES (2)", &[])
.await
.expect("INSERT 2");
provider
.refresh(&conn, "refresh_test")
.await
.expect("refresh");
let new_count = provider.len();
assert_eq!(
new_count,
initial_count + 1,
"after refresh there should be 2 rows"
);
}
#[tokio::test]
async fn test_multi_table_join() {
let user_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let order_schema = Arc::new(Schema::new(vec![
Field::new("user_id", DataType::Int64, false),
Field::new("total", DataType::Float64, false),
]));
let users = vec![
Row::new(
vec!["id".into(), "name".into()],
vec![Value::I64(1), Value::Text("Alice".into())],
),
Row::new(
vec!["id".into(), "name".into()],
vec![Value::I64(2), Value::Text("Bob".into())],
),
];
let orders = vec![
Row::new(
vec!["user_id".into(), "total".into()],
vec![Value::I64(1), Value::F64(100.0)],
),
Row::new(
vec!["user_id".into(), "total".into()],
vec![Value::I64(1), Value::F64(200.0)],
),
];
let ctx = SessionContext::new();
ctx.register_table(
"join_users",
Arc::new(OxiSqlTableProvider::from_rows(users, user_schema)),
)
.expect("register users");
ctx.register_table(
"join_orders",
Arc::new(OxiSqlTableProvider::from_rows(orders, order_schema)),
)
.expect("register orders");
let df = ctx
.sql(
"SELECT u.name, SUM(o.total) AS total \
FROM join_users u \
JOIN join_orders o ON u.id = o.user_id \
GROUP BY u.name \
ORDER BY u.name",
)
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
assert!(!batches.is_empty(), "JOIN query should return rows");
use arrow::array::{Array, Float64Array, StringArray};
let names = batches[0]
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("name column should be Utf8");
let totals = batches[0]
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.expect("total column should be Float64");
let alice_pos = (0..names.len()).find(|&i| names.value(i) == "Alice");
assert!(alice_pos.is_some(), "Alice should appear in results");
let pos = alice_pos.expect("alice_pos already checked");
assert!(
(totals.value(pos) - 300.0).abs() < 1e-6,
"Alice's total should be 300.0, got {}",
totals.value(pos)
);
}
#[test]
fn test_sort_order_builder() {
use oxisql_core::Connection;
use oxisql_datafusion::{OxiSqlStreamProvider, SortOrder};
let schema = Arc::new(Schema::new(vec![
Field::new("score", DataType::Float64, true),
Field::new("id", DataType::Int64, true),
]));
use oxisql_embedded::EmbeddedConnection;
let conn =
Arc::new(EmbeddedConnection::open_memory().expect("open_memory")) as Arc<dyn Connection>;
let order = vec![
("score".to_string(), SortOrder::Desc),
("id".to_string(), SortOrder::Asc),
];
let provider = OxiSqlStreamProvider::new(conn, "t", schema).with_sort(order.clone());
let got = provider.sort_order().expect("sort_order should be Some");
assert_eq!(got.len(), 2, "two sort columns configured");
assert_eq!(got[0].0, "score");
assert_eq!(got[0].1, SortOrder::Desc);
assert_eq!(got[1].0, "id");
assert_eq!(got[1].1, SortOrder::Asc);
}
#[test]
fn test_snapshot_filter_returns_inexact() {
use datafusion::common::Column;
use datafusion::datasource::TableProvider;
use datafusion::logical_expr::{BinaryExpr, Expr, Operator, TableProviderFilterPushDown};
use datafusion::scalar::ScalarValue;
let (rows, schema) = common::make_test_rows();
let provider = OxiSqlTableProvider::from_rows(rows, schema);
let filter = Expr::BinaryExpr(BinaryExpr {
left: Box::new(Expr::Column(Column::new_unqualified("id"))),
op: Operator::Eq,
right: Box::new(Expr::Literal(ScalarValue::Int64(Some(2)), None)),
});
let result = provider
.supports_filters_pushdown(&[&filter])
.expect("supports_filters_pushdown should not fail");
assert_eq!(result.len(), 1, "should return one result per filter");
assert_eq!(
result[0],
TableProviderFilterPushDown::Inexact,
"binary equality filter should be Inexact"
);
}
#[tokio::test]
async fn test_snapshot_filter_pushdown_equality() {
let (rows, schema) = common::make_test_rows();
let provider = OxiSqlTableProvider::from_rows(rows, schema);
let ctx = SessionContext::new();
ctx.register_table("t_filter", Arc::new(provider))
.expect("register_table");
let df = ctx
.sql("SELECT * FROM t_filter WHERE id = 2")
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 1, "WHERE id = 2 should return exactly 1 row");
}
#[tokio::test]
async fn test_range_partition_sorts_and_splits() {
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("id", DataType::Int64, false),
]));
let col = vec!["id".to_string()];
let ids: Vec<i64> = vec![9, 3, 7, 1, 5, 2, 8, 4, 6];
let rows: Vec<Row> = ids
.iter()
.map(|&i| Row::new(col.clone(), vec![Value::I64(i)]))
.collect();
let provider =
OxiSqlTableProvider::from_rows(rows, Arc::clone(&schema)).with_range_partition("id", 3);
assert_eq!(provider.len(), 9, "with_range_partition must not drop rows");
let ctx = SessionContext::new();
ctx.register_table("range_test", Arc::new(provider))
.expect("register_table");
let df = ctx
.sql("SELECT id FROM range_test ORDER BY id")
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(
total_rows, 9,
"all 9 rows must be returned after partitioning"
);
use arrow::array::Int64Array;
let all_ids: Vec<i64> = batches
.iter()
.flat_map(|b| {
b.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.expect("id column should be Int64")
.values()
.to_vec()
})
.collect();
let mut sorted_ids = all_ids.clone();
sorted_ids.sort_unstable();
assert_eq!(
all_ids, sorted_ids,
"ids should be returned in ascending order"
);
}
#[tokio::test]
async fn test_range_partition_single() {
let (rows, schema) = common::make_test_rows();
let provider = OxiSqlTableProvider::from_rows(rows, schema).with_range_partition("id", 1);
let ctx = SessionContext::new();
ctx.register_table("single_part", Arc::new(provider))
.expect("register_table");
let df = ctx
.sql("SELECT id FROM single_part ORDER BY id")
.await
.expect("sql parse");
let batches = df.collect().await.expect("collect");
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 2, "n_partitions=1 must still return all rows");
}
#[tokio::test]
async fn test_window_function_row_number() {
use oxisql_datafusion::OxiSqlContext;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("salary", DataType::Int64, false),
]));
let cols: Vec<String> = vec!["id".into(), "name".into(), "salary".into()];
let rows = vec![
Row::new(
cols.clone(),
vec![Value::I64(1), Value::Text("Alice".into()), Value::I64(1000)],
),
Row::new(
cols.clone(),
vec![Value::I64(2), Value::Text("Bob".into()), Value::I64(2000)],
),
Row::new(
cols.clone(),
vec![Value::I64(3), Value::Text("Carol".into()), Value::I64(1500)],
),
];
let ctx = OxiSqlContext::new();
ctx.register_snapshot("employees", rows, schema)
.expect("register_snapshot");
let results = ctx
.execute_sql(
"SELECT name, salary, \
ROW_NUMBER() OVER (ORDER BY salary DESC) AS rn \
FROM employees",
)
.await
.expect("execute_sql");
assert_eq!(
results.len(),
1,
"window query should return a single batch"
);
assert_eq!(results[0].num_rows(), 3, "expected 3 employee rows");
assert_eq!(
results[0].num_columns(),
3,
"result should have name, salary, and rn columns"
);
}
#[test]
fn test_auto_partition_splits_large_snapshot() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let col = vec!["id".to_string()];
let rows: Vec<Row> = (0i64..1000)
.map(|i| Row::new(col.clone(), vec![Value::I64(i)]))
.collect();
let provider = OxiSqlTableProvider::from_rows(rows, schema).with_auto_partition(4, 100);
assert!(
provider.partition_count() > 1,
"with_auto_partition(4, 100) on 1000 rows must create more than one partition; \
got partition_count() = {}",
provider.partition_count()
);
}
#[test]
fn test_auto_partition_single_for_small_snapshot() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let col = vec!["id".to_string()];
let rows: Vec<Row> = (0i64..10)
.map(|i| Row::new(col.clone(), vec![Value::I64(i)]))
.collect();
let provider = OxiSqlTableProvider::from_rows(rows, schema).with_auto_partition(4, 100);
assert_eq!(
provider.partition_count(),
1,
"with_auto_partition(4, 100) on 10 rows must remain a single partition; \
got partition_count() = {}",
provider.partition_count()
);
}