import subprocess
import pytest
from pathlib import Path
from io import StringIO
import pandas as pd
import json
class TestSqlEngine:
@classmethod
def setup_class(cls):
cls.project_root = Path(__file__).parent.parent.parent
cls.sql_cli = str(cls.project_root / "target" / "release" / "sql-cli")
if not Path(cls.sql_cli).exists():
subprocess.run(["cargo", "build", "--release"],
cwd=cls.project_root, check=True)
test_files = [
cls.project_root / "data" / "test_simple_math.csv",
cls.project_root / "data" / "test_simple_strings.csv"
]
if not all(f.exists() for f in test_files):
subprocess.run(["python3", str(cls.project_root / "scripts" / "generate_simple_test.py")],
cwd=cls.project_root, check=True)
def run_query(self, csv_file: str, query: str):
cmd = [
self.sql_cli,
str(self.project_root / "data" / csv_file),
"-q", query,
"-o", "csv"
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
assert result.returncode == 0, f"Query failed: {result.stderr}"
if result.stdout.strip():
return pd.read_csv(StringIO(result.stdout.strip()))
return pd.DataFrame()
def test_addition(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, a + b as result FROM test_simple_math WHERE id = 1")
assert len(df) == 1
assert df.iloc[0]['result'] == 11
def test_multiplication(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, a * b as result FROM test_simple_math WHERE id = 2")
assert len(df) == 1
assert df.iloc[0]['result'] == 40
def test_round_function(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, ROUND(c, 0) as result FROM test_simple_math WHERE id = 3")
assert len(df) == 1
assert df.iloc[0]['result'] == 2
def test_abs_function(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, ABS(a - d) as result FROM test_simple_math WHERE id = 10")
assert len(df) == 1
assert df.iloc[0]['result'] == 80
def test_power_function(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, POWER(a, 2) as result FROM test_simple_math WHERE id = 5")
assert len(df) == 1
assert df.iloc[0]['result'] == 25
def test_sqrt_function(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, SQRT(e) as result FROM test_simple_math WHERE id = 4")
assert len(df) == 1
assert abs(df.iloc[0]['result'] - 4) < 0.001
def test_mod_function(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, MOD(b, 7) as result FROM test_simple_math WHERE id = 3")
assert len(df) == 1
assert df.iloc[0]['result'] == 2
def test_complex_expression(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, ROUND((a + b) * c / 2, 1) as result FROM test_simple_math WHERE id = 2")
assert len(df) == 1
assert abs(df.iloc[0]['result'] - 11.0) < 0.1
def test_contains_method(self):
df = self.run_query("test_simple_strings.csv",
"SELECT id, name FROM test_simple_strings WHERE name.Contains('li')")
assert len(df) == 2
assert set(df['id'].tolist()) == {1, 3}
def test_endswith_method(self):
df = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE email.EndsWith('.com')")
assert len(df) == 8
assert set(df['id'].tolist()) == {1, 3, 4, 5, 6, 8, 9, 10}
def test_startswith_method(self):
df = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE status.StartsWith('A')")
assert len(df) == 6
assert set(df['id'].tolist()) == {1, 3, 5, 6, 7, 9}
def test_trim_method(self):
df = self.run_query("test_simple_strings.csv",
"SELECT id, name.Trim() as trimmed FROM test_simple_strings WHERE id = 4")
assert len(df) == 1
assert df.iloc[0]['trimmed'] == 'David'
def test_length_method(self):
df = self.run_query("test_simple_strings.csv",
"SELECT id, name.Length() as len FROM test_simple_strings WHERE id = 1")
assert len(df) == 1
assert df.iloc[0]['len'] == 5
def test_indexof_method(self):
df = self.run_query("test_simple_strings.csv",
"SELECT id, code.IndexOf('2') as pos FROM test_simple_strings WHERE id = 1")
assert len(df) == 1
assert df.iloc[0]['pos'] == 4
@pytest.mark.skip(reason="Aggregate functions not yet implemented")
def test_count_function(self):
df = self.run_query("test_simple_math.csv",
"SELECT COUNT(*) as total FROM test_simple_math WHERE a <= 5")
assert len(df) == 1
assert df.iloc[0]['total'] == 5
@pytest.mark.skip(reason="Aggregate functions not yet implemented")
def test_sum_function(self):
df = self.run_query("test_simple_math.csv",
"SELECT SUM(a) as total FROM test_simple_math WHERE a <= 5")
assert len(df) == 1
assert df.iloc[0]['total'] == 15
@pytest.mark.skip(reason="Aggregate functions not yet implemented")
def test_avg_function(self):
df = self.run_query("test_simple_math.csv",
"SELECT ROUND(AVG(a), 1) as average FROM test_simple_math WHERE a <= 4")
assert len(df) == 1
assert abs(df.iloc[0]['average'] - 2.5) < 0.1
def test_math_in_where(self):
df = self.run_query("test_simple_math.csv",
"SELECT id FROM test_simple_math WHERE a * b > 100")
assert len(df) == 17
def test_multiple_conditions(self):
df = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE status = 'Active' AND email.EndsWith('.com')")
assert len(df) == 4
assert set(df['id'].tolist()) == {1, 3, 5, 9}
def test_order_by(self):
df = self.run_query("test_simple_math.csv",
"SELECT id, a FROM test_simple_math WHERE a <= 5 ORDER BY a DESC")
assert len(df) == 5
assert df['a'].tolist() == [5, 4, 3, 2, 1]
def test_limit(self):
df = self.run_query("test_simple_math.csv",
"SELECT id FROM test_simple_math ORDER BY id LIMIT 3")
assert len(df) == 3
assert df['id'].tolist() == [1, 2, 3]
if __name__ == "__main__":
pytest.main([__file__, "-v"])