import subprocess
import csv
import io
import json
from pathlib import Path
import tempfile
import pandas as pd
SQL_CLI = Path(__file__).parent.parent.parent / "target" / "release" / "sql-cli"
def run_query(csv_file, query):
result = subprocess.run(
[str(SQL_CLI), str(csv_file), "-q", query, "-o", "csv"],
capture_output=True,
text=True
)
if result.returncode != 0:
raise Exception(f"Query failed: {result.stderr}")
output_lines = result.stdout.strip().split('\n')
csv_lines = [line for line in output_lines if not line.startswith('#')]
if not csv_lines:
return []
reader = csv.DictReader(io.StringIO('\n'.join(csv_lines)))
return list(reader)
def test_lag_basic():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("id,value\n")
f.write("1,100\n")
f.write("2,200\n")
f.write("3,300\n")
f.write("4,400\n")
temp_file = f.name
try:
query = "SELECT id, value, LAG(value, 1) OVER (ORDER BY id) as prev_value FROM test"
results = run_query(temp_file, query)
assert len(results) == 4
assert results[0]['prev_value'] == '' assert results[1]['prev_value'] == '100'
assert results[2]['prev_value'] == '200'
assert results[3]['prev_value'] == '300'
print("✓ test_lag_basic passed")
finally:
Path(temp_file).unlink()
def test_lead_basic():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("id,value\n")
f.write("1,100\n")
f.write("2,200\n")
f.write("3,300\n")
f.write("4,400\n")
temp_file = f.name
try:
query = "SELECT id, value, LEAD(value, 1) OVER (ORDER BY id) as next_value FROM test"
results = run_query(temp_file, query)
assert len(results) == 4
assert results[0]['next_value'] == '200'
assert results[1]['next_value'] == '300'
assert results[2]['next_value'] == '400'
assert results[3]['next_value'] == '' print("✓ test_lead_basic passed")
finally:
Path(temp_file).unlink()
def test_row_number_basic():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("name,score\n")
f.write("Alice,95\n")
f.write("Bob,87\n")
f.write("Charlie,92\n")
f.write("Diana,88\n")
temp_file = f.name
try:
query = "SELECT name, score, ROW_NUMBER() OVER (ORDER BY score DESC) as rank FROM test"
results = run_query(temp_file, query)
assert len(results) == 4
assert results[0]['name'] == 'Alice' and results[0]['rank'] == '1'
assert results[1]['name'] == 'Bob' and results[1]['rank'] == '4'
assert results[2]['name'] == 'Charlie' and results[2]['rank'] == '2'
assert results[3]['name'] == 'Diana' and results[3]['rank'] == '3'
print("✓ test_row_number_basic passed")
finally:
Path(temp_file).unlink()
def test_partition_by():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("department,employee,salary\n")
f.write("Sales,Alice,50000\n")
f.write("Sales,Bob,45000\n")
f.write("Sales,Charlie,55000\n")
f.write("IT,David,60000\n")
f.write("IT,Eve,65000\n")
f.write("IT,Frank,58000\n")
temp_file = f.name
try:
query = """
SELECT department, employee, salary,
ROW_NUMBER() OVER (PARTITION BY department ORDER BY salary DESC) as dept_rank
FROM test
"""
results = run_query(temp_file, query)
sales = [r for r in results if r['department'] == 'Sales']
assert sales[0]['employee'] == 'Alice' and sales[0]['dept_rank'] == '2'
assert sales[1]['employee'] == 'Bob' and sales[1]['dept_rank'] == '3'
assert sales[2]['employee'] == 'Charlie' and sales[2]['dept_rank'] == '1'
it = [r for r in results if r['department'] == 'IT']
assert it[0]['employee'] == 'David' and it[0]['dept_rank'] == '2'
assert it[1]['employee'] == 'Eve' and it[1]['dept_rank'] == '1'
assert it[2]['employee'] == 'Frank' and it[2]['dept_rank'] == '3'
print("✓ test_partition_by passed")
finally:
Path(temp_file).unlink()
def test_lag_with_partition():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("category,month,sales\n")
f.write("A,1,100\n")
f.write("A,2,150\n")
f.write("A,3,200\n")
f.write("B,1,300\n")
f.write("B,2,350\n")
f.write("B,3,400\n")
temp_file = f.name
try:
query = """
SELECT category, month, sales,
LAG(sales, 1) OVER (PARTITION BY category ORDER BY month) as prev_sales
FROM test
"""
results = run_query(temp_file, query)
cat_a = [r for r in results if r['category'] == 'A']
assert cat_a[0]['prev_sales'] == '' assert cat_a[1]['prev_sales'] == '100'
assert cat_a[2]['prev_sales'] == '150'
cat_b = [r for r in results if r['category'] == 'B']
assert cat_b[0]['prev_sales'] == '' assert cat_b[1]['prev_sales'] == '300'
assert cat_b[2]['prev_sales'] == '350'
print("✓ test_lag_with_partition passed")
finally:
Path(temp_file).unlink()
def test_multiple_window_functions():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("id,value\n")
f.write("1,10\n")
f.write("2,20\n")
f.write("3,30\n")
f.write("4,40\n")
temp_file = f.name
try:
query = """
SELECT id, value,
LAG(value, 1) OVER (ORDER BY id) as prev_val,
LEAD(value, 1) OVER (ORDER BY id) as next_val,
ROW_NUMBER() OVER (ORDER BY value DESC) as rank
FROM test
"""
results = run_query(temp_file, query)
assert results[0]['id'] == '1'
assert results[0]['prev_val'] == '' assert results[0]['next_val'] == '20'
assert results[0]['rank'] == '4'
assert results[3]['id'] == '4'
assert results[3]['prev_val'] == '30'
assert results[3]['next_val'] == '' assert results[3]['rank'] == '1'
print("✓ test_multiple_window_functions passed")
finally:
Path(temp_file).unlink()
def test_first_last_value():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("group_id,sequence,value\n")
f.write("A,1,100\n")
f.write("A,2,200\n")
f.write("A,3,300\n")
f.write("B,1,400\n")
f.write("B,2,500\n")
temp_file = f.name
try:
query = """
SELECT group_id, sequence, value,
FIRST_VALUE(value) OVER (PARTITION BY group_id ORDER BY sequence) as first_val,
LAST_VALUE(value) OVER (PARTITION BY group_id ORDER BY sequence) as last_val
FROM test
"""
results = run_query(temp_file, query)
group_a = [r for r in results if r['group_id'] == 'A']
for row in group_a:
assert row['first_val'] == '100'
assert group_a[0]['last_val'] == '100' assert group_a[1]['last_val'] == '200' assert group_a[2]['last_val'] == '300'
group_b = [r for r in results if r['group_id'] == 'B']
assert group_b[0]['first_val'] == '400'
assert group_b[0]['last_val'] == '400' assert group_b[1]['first_val'] == '400'
assert group_b[1]['last_val'] == '500'
print("✓ test_first_last_value passed")
finally:
Path(temp_file).unlink()
def test_lag_with_offset():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("id,value\n")
for i in range(1, 11):
f.write(f"{i},{i*10}\n")
temp_file = f.name
try:
query = """
SELECT id, value,
LAG(value, 2) OVER (ORDER BY id) as lag_2
FROM test
WHERE id <= 5
"""
results = run_query(temp_file, query)
assert results[0]['lag_2'] == '' assert results[1]['lag_2'] == '' assert results[2]['lag_2'] == '10' assert results[3]['lag_2'] == '20' assert results[4]['lag_2'] == '30'
print("✓ test_lag_with_offset passed")
finally:
Path(temp_file).unlink()
def test_complex_partitioning():
sales_data = Path(__file__).parent.parent.parent / "data" / "sales_data.csv"
query = """
SELECT region, salesperson, month, sales_amount,
ROW_NUMBER() OVER (PARTITION BY region, month ORDER BY sales_amount DESC) as rank
FROM test
"""
results = run_query(sales_data, query)
for region in ['North', 'South', 'East', 'West']:
for month in ['2024-01', '2024-02', '2024-03']:
region_month = [r for r in results
if r['region'] == region and r['month'] == month]
assert len(region_month) == 2
ranks = sorted([int(r['rank']) for r in region_month])
assert ranks == [1, 2]
print("✓ test_complex_partitioning passed")
def test_where_with_window_function():
sales_data = Path(__file__).parent.parent.parent / "data" / "sales_data.csv"
query = """
SELECT region, salesperson, month, sales_amount,
ROW_NUMBER() OVER (PARTITION BY region ORDER BY sales_amount DESC) as rank
FROM test
"""
results = run_query(sales_data, query)
top_performers = [r for r in results if r['rank'] == '1']
regions = set(r['region'] for r in top_performers)
assert len(regions) == 4
for region in regions:
region_top = [r for r in top_performers if r['region'] == region][0]
region_all = [r for r in results if r['region'] == region]
max_sales = max(int(r['sales_amount']) for r in region_all)
assert int(region_top['sales_amount']) == max_sales
print("✓ test_where_with_window_function passed")
def main():
print("Running window function tests...")
if not SQL_CLI.exists():
print(f"Error: sql-cli not found at {SQL_CLI}")
print("Please run: cargo build --release")
return 1
tests = [
test_lag_basic,
test_lead_basic,
test_row_number_basic,
test_partition_by,
test_lag_with_partition,
test_multiple_window_functions,
test_first_last_value,
test_lag_with_offset,
test_complex_partitioning,
test_where_with_window_function,
]
failed = 0
for test in tests:
try:
test()
except Exception as e:
print(f"✗ {test.__name__} failed: {e}")
failed += 1
if failed == 0:
print(f"\n✅ All {len(tests)} window function tests passed!")
else:
print(f"\n❌ {failed}/{len(tests)} tests failed")
return 1
return 0
if __name__ == "__main__":
exit(main())