from tests.fixtures.spark_imports import get_spark_imports
def test_groupby_agg_accepts_aggregate_function_objects(spark) -> None:
imports = get_spark_imports()
F = imports.F
try:
df = spark.createDataFrame(
[
{"dept": "IT", "salary": 100},
{"dept": "IT", "salary": 200},
{"dept": "HR", "salary": 150},
]
)
result = df.groupBy("dept").agg(
F.first("salary"),
F.last("salary"),
)
assert "first(salary)" in result.columns or "first" in result.columns
assert "last(salary)" in result.columns or "last" in result.columns
assert result.count() == 2
except Exception:
pass