import subprocess
import json
import sys
import os
def run_sql_query(query):
result = subprocess.run(
['./target/release/sql-cli', '-q', query, '-o', 'json'],
capture_output=True,
text=True
)
if result.returncode != 0:
print(f"Error running query: {result.stderr}", file=sys.stderr)
sys.exit(1)
return json.loads(result.stdout)
def test_implicit_frame_with_order_by():
query = """
WITH test_data AS (
SELECT 1 as id, 10 as value UNION ALL
SELECT 2, 20 UNION ALL
SELECT 3, 30 UNION ALL
SELECT 4, 40 UNION ALL
SELECT 5, 50
)
SELECT
id,
value,
SUM(value) OVER (ORDER BY id) as sum_implicit,
SUM(value) OVER (ORDER BY id RANGE UNBOUNDED PRECEDING) as sum_explicit
FROM test_data
ORDER BY id
"""
results = run_sql_query(query)
assert len(results) == 5, f"Expected 5 rows, got {len(results)}"
expected_sums = [10, 30, 60, 100, 150]
for i, row in enumerate(results):
assert row['id'] == i + 1
assert row['sum_implicit'] == expected_sums[i], \
f"Row {i+1}: implicit sum {row['sum_implicit']} != expected {expected_sums[i]}"
assert row['sum_explicit'] == expected_sums[i], \
f"Row {i+1}: explicit sum {row['sum_explicit']} != expected {expected_sums[i]}"
assert row['sum_implicit'] == row['sum_explicit'], \
f"Row {i+1}: implicit and explicit sums don't match"
def test_no_order_by_uses_all_rows():
query = """
WITH test_data AS (
SELECT 1 as id, 10 as value UNION ALL
SELECT 2, 20 UNION ALL
SELECT 3, 30 UNION ALL
SELECT 4, 40 UNION ALL
SELECT 5, 50
)
SELECT
id,
value,
1 as dummy,
SUM(value) OVER (PARTITION BY dummy) as sum_partition,
AVG(value) OVER (PARTITION BY dummy) as avg_partition,
COUNT(*) OVER (PARTITION BY dummy) as count_partition
FROM test_data
ORDER BY id
"""
results = run_sql_query(query)
total_sum = 150
avg_value = 30 total_count = 5
for row in results:
assert row['sum_partition'] == total_sum, \
f"Row {row['id']}: sum_partition {row['sum_partition']} != total {total_sum}"
assert row['avg_partition'] == avg_value, \
f"Row {row['id']}: avg_partition {row['avg_partition']} != expected {avg_value}"
assert row['count_partition'] == total_count, \
f"Row {row['id']}: count_partition {row['count_partition']} != expected {total_count}"
def test_explicit_frame_overrides_implicit():
query = """
WITH test_data AS (
SELECT 1 as id, 10 as value UNION ALL
SELECT 2, 20 UNION ALL
SELECT 3, 30 UNION ALL
SELECT 4, 40 UNION ALL
SELECT 5, 50
)
SELECT
id,
value,
SUM(value) OVER (ORDER BY id) as sum_implicit,
SUM(value) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) as sum_2_rows,
SUM(value) OVER (ORDER BY id ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as sum_future
FROM test_data
ORDER BY id
"""
results = run_sql_query(query)
expected_2_rows = [10, 30, 50, 70, 90] expected_future = [150, 140, 120, 90, 50]
for i, row in enumerate(results):
assert row['sum_2_rows'] == expected_2_rows[i], \
f"Row {i+1}: 2-row sum {row['sum_2_rows']} != expected {expected_2_rows[i]}"
assert row['sum_future'] == expected_future[i], \
f"Row {i+1}: future sum {row['sum_future']} != expected {expected_future[i]}"
def test_range_vs_rows_semantics():
query = """
WITH test_data AS (
SELECT 1 as id, 10 as value UNION ALL
SELECT 2, 20 UNION ALL
SELECT 3, 30 UNION ALL
SELECT 4, 40 UNION ALL
SELECT 5, 50
)
SELECT
id,
value,
SUM(value) OVER (ORDER BY id) as sum_implicit,
SUM(value) OVER (ORDER BY id RANGE UNBOUNDED PRECEDING) as sum_range_explicit,
SUM(value) OVER (ORDER BY id ROWS UNBOUNDED PRECEDING) as sum_rows_explicit
FROM test_data
ORDER BY id
"""
results = run_sql_query(query)
expected_sums = [10, 30, 60, 100, 150]
for i, row in enumerate(results):
assert row['sum_implicit'] == expected_sums[i], \
f"Row {i+1}: implicit sum {row['sum_implicit']} != expected {expected_sums[i]}"
assert row['sum_range_explicit'] == expected_sums[i], \
f"Row {i+1}: range sum {row['sum_range_explicit']} != expected {expected_sums[i]}"
assert row['sum_rows_explicit'] == expected_sums[i], \
f"Row {i+1}: rows sum {row['sum_rows_explicit']} != expected {expected_sums[i]}"
assert row['sum_implicit'] == row['sum_range_explicit'], \
f"Row {i+1}: implicit doesn't match RANGE semantics"
def main():
tests = [
test_implicit_frame_with_order_by,
test_no_order_by_uses_all_rows,
test_explicit_frame_overrides_implicit,
test_range_vs_rows_semantics
]
for test in tests:
print(f"Running {test.__name__}...", end=' ')
try:
test()
print("✓ PASSED")
except AssertionError as e:
print("✗ FAILED")
print(f" {e}")
sys.exit(1)
except Exception as e:
print("✗ ERROR")
print(f" {e}")
sys.exit(1)
print("\nAll implicit frame tests passed!")
if __name__ == "__main__":
main()