use sql_cli::data::data_view::DataView;
use sql_cli::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
use sql_cli::data::query_engine::QueryEngine;
use std::sync::Arc;
fn get_value(view: &DataView, row_idx: usize, col_idx: usize) -> DataValue {
view.get_row(row_idx).unwrap().get(col_idx).unwrap().clone()
}
fn create_test_table() -> Arc<DataTable> {
let mut table = DataTable::new("test_arithmetic");
table.add_column(DataColumn::new("id"));
table.add_column(DataColumn::new("quantity"));
table.add_column(DataColumn::new("price"));
table.add_column(DataColumn::new("discount"));
table
.add_row(DataRow::new(vec![
DataValue::Integer(1),
DataValue::Integer(10),
DataValue::Float(25.50),
DataValue::Float(2.50),
]))
.unwrap();
table
.add_row(DataRow::new(vec![
DataValue::Integer(2),
DataValue::Integer(5),
DataValue::Float(100.00),
DataValue::Float(10.00),
]))
.unwrap();
table
.add_row(DataRow::new(vec![
DataValue::Integer(3),
DataValue::Integer(3),
DataValue::Integer(15),
DataValue::Float(1.50),
]))
.unwrap();
table
.add_row(DataRow::new(vec![
DataValue::Integer(4),
DataValue::Integer(8),
DataValue::Float(45.75),
DataValue::Float(5.00),
]))
.unwrap();
Arc::new(table)
}
#[test]
fn test_where_with_multiplication() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, quantity, price FROM test_arithmetic WHERE quantity * price > 200",
)
.unwrap();
assert_eq!(
view.row_count(),
3,
"Should return 3 rows where quantity * price > 200"
);
assert_eq!(get_value(&view, 0, 0), DataValue::Integer(1));
assert_eq!(get_value(&view, 1, 0), DataValue::Integer(2));
assert_eq!(get_value(&view, 2, 0), DataValue::Integer(4));
}
#[test]
fn test_where_with_subtraction() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, price, discount FROM test_arithmetic WHERE price - discount < 50",
)
.unwrap();
assert_eq!(
view.row_count(),
3,
"Should return 3 rows where price - discount < 50"
);
}
#[test]
fn test_where_with_addition() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id FROM test_arithmetic WHERE price + discount > 100",
)
.unwrap();
assert_eq!(
view.row_count(),
1,
"Should return 1 row where price + discount > 100"
);
assert_eq!(get_value(&view, 0, 0), DataValue::Integer(2));
}
#[test]
fn test_where_with_division() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, price, quantity FROM test_arithmetic WHERE price / quantity > 10",
)
.unwrap();
assert_eq!(
view.row_count(),
1,
"Should return 1 row where price / quantity > 10"
);
assert_eq!(get_value(&view, 0, 0), DataValue::Integer(2));
}
#[test]
fn test_where_with_complex_expression() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id FROM test_arithmetic WHERE (quantity * price) - discount > 250",
)
.unwrap();
assert_eq!(
view.row_count(),
3,
"Should return 3 rows where (quantity * price) - discount > 250"
);
}
#[test]
fn test_where_expression_with_computed_select() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, quantity * price as total, price - discount as net_price
FROM test_arithmetic
WHERE quantity * price > 200",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 0), DataValue::Integer(1));
assert_eq!(get_value(&view, 0, 1), DataValue::Float(255.0));
assert_eq!(get_value(&view, 0, 2), DataValue::Float(23.0));
assert_eq!(get_value(&view, 1, 0), DataValue::Integer(2));
assert_eq!(get_value(&view, 1, 1), DataValue::Float(500.0));
assert_eq!(get_value(&view, 1, 2), DataValue::Float(90.0));
}
#[test]
fn test_where_expression_with_order_by() {
let table = create_test_table();
let engine = QueryEngine::new();
let view = engine
.execute(
table.clone(),
"SELECT id, quantity * price as total
FROM test_arithmetic
WHERE quantity * price > 100
ORDER BY total DESC",
)
.unwrap();
assert_eq!(view.row_count(), 3);
assert_eq!(get_value(&view, 0, 0), DataValue::Integer(2)); assert_eq!(get_value(&view, 1, 0), DataValue::Integer(4)); assert_eq!(get_value(&view, 2, 0), DataValue::Integer(1)); }