def test_cte_with_join(spark) -> None:
try:
employees = spark.createDataFrame(
[(1, "Alice", 10), (2, "Bob", 20), (3, "Carol", 10)],
["id", "name", "dept_id"],
)
employees.write.mode("overwrite").saveAsTable("employees")
departments = spark.createDataFrame(
[(10, "Engineering"), (20, "Sales")],
["id", "dept_name"],
)
departments.write.mode("overwrite").saveAsTable("departments")
result = spark.sql(
"SELECT e.id, e.name, d.dept_name FROM employees e "
"JOIN departments d ON e.dept_id = d.id"
)
rows = result.collect()
assert len(rows) == 3
assert "e_id" in result.columns or "id" in result.columns
assert "e_name" in result.columns or "name" in result.columns
assert "d_dept_name" in result.columns or "dept_name" in result.columns
def _val(r, *keys):
for k in keys:
if k in result.columns:
return r[k]
return None
by_name = {
_val(r, "e_name", "name"): _val(r, "d_dept_name", "dept_name") for r in rows
}
assert by_name["Alice"] == "Engineering"
assert by_name["Bob"] == "Sales"
assert by_name["Carol"] == "Engineering"
finally:
spark.sql("DROP TABLE IF EXISTS employees")
spark.sql("DROP TABLE IF EXISTS departments")
def test_cte_with_multiple_joins(spark) -> None:
try:
employees = spark.createDataFrame(
[(1, "Alice", 10, 100), (2, "Bob", 20, 200)],
["id", "name", "dept_id", "project_id"],
)
employees.write.mode("overwrite").saveAsTable("employees")
departments = spark.createDataFrame(
[(10, "Engineering"), (20, "Sales")],
["id", "dept_name"],
)
departments.write.mode("overwrite").saveAsTable("departments")
projects = spark.createDataFrame(
[(100, "ProjectA"), (200, "ProjectB")],
["id", "project_name"],
)
projects.write.mode("overwrite").saveAsTable("projects")
result = spark.sql(
"""
SELECT e.name, d.dept_name, p.project_name
FROM employees e
JOIN departments d ON e.dept_id = d.id
JOIN projects p ON e.project_id = p.id
"""
)
rows = result.collect()
assert len(rows) == 2
def _val(r, *keys):
for k in keys:
if k in result.columns:
return r[k]
return None
alice = [r for r in rows if _val(r, "e_name", "name") == "Alice"][0]
assert _val(alice, "d_dept_name", "dept_name") == "Engineering"
assert _val(alice, "p_project_name", "project_name") == "ProjectA"
finally:
spark.sql("DROP TABLE IF EXISTS employees")
spark.sql("DROP TABLE IF EXISTS departments")
spark.sql("DROP TABLE IF EXISTS projects")
def test_cte_with_left_join(spark) -> None:
try:
employees = spark.createDataFrame(
[(1, "Alice", 10), (2, "Bob", 30)], ["id", "name", "dept_id"],
)
employees.write.mode("overwrite").saveAsTable("employees")
departments = spark.createDataFrame(
[(10, "Engineering")],
["id", "dept_name"],
)
departments.write.mode("overwrite").saveAsTable("departments")
result = spark.sql(
"SELECT e.name, d.dept_name FROM employees e "
"LEFT JOIN departments d ON e.dept_id = d.id"
)
rows = result.collect()
assert len(rows) == 2
def _val(r, *keys):
for k in keys:
if k in result.columns:
return r[k]
return None
bob = [r for r in rows if _val(r, "e_name", "name") == "Bob"][0]
assert _val(bob, "d_dept_name", "dept_name") is None
finally:
spark.sql("DROP TABLE IF EXISTS employees")
spark.sql("DROP TABLE IF EXISTS departments")
def test_cte_with_where_clause(spark) -> None:
try:
employees = spark.createDataFrame(
[(1, "Alice", 10, 50000), (2, "Bob", 20, 60000), (3, "Carol", 10, 70000)],
["id", "name", "dept_id", "salary"],
)
employees.write.mode("overwrite").saveAsTable("employees")
departments = spark.createDataFrame(
[(10, "Engineering"), (20, "Sales")],
["id", "dept_name"],
)
departments.write.mode("overwrite").saveAsTable("departments")
result = spark.sql(
"""
SELECT e.name, d.dept_name, e.salary
FROM employees e
JOIN departments d ON e.dept_id = d.id
WHERE e.salary > 55000
"""
)
rows = result.collect()
assert len(rows) >= 2
def _val(r, *keys):
for k in keys:
if k in result.columns:
return r[k]
return None
names = [_val(r, "e_name", "name") for r in rows]
assert "Bob" in names
assert "Carol" in names
finally:
spark.sql("DROP TABLE IF EXISTS employees")
spark.sql("DROP TABLE IF EXISTS departments")
def test_cte_with_aggregation_after_join(spark) -> None:
try:
employees = spark.createDataFrame(
[(1, "Alice", 10), (2, "Bob", 20), (3, "Carol", 10), (4, "Dave", 20)],
["id", "name", "dept_id"],
)
employees.write.mode("overwrite").saveAsTable("employees")
departments = spark.createDataFrame(
[(10, "Engineering"), (20, "Sales")],
["id", "dept_name"],
)
departments.write.mode("overwrite").saveAsTable("departments")
result = spark.sql(
"""
SELECT d.dept_name, COUNT(*) as emp_count
FROM employees e
JOIN departments d ON e.dept_id = d.id
GROUP BY d.dept_name
"""
)
rows = result.collect()
assert len(rows) == 2
def _val(r, *keys):
for k in keys:
if k in result.columns:
return r[k]
return None
counts = {
_val(r, "d_dept_name", "dept_name"): _val(r, "emp_count") for r in rows
}
assert counts["Engineering"] == 2
assert counts["Sales"] == 2
finally:
spark.sql("DROP TABLE IF EXISTS employees")
spark.sql("DROP TABLE IF EXISTS departments")
def test_cte_with_self_join(spark) -> None:
try:
employees = spark.createDataFrame(
[(1, "Alice", None), (2, "Bob", 1), (3, "Carol", 1)],
["id", "name", "manager_id"],
)
employees.write.mode("overwrite").saveAsTable("employees")
result = spark.sql(
"""
SELECT e.name as employee, m.name as manager
FROM employees e
LEFT JOIN employees m ON e.manager_id = m.id
"""
)
rows = result.collect()
assert len(rows) == 3
def _val(r, *keys):
for k in keys:
if k in result.columns:
return r[k]
return None
managed = [r for r in rows if _val(r, "manager", "m_name") == "Alice"]
assert len(managed) >= 1
finally:
spark.sql("DROP TABLE IF EXISTS employees")