use sql_cli::data::data_view::DataView;
use sql_cli::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
use sql_cli::data::query_engine::QueryEngine;
use std::sync::Arc;
fn get_value(view: &DataView, row_idx: usize, col_idx: usize) -> DataValue {
view.get_row(row_idx).unwrap().get(col_idx).unwrap().clone()
}
fn create_test_table() -> Arc<DataTable> {
let mut table = DataTable::new("test_math");
table.add_column(DataColumn::new("id"));
table.add_column(DataColumn::new("quantity"));
table.add_column(DataColumn::new("price"));
table.add_column(DataColumn::new("discount"));
table.add_column(DataColumn::new("negative"));
table
.add_row(DataRow::new(vec![
DataValue::Integer(1),
DataValue::Integer(10),
DataValue::Float(25.456),
DataValue::Float(2.5),
DataValue::Float(-15.789),
]))
.unwrap();
table
.add_row(DataRow::new(vec![
DataValue::Integer(2),
DataValue::Integer(7),
DataValue::Float(99.999),
DataValue::Float(10.0),
DataValue::Integer(-42),
]))
.unwrap();
table
.add_row(DataRow::new(vec![
DataValue::Integer(3),
DataValue::Integer(3),
DataValue::Float(15.234),
DataValue::Float(1.567),
DataValue::Float(-3.14159),
]))
.unwrap();
Arc::new(table)
}
#[test]
fn test_round_basic() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, ROUND(price) as rounded_price FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Integer(25));
assert_eq!(get_value(&view, 1, 1), DataValue::Integer(100));
assert_eq!(get_value(&view, 2, 1), DataValue::Integer(15));
}
#[test]
fn test_round_with_decimals() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, ROUND(price, 2) as rounded_price FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Float(25.46));
assert_eq!(get_value(&view, 1, 1), DataValue::Float(100.0));
assert_eq!(get_value(&view, 2, 1), DataValue::Float(15.23));
}
#[test]
fn test_round_with_nested_expression() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, ROUND(quantity * price / 100, 3) as result FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Float(2.546));
assert_eq!(get_value(&view, 1, 1), DataValue::Float(7.0));
assert_eq!(get_value(&view, 2, 1), DataValue::Float(0.457));
}
#[test]
fn test_abs_function() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, ABS(negative) as abs_value FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Float(15.789));
assert_eq!(get_value(&view, 1, 1), DataValue::Integer(42));
assert_eq!(get_value(&view, 2, 1), DataValue::Float(3.14159));
}
#[test]
fn test_floor_function() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, FLOOR(price) as floor_price FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Integer(25));
assert_eq!(get_value(&view, 1, 1), DataValue::Integer(99));
assert_eq!(get_value(&view, 2, 1), DataValue::Integer(15));
}
#[test]
fn test_ceiling_function() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, CEILING(price) as ceil_price FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Integer(26));
assert_eq!(get_value(&view, 1, 1), DataValue::Integer(100));
assert_eq!(get_value(&view, 2, 1), DataValue::Integer(16));
}
#[test]
fn test_ceil_alias_function() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, CEIL(discount) as ceil_discount FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Integer(3));
assert_eq!(get_value(&view, 1, 1), DataValue::Integer(10));
assert_eq!(get_value(&view, 2, 1), DataValue::Integer(2));
}
#[test]
fn test_functions_in_where_clause() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, price FROM test_math WHERE ROUND(price) > 50",
)
.unwrap();
assert_eq!(view.row_count(), 1);
assert_eq!(get_value(&view, 0, 0), DataValue::Integer(2));
}
#[test]
fn test_nested_functions() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, ROUND(ABS(negative), 1) as result FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Float(15.8));
assert_eq!(get_value(&view, 1, 1), DataValue::Integer(42));
assert_eq!(get_value(&view, 2, 1), DataValue::Float(3.1));
}
#[test]
fn test_functions_with_order_by() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, ROUND(price - discount, 2) as net_price
FROM test_math
ORDER BY net_price DESC",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 0), DataValue::Integer(2)); assert_eq!(get_value(&view, 0, 1), DataValue::Float(90.0));
assert_eq!(get_value(&view, 1, 0), DataValue::Integer(1)); assert_eq!(get_value(&view, 1, 1), DataValue::Float(22.96));
assert_eq!(get_value(&view, 2, 0), DataValue::Integer(3)); assert_eq!(get_value(&view, 2, 1), DataValue::Float(13.67));
}
#[test]
fn test_complex_expression_with_functions() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, quantity * ROUND(price, 1) - ABS(negative) as complex_calc
FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Float(239.211));
assert_eq!(get_value(&view, 1, 1), DataValue::Float(658.0));
let val = get_value(&view, 2, 1);
if let DataValue::Float(f) = val {
assert!((f - 42.45841).abs() < 0.00001);
} else {
panic!("Expected Float value");
}
}
#[test]
fn test_functions_with_select_star() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT *, ROUND(quantity * price, 2) as total FROM test_math WHERE id = 1",
)
.unwrap();
assert_eq!(view.row_count(), 1);
assert_eq!(view.column_count(), 6);
assert_eq!(get_value(&view, 0, 5), DataValue::Float(254.56)); }
#[test]
fn test_mod_function() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, MOD(id, 2) as even_odd, MOD(quantity, 3) as q_mod_3 FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Integer(1));
assert_eq!(get_value(&view, 0, 2), DataValue::Integer(1));
assert_eq!(get_value(&view, 1, 1), DataValue::Integer(0));
assert_eq!(get_value(&view, 1, 2), DataValue::Integer(1));
assert_eq!(get_value(&view, 2, 1), DataValue::Integer(1));
assert_eq!(get_value(&view, 2, 2), DataValue::Integer(0));
}
#[test]
fn test_quotient_function() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, QUOTIENT(quantity, 3) as q_div_3, QUOTIENT(ROUND(price), 10) as price_bucket FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Integer(3));
assert_eq!(get_value(&view, 0, 2), DataValue::Integer(2));
assert_eq!(get_value(&view, 1, 1), DataValue::Integer(2));
assert_eq!(get_value(&view, 1, 2), DataValue::Integer(10));
assert_eq!(get_value(&view, 2, 1), DataValue::Integer(1));
assert_eq!(get_value(&view, 2, 2), DataValue::Integer(1));
}
#[test]
fn test_power_and_sqrt_functions() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, POWER(2, quantity) as two_to_q, SQRT(quantity) as sqrt_q, POW(quantity, 0.5) as pow_half FROM test_math WHERE id <= 2",
)
.unwrap();
assert_eq!(view.row_count(), 2);
assert_eq!(get_value(&view, 0, 1), DataValue::Float(1024.0));
let sqrt_10 = get_value(&view, 0, 2);
if let DataValue::Float(f) = sqrt_10 {
assert!((f - 3.1622776).abs() < 0.0001);
}
assert_eq!(get_value(&view, 1, 1), DataValue::Float(128.0));
let sqrt_7 = get_value(&view, 1, 2);
if let DataValue::Float(f) = sqrt_7 {
assert!((f - 2.6457513).abs() < 0.0001);
}
}
#[test]
fn test_logarithm_functions() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, LOG10(100) as log10_100, LN(EXP(1)) as ln_e, LOG(8, 2) as log2_8 FROM test_math WHERE id = 1",
)
.unwrap();
assert_eq!(view.row_count(), 1);
assert_eq!(get_value(&view, 0, 1), DataValue::Float(2.0));
assert_eq!(get_value(&view, 0, 2), DataValue::Float(1.0));
assert_eq!(get_value(&view, 0, 3), DataValue::Float(3.0));
}
#[test]
fn test_pi_function() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, PI() as pi_value, ROUND(PI(), 5) as pi_rounded FROM test_math WHERE id = 1",
)
.unwrap();
assert_eq!(view.row_count(), 1);
let pi = get_value(&view, 0, 1);
if let DataValue::Float(f) = pi {
assert!((f - std::f64::consts::PI).abs() < 0.000001);
}
assert_eq!(get_value(&view, 0, 2), DataValue::Float(3.14159));
}
#[test]
fn test_math_functions_in_where() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, quantity FROM test_math WHERE MOD(quantity, 2) = 0",
)
.unwrap();
assert_eq!(view.row_count(), 1);
assert_eq!(get_value(&view, 0, 0), DataValue::Integer(1));
assert_eq!(get_value(&view, 0, 1), DataValue::Integer(10));
}
#[test]
fn test_combined_math_functions() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, ROUND(SQRT(POWER(quantity, 2) + POWER(3, 2)), 2) as hypotenuse FROM test_math",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 1), DataValue::Float(10.44));
assert_eq!(get_value(&view, 1, 1), DataValue::Float(7.62));
assert_eq!(get_value(&view, 2, 1), DataValue::Float(4.24));
}