use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use arrow::array::Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_flight::sql::client::FlightSqlServiceClient;
use futures::TryStreamExt;
use rhei_core::OlapEngine;
use rhei_olap::OlapBackend;
use tonic::transport::Channel;
async fn start_server(olap: OlapBackend) -> SocketAddr {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let service = rhei_flight::RheiFlightSqlService::new(olap);
let svc = arrow_flight::flight_service_server::FlightServiceServer::new(service);
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.unwrap();
});
tokio::time::sleep(Duration::from_millis(50)).await;
addr
}
async fn connect(addr: SocketAddr) -> FlightSqlServiceClient<Channel> {
let channel = Channel::from_shared(format!("http://{addr}"))
.unwrap()
.connect()
.await
.unwrap();
FlightSqlServiceClient::new(channel)
}
async fn collect_batches(
client: &mut FlightSqlServiceClient<Channel>,
info: arrow_flight::FlightInfo,
) -> Vec<RecordBatch> {
let mut batches = Vec::new();
for endpoint in info.endpoint {
if let Some(ticket) = endpoint.ticket {
let stream = client.do_get(ticket).await.unwrap();
let flight_batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
batches.extend(flight_batches);
}
}
batches
}
#[cfg(feature = "datafusion-backend")]
mod datafusion_tests {
use super::*;
use rhei_olap::{DataFusionEngine, SharedDataFusionEngine};
async fn setup() -> (SocketAddr, OlapBackend) {
let engine = SharedDataFusionEngine::new(DataFusionEngine::new());
let olap = OlapBackend::DataFusion(engine.clone());
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("score", DataType::Float64, true),
]));
engine.create_table("students", &schema, &[]).await.unwrap();
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3, 4])),
Arc::new(arrow::array::StringArray::from(vec![
"Alice", "Bob", "Carol", "Dave",
])),
Arc::new(arrow::array::Float64Array::from(vec![
Some(95.5),
Some(87.3),
Some(92.1),
None,
])),
],
)
.unwrap();
engine.load_arrow("students", &[batch]).await.unwrap();
let addr = start_server(olap.clone()).await;
(addr, olap)
}
#[tokio::test]
async fn test_handshake() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let result = client.handshake("user", "pass").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_simple_select() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT * FROM students".into(), None)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 4);
assert_eq!(batches[0].num_columns(), 3);
}
#[tokio::test]
async fn test_select_with_where() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT name FROM students WHERE id = 1".into(), None)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 1);
let col = batches[0]
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
assert_eq!(col.value(0), "Alice");
}
#[tokio::test]
async fn test_aggregate_query() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute(
"SELECT COUNT(*) as cnt, AVG(score) as avg_score FROM students".into(),
None,
)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
assert_eq!(batches[0].num_rows(), 1);
let cnt = batches[0]
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.unwrap();
assert_eq!(cnt.value(0), 4);
}
#[tokio::test]
async fn test_null_handling() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT score FROM students WHERE id = 4".into(), None)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
let col = batches[0]
.column(0)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.unwrap();
assert!(col.is_null(0));
}
#[tokio::test]
async fn test_empty_result() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT * FROM students WHERE id = 999".into(), None)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 0);
}
#[tokio::test]
async fn test_schema_from_results() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT * FROM students".into(), None)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
let schema = batches[0].schema();
assert_eq!(schema.fields().len(), 3);
assert_eq!(schema.field(0).name(), "id");
assert_eq!(schema.field(1).name(), "name");
assert_eq!(schema.field(2).name(), "score");
}
#[tokio::test]
async fn test_prepared_statement() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let mut stmt = client
.prepare(
"SELECT name, score FROM students WHERE score > 90".into(),
None,
)
.await
.unwrap();
let info = stmt.execute().await.unwrap();
let batches = collect_batches(&mut client, info).await;
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 2);
stmt.close().await.unwrap();
}
#[tokio::test]
async fn test_write_rejected() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let result = client
.execute_update("INSERT INTO students VALUES (5, 'Eve', 88.0)".into(), None)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_invalid_sql() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT * FROM nonexistent_table".into(), None)
.await
.unwrap();
for endpoint in info.endpoint {
if let Some(ticket) = endpoint.ticket {
let result = client.do_get(ticket).await;
assert!(result.is_err(), "do_get should fail for invalid SQL");
}
}
}
#[tokio::test]
async fn test_concurrent_queries() {
let (addr, _) = setup().await;
let mut handles = Vec::new();
for i in 0..8 {
let addr = addr;
handles.push(tokio::spawn(async move {
let mut client = connect(addr).await;
let sql = format!("SELECT * FROM students WHERE id >= {}", i % 4);
let info = client.execute(sql, None).await.unwrap();
let batches = collect_batches(&mut client, info).await;
let rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert!(rows > 0);
}));
}
for h in handles {
h.await.unwrap();
}
}
#[tokio::test]
async fn test_multiple_tables() {
let engine = SharedDataFusionEngine::new(DataFusionEngine::new());
let olap = OlapBackend::DataFusion(engine.clone());
let schema1 = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("value", DataType::Int64, false),
]));
engine.create_table("users", &schema1, &[]).await.unwrap();
engine.create_table("scores", &schema2, &[]).await.unwrap();
let users_batch = RecordBatch::try_new(
schema1,
vec![
Arc::new(arrow::array::Int64Array::from(vec![1])),
Arc::new(arrow::array::StringArray::from(vec!["Alice"])),
],
)
.unwrap();
engine.load_arrow("users", &[users_batch]).await.unwrap();
let scores_batch = RecordBatch::try_new(
schema2,
vec![
Arc::new(arrow::array::Int64Array::from(vec![1])),
Arc::new(arrow::array::Int64Array::from(vec![100])),
],
)
.unwrap();
engine.load_arrow("scores", &[scores_batch]).await.unwrap();
let addr = start_server(olap).await;
let mut client = connect(addr).await;
let info = client
.execute(
"SELECT u.name, s.value FROM users u JOIN scores s ON u.id = s.id".into(),
None,
)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
assert_eq!(batches[0].num_rows(), 1);
let name = batches[0]
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
assert_eq!(name.value(0), "Alice");
}
}
#[cfg(feature = "duckdb-backend")]
mod duckdb_tests {
use super::*;
use rhei_olap::{DuckDbEngine, SharedDuckDbEngine};
async fn setup() -> (SocketAddr, OlapBackend) {
let engine = SharedDuckDbEngine::new(DuckDbEngine::in_memory().unwrap());
let olap = OlapBackend::DuckDb(engine.clone());
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("score", DataType::Float64, true),
]));
engine.create_table("students", &schema, &[]).await.unwrap();
engine
.execute("INSERT INTO students VALUES (1, 'Alice', 95.5)")
.await
.unwrap();
engine
.execute("INSERT INTO students VALUES (2, 'Bob', 87.3)")
.await
.unwrap();
engine
.execute("INSERT INTO students VALUES (3, 'Carol', 92.1)")
.await
.unwrap();
let addr = start_server(olap.clone()).await;
(addr, olap)
}
#[tokio::test]
async fn test_simple_select() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT * FROM students".into(), None)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3);
}
#[tokio::test]
async fn test_aggregate_query() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT SUM(score) as total FROM students".into(), None)
.await
.unwrap();
let batches = collect_batches(&mut client, info).await;
assert_eq!(batches[0].num_rows(), 1);
}
#[tokio::test]
async fn test_prepared_statement() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let mut stmt = client
.prepare("SELECT name FROM students ORDER BY name".into(), None)
.await
.unwrap();
let info = stmt.execute().await.unwrap();
let batches = collect_batches(&mut client, info).await;
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3);
stmt.close().await.unwrap();
}
#[tokio::test]
async fn test_write_rejected() {
let (addr, _) = setup().await;
let mut client = connect(addr).await;
let result = client
.execute_update("DELETE FROM students WHERE id = 1".into(), None)
.await;
assert!(result.is_err());
}
}
#[cfg(feature = "datafusion-backend")]
mod auth_tests {
use super::*;
use rhei_olap::{DataFusionEngine, SharedDataFusionEngine};
async fn start_authed_server(token: &str) -> SocketAddr {
let engine = SharedDataFusionEngine::new(DataFusionEngine::new());
let olap = OlapBackend::DataFusion(engine.clone());
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("val", DataType::Int64, false),
]));
engine.create_table("t", &schema, &[]).await.unwrap();
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(arrow::array::Int64Array::from(vec![1_i64])),
Arc::new(arrow::array::Int64Array::from(vec![42_i64])),
],
)
.unwrap();
engine.load_arrow("t", &[batch]).await.unwrap();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let service =
rhei_flight::RheiFlightSqlService::new(olap).with_auth_token(token.to_string());
let svc = arrow_flight::flight_service_server::FlightServiceServer::new(service);
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.unwrap();
});
tokio::time::sleep(Duration::from_millis(50)).await;
addr
}
async fn connect_with_token(addr: SocketAddr, token: &str) -> FlightSqlServiceClient<Channel> {
let mut client = connect(addr).await;
client.set_token(token.to_string());
client
}
#[tokio::test]
async fn test_valid_token_allows_query() {
let addr = start_authed_server("secret").await;
let mut client = connect_with_token(addr, "secret").await;
let info = client
.execute("SELECT * FROM t".into(), None)
.await
.expect("query with valid token should succeed");
let batches = collect_batches(&mut client, info).await;
let rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(rows, 1);
}
#[tokio::test]
async fn test_missing_token_rejected() {
let addr = start_authed_server("secret").await;
let mut client = connect(addr).await;
let result = client.execute("SELECT * FROM t".into(), None).await;
assert!(result.is_err(), "request without token should be rejected");
let err = result.unwrap_err().to_string().to_lowercase();
assert!(
err.contains("unauthenticated") || err.contains("bearer"),
"error should mention unauthenticated/bearer, got: {err}"
);
}
#[tokio::test]
async fn test_wrong_token_rejected() {
let addr = start_authed_server("secret").await;
let mut client = connect_with_token(addr, "wrong-token").await;
let result = client.execute("SELECT * FROM t".into(), None).await;
assert!(
result.is_err(),
"request with wrong token should be rejected"
);
let err = result.unwrap_err().to_string().to_lowercase();
assert!(
err.contains("unauthenticated") || err.contains("bearer"),
"error should mention unauthenticated/bearer, got: {err}"
);
}
#[tokio::test]
async fn test_no_auth_server_accepts_any_request() {
let engine = SharedDataFusionEngine::new(DataFusionEngine::new());
let olap = OlapBackend::DataFusion(engine.clone());
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
engine.create_table("empty", &schema, &[]).await.unwrap();
let addr = start_server(olap).await;
let mut client = connect(addr).await;
let info = client
.execute("SELECT * FROM empty".into(), None)
.await
.expect("no-auth server should accept requests without a token");
let batches = collect_batches(&mut client, info).await;
let rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(rows, 0);
}
#[tokio::test]
async fn test_do_get_also_protected() {
let addr = start_authed_server("s3cr3t").await;
let mut auth_client = connect_with_token(addr, "s3cr3t").await;
let info = auth_client
.execute("SELECT * FROM t".into(), None)
.await
.expect("get_flight_info with correct token should succeed");
let mut unauth_client = connect(addr).await;
for endpoint in info.endpoint {
if let Some(ticket) = endpoint.ticket {
let result = unauth_client.do_get(ticket).await;
assert!(result.is_err(), "do_get without token should be rejected");
}
}
}
}
#[cfg(feature = "datafusion-backend")]
mod bench {
use super::*;
use rhei_olap::{DataFusionEngine, SharedDataFusionEngine};
use std::time::Instant;
async fn setup_large(n: usize) -> (SocketAddr, OlapBackend) {
let engine = SharedDataFusionEngine::new(DataFusionEngine::new());
let olap = OlapBackend::DataFusion(engine.clone());
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("user_id", DataType::Int64, false),
Field::new("amount", DataType::Float64, false),
Field::new("category", DataType::Utf8, false),
Field::new("region", DataType::Utf8, false),
]));
engine.create_table("orders", &schema, &[]).await.unwrap();
let categories = ["electronics", "clothing", "food", "books", "sports"];
let regions = ["north", "south", "east", "west"];
let batch_size = 10_000.min(n);
let mut loaded = 0;
while loaded < n {
let chunk = batch_size.min(n - loaded);
let ids: Vec<i64> = (loaded..loaded + chunk).map(|i| i as i64).collect();
let user_ids: Vec<i64> = ids.iter().map(|i| i % 1000).collect();
let amounts: Vec<f64> = ids.iter().map(|i| (*i as f64) * 1.23 % 999.99).collect();
let cats: Vec<&str> = ids.iter().map(|i| categories[*i as usize % 5]).collect();
let regs: Vec<&str> = ids.iter().map(|i| regions[*i as usize % 4]).collect();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::Int64Array::from(ids)),
Arc::new(arrow::array::Int64Array::from(user_ids)),
Arc::new(arrow::array::Float64Array::from(amounts)),
Arc::new(arrow::array::StringArray::from(cats)),
Arc::new(arrow::array::StringArray::from(regs)),
],
)
.unwrap();
engine.load_arrow("orders", &[batch]).await.unwrap();
loaded += chunk;
}
let addr = start_server(olap.clone()).await;
(addr, olap)
}
fn total_rows(batches: &[RecordBatch]) -> usize {
batches.iter().map(|b| b.num_rows()).sum()
}
fn total_bytes(batches: &[RecordBatch]) -> usize {
batches.iter().map(|b| b.get_array_memory_size()).sum()
}
async fn bench_query(
olap: &OlapBackend,
client: &mut FlightSqlServiceClient<Channel>,
sql: &str,
label: &str,
iterations: usize,
) {
let _ = olap.query(sql).await.unwrap();
let info = client.execute(sql.to_string(), None).await.unwrap();
let _ = collect_batches(client, info).await;
let start = Instant::now();
let mut local_rows = 0;
let mut local_bytes = 0;
for _ in 0..iterations {
let batches = olap.query(sql).await.unwrap();
local_rows += total_rows(&batches);
local_bytes += total_bytes(&batches);
}
let local_elapsed = start.elapsed();
let start = Instant::now();
let mut flight_rows = 0;
let mut flight_bytes = 0;
for _ in 0..iterations {
let info = client.execute(sql.to_string(), None).await.unwrap();
let batches = collect_batches(client, info).await;
flight_rows += total_rows(&batches);
flight_bytes += total_bytes(&batches);
}
let flight_elapsed = start.elapsed();
let local_qps = iterations as f64 / local_elapsed.as_secs_f64();
let flight_qps = iterations as f64 / flight_elapsed.as_secs_f64();
let local_mbps = local_bytes as f64 / local_elapsed.as_secs_f64() / 1_048_576.0;
let flight_mbps = flight_bytes as f64 / flight_elapsed.as_secs_f64() / 1_048_576.0;
let overhead = flight_elapsed.as_secs_f64() / local_elapsed.as_secs_f64();
eprintln!(" {label}:");
eprintln!(
" Local: {local_qps:>8.0} q/s {local_mbps:>8.1} MB/s ({local_rows} rows in {:.1}ms)",
local_elapsed.as_secs_f64() * 1000.0
);
eprintln!(
" Flight: {flight_qps:>8.0} q/s {flight_mbps:>8.1} MB/s ({flight_rows} rows in {:.1}ms)",
flight_elapsed.as_secs_f64() * 1000.0
);
eprintln!(" Overhead: {overhead:.2}x");
}
#[tokio::test]
async fn bench_olap_vs_flight() {
if std::env::var("RHEI_TEST_BENCH").is_err() {
eprintln!("skipping OLAP benchmark (set RHEI_TEST_BENCH=1 to run)");
return;
}
let n = 100_000;
eprintln!("\n=== OLAP Benchmark: Local vs FlightSQL ({n} rows) ===\n");
let (addr, olap) = setup_large(n).await;
let mut client = connect(addr).await;
bench_query(
&olap,
&mut client,
"SELECT * FROM orders",
"Full scan (100K rows, 5 cols)",
10,
)
.await;
bench_query(
&olap,
&mut client,
"SELECT category, COUNT(*) as cnt, AVG(amount) as avg_amt FROM orders GROUP BY category",
"GROUP BY aggregate (5 result rows)",
100,
)
.await;
bench_query(
&olap,
&mut client,
"SELECT * FROM orders WHERE region = 'north' AND amount > 500.0",
"Filtered scan (~6K rows)",
50,
)
.await;
bench_query(
&olap,
&mut client,
"SELECT user_id, SUM(amount) as total FROM orders GROUP BY user_id ORDER BY total DESC LIMIT 10",
"Top-10 users by spend",
100,
)
.await;
bench_query(
&olap,
&mut client,
"SELECT category, region, SUM(amount) as total, \
RANK() OVER (PARTITION BY category ORDER BY SUM(amount) DESC) as rank \
FROM orders GROUP BY category, region",
"Window function (rank by category)",
50,
)
.await;
eprintln!("\n=== Benchmark complete ===\n");
}
}