from __future__ import annotations
from sparkless.testing import get_imports
from tests.utils import assert_rows_equal
_imports = get_imports()
F = _imports.F
INPUT_EMPLOYEES = [
{"id": 1, "name": "Alice", "age": 25, "salary": 50000, "department": "IT"},
{"id": 2, "name": "Bob", "age": 30, "salary": 60000, "department": "HR"},
{"id": 3, "name": "Charlie", "age": 35, "salary": 70000, "department": "IT"},
{"id": 4, "name": "David", "age": 40, "salary": 80000, "department": "Finance"},
]
INPUT_EMPLOYEES_TUPLES = [
(1, "Alice", 25, 50000, "IT"),
(2, "Bob", 30, 60000, "HR"),
(3, "Charlie", 35, 70000, "IT"),
(4, "David", 40, 80000, "Finance"),
]
SCHEMA_EMPLOYEES = [
"id",
"name",
"age",
"salary",
"department",
]
def test_filter_salary_gt_60000(spark) -> None:
df = spark.createDataFrame(INPUT_EMPLOYEES_TUPLES, SCHEMA_EMPLOYEES)
result = df.filter(F.col("salary") > F.lit(60000))
rows = [r.asDict() for r in result.collect()]
expected = [
{"age": 35, "department": "IT", "id": 3, "name": "Charlie", "salary": 70000},
{"age": 40, "department": "Finance", "id": 4, "name": "David", "salary": 80000},
]
assert_rows_equal(rows, expected, order_matters=False)
def test_filter_and_operator(spark) -> None:
data = [{"a": 1, "b": 2}, {"a": 2, "b": 3}, {"a": 3, "b": 1}]
schema = ["a", "b"]
df = spark.createDataFrame(data, schema)
result = df.filter((F.col("a") > F.lit(1)) & (F.col("b") > F.lit(1)))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["a"] == 2 and rows[0]["b"] == 3
def test_filter_or_operator(spark) -> None:
data = [{"a": 1, "b": 2}, {"a": 2, "b": 3}, {"a": 3, "b": 1}]
schema = ["a", "b"]
df = spark.createDataFrame(data, schema)
result = df.filter((F.col("a") > F.lit(1)) | (F.col("b") > F.lit(1)))
rows = result.collect()
assert len(rows) == 3
def test_basic_select(spark) -> None:
df = spark.createDataFrame(INPUT_EMPLOYEES_TUPLES, SCHEMA_EMPLOYEES)
result = df.select("id", "name", "age")
rows = [r.asDict() for r in result.collect()]
expected = [
{"id": 1, "name": "Alice", "age": 25},
{"id": 2, "name": "Bob", "age": 30},
{"id": 3, "name": "Charlie", "age": 35},
{"id": 4, "name": "David", "age": 40},
]
assert_rows_equal(rows, expected, order_matters=True)
def test_select_with_alias(spark) -> None:
df = spark.createDataFrame(INPUT_EMPLOYEES_TUPLES, SCHEMA_EMPLOYEES)
result = df.select(F.col("id").alias("user_id"), F.col("name").alias("full_name"))
rows = [r.asDict() for r in result.collect()]
expected = [
{"user_id": 1, "full_name": "Alice"},
{"user_id": 2, "full_name": "Bob"},
{"user_id": 3, "full_name": "Charlie"},
{"user_id": 4, "full_name": "David"},
]
assert_rows_equal(rows, expected, order_matters=True)
def test_aggregation_avg_count(spark) -> None:
df = spark.createDataFrame(INPUT_EMPLOYEES_TUPLES, SCHEMA_EMPLOYEES)
result = df.groupBy("department").agg(
F.avg("salary").alias("avg_salary"),
F.count("id").alias("count"),
)
rows = [r.asDict() for r in result.collect()]
expected = [
{"department": "Finance", "avg_salary": 80000.0, "count": 1},
{"department": "HR", "avg_salary": 60000.0, "count": 1},
{"department": "IT", "avg_salary": 60000.0, "count": 2},
]
assert_rows_equal(rows, expected, order_matters=False)
def test_inner_join(spark) -> None:
employees_data = [
(1, "Alice", 10, 50000),
(2, "Bob", 20, 60000),
(3, "Charlie", 10, 70000),
(4, "David", 30, 55000),
]
departments_data = [
(10, "IT", "NYC"),
(20, "HR", "LA"),
(40, "Finance", "Chicago"),
]
emp_schema = [
"id",
"name",
"dept_id",
"salary",
]
dept_schema = ["dept_id", "name", "location"]
emp_df = spark.createDataFrame(employees_data, emp_schema)
dept_df = spark.createDataFrame(departments_data, dept_schema)
result = emp_df.join(dept_df, on="dept_id", how="inner")
rows = result.collect()
assert len(rows) == 3
by_id = {r["id"]: r for r in rows}
assert 1 in by_id and by_id[1]["dept_id"] == 10 and by_id[1]["salary"] == 50000
assert 2 in by_id and by_id[2]["dept_id"] == 20
assert 3 in by_id and by_id[3]["dept_id"] == 10 and by_id[3]["salary"] == 70000