import subprocess
import csv
import io
from pathlib import Path
import tempfile
SQL_CLI = Path(__file__).parent.parent.parent / "target" / "release" / "sql-cli"
def run_query(csv_file, query):
result = subprocess.run(
[str(SQL_CLI), str(csv_file), "-q", query, "-o", "csv"],
capture_output=True,
text=True
)
if result.returncode != 0:
raise Exception(f"Query failed: {result.stderr}")
output_lines = result.stdout.strip().split('\n')
csv_lines = [line for line in output_lines if not line.startswith('#')]
if not csv_lines:
return []
reader = csv.DictReader(io.StringIO('\n'.join(csv_lines)))
return list(reader)
def test_null_arithmetic_with_lag():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("id,value\n")
f.write("1,100\n")
f.write("2,200\n")
f.write("3,300\n")
temp_file = f.name
try:
query = """
SELECT
id,
value,
LAG(value, 1) OVER (ORDER BY id) as prev,
value - LAG(value, 1) OVER (ORDER BY id) as diff
FROM test
"""
results = run_query(temp_file, query)
assert results[0]['prev'] == '' assert results[0]['diff'] == ''
assert results[1]['prev'] == '100'
assert results[1]['diff'] == '100'
print("✓ test_null_arithmetic_with_lag passed")
finally:
Path(temp_file).unlink()
def test_null_arithmetic_operations():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("a,b\n")
f.write("10,5\n")
f.write("20,\n") f.write(",30\n") temp_file = f.name
try:
query = "SELECT a, b, a + b as sum FROM test"
results = run_query(temp_file, query)
assert results[0]['sum'] == '15' assert results[1]['sum'] == '' assert results[2]['sum'] == ''
query = "SELECT a, b, a - b as diff FROM test"
results = run_query(temp_file, query)
assert results[0]['diff'] == '5' assert results[1]['diff'] == '' assert results[2]['diff'] == ''
query = "SELECT a, b, a * b as product FROM test"
results = run_query(temp_file, query)
assert results[0]['product'] == '50' assert results[1]['product'] == '' assert results[2]['product'] == ''
query = "SELECT a, b, a / b as quotient FROM test"
results = run_query(temp_file, query)
assert results[0]['quotient'] == '2' assert results[1]['quotient'] == '' assert results[2]['quotient'] == ''
print("✓ test_null_arithmetic_operations passed")
finally:
Path(temp_file).unlink()
def test_complex_null_expressions():
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
f.write("x,y,z\n")
f.write("10,20,30\n")
f.write("5,,15\n") temp_file = f.name
try:
query = "SELECT x, y, z, (x + y) * z as result FROM test"
results = run_query(temp_file, query)
assert results[0]['result'] == '900' assert results[1]['result'] == ''
print("✓ test_complex_null_expressions passed")
finally:
Path(temp_file).unlink()
def main():
print("Running NULL arithmetic tests...")
if not SQL_CLI.exists():
print(f"Error: sql-cli not found at {SQL_CLI}")
print("Please run: cargo build --release")
return 1
tests = [
test_null_arithmetic_with_lag,
test_null_arithmetic_operations,
test_complex_null_expressions,
]
failed = 0
for test in tests:
try:
test()
except Exception as e:
print(f"✗ {test.__name__} failed: {e}")
failed += 1
if failed == 0:
print(f"\n✅ All {len(tests)} NULL arithmetic tests passed!")
else:
print(f"\n❌ {failed}/{len(tests)} tests failed")
return 1
return 0
if __name__ == "__main__":
exit(main())