import subprocess
import pytest
from pathlib import Path
from io import StringIO
import pandas as pd
class TestStringMethods:
@classmethod
def setup_class(cls):
cls.project_root = Path(__file__).parent.parent.parent
cls.sql_cli = str(cls.project_root / "target" / "release" / "sql-cli")
if not Path(cls.sql_cli).exists():
subprocess.run(["cargo", "build", "--release"],
cwd=cls.project_root, check=True)
def run_query(self, csv_file: str, query: str):
cmd = [
self.sql_cli,
str(self.project_root / "data" / csv_file),
"-q", query,
"-o", "csv"
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=5)
if result.returncode != 0:
return None, result.stderr
if result.stdout.strip():
return pd.read_csv(StringIO(result.stdout.strip())), None
return pd.DataFrame(), None
def test_trim_removes_both_spaces(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name.Trim() as trimmed FROM test_simple_strings WHERE id IN (4, 7, 8)")
assert len(df) == 3
assert df[df['id'] == 4]['trimmed'].iloc[0] == 'David'
assert df[df['id'] == 7]['trimmed'].iloc[0] == 'Grace'
assert df[df['id'] == 8]['trimmed'].iloc[0] == 'Henry'
def test_trimstart_removes_leading_spaces(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name.TrimStart() as trimmed FROM test_simple_strings WHERE id = 7")
assert len(df) == 1
assert df.iloc[0]['trimmed'] == 'Grace'
def test_trimend_removes_trailing_spaces(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name.TrimEnd() as trimmed FROM test_simple_strings WHERE id = 8")
assert len(df) == 1
assert df.iloc[0]['trimmed'] == 'Henry'
def test_length_counts_characters(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name, name.Length() as len FROM test_simple_strings WHERE id IN (1, 4)")
assert len(df) == 2
assert df[df['id'] == 1]['len'].iloc[0] == 5 assert df[df['id'] == 4]['len'].iloc[0] == 9
def test_length_in_where(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name FROM test_simple_strings WHERE name.Length() > 6")
assert len(df) == 4 assert set(df['id'].tolist()) == {3, 4, 7, 8}
def test_indexof_finds_substring(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, email, email.IndexOf('@') as at_pos FROM test_simple_strings WHERE id = 1")
assert len(df) == 1
assert df.iloc[0]['at_pos'] == 5
def test_indexof_returns_minus_one(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name, name.IndexOf('x') as pos FROM test_simple_strings WHERE id = 1")
assert len(df) == 1
assert df.iloc[0]['pos'] == -1
def test_indexof_in_where(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE email.IndexOf('gmail') > 0")
assert len(df) == 1
assert df.iloc[0]['id'] == 5
def test_contains_finds_substring(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE email.Contains('example')")
assert len(df) == 4 assert set(df['id'].tolist()) == {1, 4, 7, 10}
def test_contains_case_sensitive(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE status.Contains('active')")
assert len(df) == 7 assert set(df['id'].tolist()) == {1, 2, 3, 5, 7, 8, 9}
def test_startswith_prefix_check(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE code.StartsWith('ABC')")
assert len(df) == 1
assert df.iloc[0]['id'] == 1
def test_startswith_case_sensitive(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE status.StartsWith('a')")
assert len(df) == 6 assert set(df['id'].tolist()) == {1, 3, 5, 6, 7, 9}
def test_endswith_suffix_check(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE email.EndsWith('.org')")
assert len(df) == 2
assert set(df['id'].tolist()) == {2, 7}
def test_endswith_with_spaces(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE name.EndsWith(' ')")
assert len(df) == 2 assert set(df['id'].tolist()) == {4, 8}
def test_trim_then_length(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name.Trim().Length() as trimmed_len FROM test_simple_strings WHERE id = 4")
assert len(df) == 1
assert df.iloc[0]['trimmed_len'] == 5
def test_multiple_string_methods_in_select(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name.Trim() as trimmed, email.IndexOf('@') as at_pos, code.Length() as code_len FROM test_simple_strings WHERE id = 1")
assert len(df) == 1
assert df.iloc[0]['trimmed'] == 'Alice'
assert df.iloc[0]['at_pos'] == 5
assert df.iloc[0]['code_len'] == 6
def test_string_method_with_arithmetic(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name.Length() + email.Length() as total_len FROM test_simple_strings WHERE id = 1")
assert len(df) == 1
assert df.iloc[0]['total_len'] == 22
def test_method_on_empty_string(self):
pass
def test_method_with_special_characters(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE email.Contains('.')")
assert len(df) == 10
def test_simple_method_chaining(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id, name.Trim().Length() as result FROM test_simple_strings WHERE id = 4")
assert len(df) == 1
assert df.iloc[0]['result'] == 5
def test_triple_method_chaining(self):
df, err = self.run_query("test_simple_strings.csv",
"SELECT id, name.Trim().TrimStart().Length() as result FROM test_simple_strings WHERE id = 4")
if err:
pytest.skip("Triple method chaining not yet supported")
else:
assert len(df) == 1
assert df.iloc[0]['result'] == 5
def test_string_methods_with_and(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE name.StartsWith('A') AND email.EndsWith('.com')")
assert len(df) == 1
assert df.iloc[0]['id'] == 1
def test_string_methods_with_or(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE name.StartsWith('A') OR name.StartsWith('B')")
assert len(df) == 2
assert set(df['id'].tolist()) == {1, 2}
def test_string_method_with_not_equal(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT id FROM test_simple_strings WHERE name.Length() != 5")
assert len(df) == 8 assert set(df['id'].tolist()) == {2, 3, 4, 5, 7, 8, 9, 10}
def test_string_method_on_all_rows(self):
df, _ = self.run_query("test_simple_strings.csv",
"SELECT name.Trim() as trimmed FROM test_simple_strings")
assert len(df) == 10
assert 'David' in df['trimmed'].values
assert 'Grace' in df['trimmed'].values
if __name__ == "__main__":
pytest.main([__file__, "-v"])