import subprocess
import csv
from io import StringIO
import sys
import os
SQL_CLI = os.path.join(os.path.dirname(__file__), "../../target/release/sql-cli")
def run_query(query):
result = subprocess.run(
[SQL_CLI, "-q", query, "-o", "csv"],
capture_output=True,
text=True,
)
if result.returncode != 0:
raise Exception(f"Query failed: {result.stderr}")
reader = csv.DictReader(StringIO(result.stdout))
return list(reader)
def test_basic_range():
query = "SELECT value FROM RANGE(1, 5)"
result = run_query(query)
expected = [1, 2, 3, 4, 5]
actual = [int(row['value']) for row in result]
assert actual == expected, f"Expected {expected}, got {actual}"
print("✓ Basic RANGE test passed")
def test_range_with_step():
query = "SELECT value FROM RANGE(0, 20, 5)"
result = run_query(query)
expected = [0, 5, 10, 15, 20]
actual = [int(row['value']) for row in result]
assert actual == expected, f"Expected {expected}, got {actual}"
print("✓ RANGE with step test passed")
def test_range_in_cte():
query = """
WITH numbers AS (
SELECT value FROM RANGE(1, 5)
)
SELECT value, value * 2 AS doubled FROM numbers
"""
result = run_query(query)
expected_values = [1, 2, 3, 4, 5]
expected_doubled = [2, 4, 6, 8, 10]
actual_values = [int(row['value']) for row in result]
actual_doubled = [int(row['doubled']) for row in result]
assert actual_values == expected_values, f"Values mismatch"
assert actual_doubled == expected_doubled, f"Doubled values mismatch"
print("✓ RANGE in CTE test passed")
def test_nested_cte_with_range():
query = """
WITH first_range AS (
SELECT value AS n FROM RANGE(1, 3)
),
squared AS (
SELECT n, n * n AS sq FROM first_range
)
SELECT n, sq, sq * 2 AS doubled_square FROM squared
"""
result = run_query(query)
expected = [
{'n': 1, 'sq': 1, 'doubled_square': 2},
{'n': 2, 'sq': 4, 'doubled_square': 8},
{'n': 3, 'sq': 9, 'doubled_square': 18}
]
for i, row in enumerate(result):
assert int(row['n']) == expected[i]['n']
assert int(row['sq']) == expected[i]['sq']
assert int(row['doubled_square']) == expected[i]['doubled_square']
print("✓ Nested CTE with RANGE test passed")
def test_range_with_where_clause():
query = "SELECT value FROM RANGE(1, 20) WHERE value % 3 = 0"
result = run_query(query)
expected = [3, 6, 9, 12, 15, 18]
actual = [int(row['value']) for row in result]
assert actual == expected, f"Expected {expected}, got {actual}"
print("✓ RANGE with WHERE clause test passed")
def test_range_with_is_prime():
query = "SELECT value FROM RANGE(2, 20) WHERE IS_PRIME(value) = true"
result = run_query(query)
expected = [2, 3, 5, 7, 11, 13, 17, 19]
actual = [int(row['value']) for row in result]
assert actual == expected, f"Expected primes {expected}, got {actual}"
print("✓ RANGE with IS_PRIME test passed")
def test_prime_count_in_range():
query = "SELECT COUNT(*) AS prime_count FROM RANGE(2, 100) WHERE IS_PRIME(value) = true"
result = run_query(query)
assert int(result[0]['prime_count']) == 25, f"Expected 25 primes below 100"
print("✓ Prime count in range test passed")
def test_prime_pi_with_range():
query = """
SELECT
value AS n,
PRIME_PI(value) AS primes_up_to_n
FROM RANGE(10, 50, 10)
"""
result = run_query(query)
expected_counts = {
10: 4, 20: 8, 30: 10, 40: 12, 50: 15 }
for row in result:
n = int(row['n'])
expected = expected_counts[n]
actual = int(row['primes_up_to_n'])
assert actual == expected, f"PRIME_PI({n}) should be {expected}, got {actual}"
print("✓ PRIME_PI with RANGE test passed")
def test_prime_density_blocks():
query = """
WITH blocks AS (
SELECT value AS block_num FROM RANGE(1, 5)
)
SELECT
block_num,
PRIME_PI(block_num * 100) - PRIME_PI((block_num - 1) * 100) AS primes_in_block
FROM blocks
"""
result = run_query(query)
expected_counts = {
1: 25, 2: 21, 3: 16, 4: 16, 5: 17 }
for row in result:
block = int(row['block_num'])
expected = expected_counts[block]
actual = int(row['primes_in_block'])
assert actual == expected, f"Block {block} should have {expected} primes, got {actual}"
print("✓ Prime density blocks test passed")
def test_complex_cte_with_aggregates():
query = """
WITH prime_analysis AS (
SELECT
value,
IS_PRIME(value) AS is_prime
FROM RANGE(1, 100)
)
SELECT
COUNT(*) AS total_numbers,
SUM(CASE WHEN is_prime = true THEN 1 ELSE 0 END) AS prime_count,
SUM(CASE WHEN is_prime = false THEN 1 ELSE 0 END) AS non_prime_count
FROM prime_analysis
"""
result = run_query(query)
row = result[0]
assert int(row['total_numbers']) == 100, "Should have 100 total numbers"
assert int(row['prime_count']) == 25, "Should have 25 primes"
assert int(row['non_prime_count']) == 75, "Should have 75 non-primes"
print("✓ Complex CTE with aggregates test passed")
def test_multiple_ranges_in_cte():
query = """
WITH small_range AS (
SELECT value AS small FROM RANGE(1, 3)
),
large_range AS (
SELECT value AS large FROM RANGE(10, 12)
)
SELECT * FROM small_range
"""
result = run_query(query)
expected_values = [1, 2, 3]
actual_values = [int(row['small']) for row in result]
assert actual_values == expected_values, f"Single CTE selection should work"
print("✓ Multiple RANGE in CTE (limited) test passed")
def test_range_with_calculations():
query = """
SELECT
value,
value * value AS squared,
value * value * value AS cubed,
SQRT(value) AS square_root
FROM RANGE(1, 5)
"""
result = run_query(query)
for row in result:
n = int(row['value'])
assert int(row['squared']) == n * n
assert int(row['cubed']) == n * n * n
assert abs(float(row['square_root']) - (n ** 0.5)) < 0.001
print("✓ RANGE with calculations test passed")
def test_range_edge_cases():
query = "SELECT value FROM RANGE(5, 5)"
result = run_query(query)
assert len(result) == 1 and int(result[0]['value']) == 5
query = "SELECT value FROM RANGE(-5, -1)"
result = run_query(query)
expected = [-5, -4, -3, -2, -1]
actual = [int(row['value']) for row in result]
assert actual == expected
query = "SELECT value FROM RANGE(0, 100, 25)"
result = run_query(query)
expected = [0, 25, 50, 75, 100]
actual = [int(row['value']) for row in result]
assert actual == expected
print("✓ RANGE edge cases test passed")
def test_range_with_group_by():
query = """
WITH numbers AS (
SELECT
value,
value % 5 AS remainder
FROM RANGE(1, 20)
)
SELECT
remainder,
COUNT(*) AS count,
SUM(value) AS sum_values
FROM numbers
GROUP BY remainder
ORDER BY remainder
"""
result = run_query(query)
for row in result:
assert int(row['count']) == 4, f"Each remainder group should have 4 numbers"
remainder_0 = [r for r in result if int(r['remainder']) == 0][0]
assert int(remainder_0['sum_values']) == 50
print("✓ RANGE with GROUP BY (using CTE workaround) test passed")
def test_range_with_order_by():
query = """
SELECT
value,
value * value AS squared
FROM RANGE(1, 5)
ORDER BY squared DESC
"""
result = run_query(query)
expected_order = [5, 4, 3, 2, 1]
actual_order = [int(row['value']) for row in result]
assert actual_order == expected_order
print("✓ RANGE with ORDER BY test passed")
def test_range_with_window_functions():
query = """
WITH data AS (
SELECT
value,
value % 3 AS grp
FROM RANGE(1, 12)
)
SELECT
value,
grp,
ROW_NUMBER() OVER (PARTITION BY grp ORDER BY value) AS row_num,
SUM(value) OVER (PARTITION BY grp) AS group_sum
FROM data
ORDER BY value
"""
result = run_query(query)
assert int(result[0]['value']) == 1
assert int(result[0]['grp']) == 1
assert int(result[0]['row_num']) == 1
for row in result:
grp = int(row['grp'])
if grp == 0:
assert int(row['group_sum']) == 30, f"Group 0 sum should be 30"
elif grp == 1:
assert int(row['group_sum']) == 22, f"Group 1 sum should be 22"
elif grp == 2:
assert int(row['group_sum']) == 26, f"Group 2 sum should be 26"
print("✓ RANGE with window functions (ROW_NUMBER and SUM) test passed")
def test_cte_range_with_partition():
query = """
WITH base_numbers AS (
SELECT value AS num FROM RANGE(1, 20)
),
categorized AS (
SELECT
num,
CASE
WHEN IS_PRIME(num) = true THEN 'prime'
WHEN num % 2 = 0 THEN 'even'
ELSE 'odd'
END AS category
FROM base_numbers
)
SELECT
num,
category,
ROW_NUMBER() OVER (PARTITION BY category ORDER BY num) AS position_in_category,
COUNT(num) OVER (PARTITION BY category) AS category_count
FROM categorized
ORDER BY num
"""
result = run_query(query)
primes = [r for r in result if r['category'] == 'prime']
evens = [r for r in result if r['category'] == 'even']
odds = [r for r in result if r['category'] == 'odd']
assert len(primes) == 8
prime_2 = [r for r in result if int(r['num']) == 2][0]
assert prime_2['category'] == 'prime'
assert int(prime_2['position_in_category']) == 1
assert int(prime_2['category_count']) == 8
for row in result:
if row['category'] == 'prime':
assert int(row['category_count']) == 8 elif row['category'] == 'even':
assert int(row['category_count']) == 9 elif row['category'] == 'odd':
assert int(row['category_count']) == 3
print("✓ CTE with RANGE and PARTITION BY test passed")
def main():
tests = [
test_basic_range,
test_range_with_step,
test_range_in_cte,
test_nested_cte_with_range,
test_range_with_where_clause,
test_range_with_is_prime,
test_prime_count_in_range,
test_prime_pi_with_range,
test_prime_density_blocks,
test_complex_cte_with_aggregates,
test_multiple_ranges_in_cte,
test_range_with_calculations,
test_range_edge_cases,
test_range_with_group_by,
test_range_with_order_by,
test_range_with_window_functions,
test_cte_range_with_partition
]
passed = 0
failed = 0
print("Running RANGE and CTE tests...")
print("=" * 60)
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"✗ {test.__name__}: {str(e)}")
failed += 1
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
if failed > 0:
sys.exit(1)
if __name__ == "__main__":
main()