from tests.tools.parity_base import ParityTestBase
class TestSQLAdvancedParity(ParityTestBase):
def test_sql_with_inner_join(self, spark):
emp_data = [("Alice", 1), ("Bob", 2), ("Charlie", 3)]
emp_df = spark.createDataFrame(emp_data, ["name", "dept_id"])
emp_df.write.mode("overwrite").saveAsTable("employees")
dept_data = [(1, "IT"), (2, "HR")]
dept_df = spark.createDataFrame(dept_data, ["id", "name"])
dept_df.write.mode("overwrite").saveAsTable("departments")
result = spark.sql("""
SELECT e.name, d.name as dept_name
FROM employees e
INNER JOIN departments d ON e.dept_id = d.id
""")
rows = result.collect()
assert len(rows) == 2
spark.sql("DROP TABLE IF EXISTS employees")
spark.sql("DROP TABLE IF EXISTS departments")
def test_sql_with_left_join(self, spark):
emp_data = [("Alice", 1), ("Bob", 2), ("Charlie", 99)]
emp_df = spark.createDataFrame(emp_data, ["name", "dept_id"])
emp_df.write.mode("overwrite").saveAsTable("employees2")
dept_data = [(1, "IT"), (2, "HR")]
dept_df = spark.createDataFrame(dept_data, ["id", "name"])
dept_df.write.mode("overwrite").saveAsTable("departments2")
result = spark.sql("""
SELECT e.name, d.name as dept_name
FROM employees2 e
LEFT JOIN departments2 d ON e.dept_id = d.id
""")
rows = result.collect()
assert len(rows) == 3
spark.sql("DROP TABLE IF EXISTS employees2")
spark.sql("DROP TABLE IF EXISTS departments2")
def test_sql_with_order_by(self, spark):
data = [("Alice", 30), ("Bob", 25), ("Charlie", 35)]
df = spark.createDataFrame(data, ["name", "age"])
df.write.mode("overwrite").saveAsTable("order_test")
result = spark.sql("SELECT * FROM order_test ORDER BY age")
rows = result.collect()
assert rows[0]["name"] == "Bob"
assert rows[1]["name"] == "Alice"
assert rows[2]["name"] == "Charlie"
result = spark.sql("SELECT * FROM order_test ORDER BY age DESC")
rows = result.collect()
assert rows[0]["name"] == "Charlie"
assert rows[2]["name"] == "Bob"
spark.sql("DROP TABLE IF EXISTS order_test")
def test_sql_with_limit(self, spark):
data = [("Alice", 25), ("Bob", 30), ("Charlie", 35), ("David", 40)]
df = spark.createDataFrame(data, ["name", "age"])
df.write.mode("overwrite").saveAsTable("limit_test")
result = spark.sql("SELECT * FROM limit_test LIMIT 2")
assert result.count() == 2
spark.sql("DROP TABLE IF EXISTS limit_test")
def test_sql_with_having(self, spark):
data = [("Alice", "IT", 50000), ("Bob", "IT", 60001), ("Charlie", "HR", 55000)]
df = spark.createDataFrame(data, ["name", "dept", "salary"])
df.write.mode("overwrite").saveAsTable("having_test")
result = spark.sql("""
SELECT dept, AVG(salary) as avg_salary
FROM having_test
GROUP BY dept
HAVING AVG(salary) > 55000
""")
rows = result.collect()
assert len(rows) == 1 assert rows[0]["dept"] == "IT"
spark.sql("DROP TABLE IF EXISTS having_test")
def test_sql_with_union(self, spark):
data1 = [("Alice", 25), ("Bob", 30)]
df1 = spark.createDataFrame(data1, ["name", "age"])
df1.write.mode("overwrite").saveAsTable("union_table1")
data2 = [("Charlie", 35), ("David", 40)]
df2 = spark.createDataFrame(data2, ["name", "age"])
df2.write.mode("overwrite").saveAsTable("union_table2")
result = spark.sql("""
SELECT name, age FROM union_table1
UNION
SELECT name, age FROM union_table2
""")
assert result.count() == 4
spark.sql("DROP TABLE IF EXISTS union_table1")
spark.sql("DROP TABLE IF EXISTS union_table2")
def test_sql_with_subquery(self, spark):
data = [("Alice", 50000), ("Bob", 60000), ("Charlie", 70000)]
df = spark.createDataFrame(data, ["name", "salary"])
df.write.mode("overwrite").saveAsTable("subquery_test")
result = spark.sql("""
SELECT name, salary
FROM subquery_test
WHERE salary > (SELECT AVG(salary) FROM subquery_test)
""")
rows = result.collect()
assert len(rows) == 1
assert rows[0]["name"] == "Charlie"
spark.sql("DROP TABLE IF EXISTS subquery_test")
def test_sql_with_case_when(self, spark):
data = [("Alice", 25), ("Bob", 30), ("Charlie", 35)]
df = spark.createDataFrame(data, ["name", "age"])
df.write.mode("overwrite").saveAsTable("case_test")
result = spark.sql("""
SELECT name, age,
CASE WHEN age < 30 THEN 'Young'
WHEN age < 35 THEN 'Middle'
ELSE 'Senior' END as category
FROM case_test
""")
rows = result.collect()
assert len(rows) == 3
assert rows[0]["category"] == "Young"
assert rows[1]["category"] == "Middle"
assert rows[2]["category"] == "Senior"
spark.sql("DROP TABLE IF EXISTS case_test")
def test_sql_with_like(self, spark):
data = [("Alice",), ("Bob",), ("Charlie",), ("David",)]
df = spark.createDataFrame(data, ["name"])
df.write.mode("overwrite").saveAsTable("like_test")
result = spark.sql("SELECT * FROM like_test WHERE name LIKE 'A%'")
rows = result.collect()
assert len(rows) == 1
assert rows[0]["name"] == "Alice"
spark.sql("DROP TABLE IF EXISTS like_test")
def test_sql_with_in_clause(self, spark):
data = [("Alice", 25), ("Bob", 30), ("Charlie", 35)]
df = spark.createDataFrame(data, ["name", "age"])
df.write.mode("overwrite").saveAsTable("in_test")
result = spark.sql("SELECT * FROM in_test WHERE age IN (25, 35)")
rows = result.collect()
assert len(rows) == 2
names = {row["name"] for row in rows}
assert "Alice" in names
assert "Charlie" in names
spark.sql("DROP TABLE IF EXISTS in_test")