from __future__ import annotations
from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
F = _imports.F
Window = _imports.Window
def test_window_constants() -> None:
assert getattr(Window, "unboundedPreceding") == -(2**63)
assert getattr(Window, "currentRow") == 0
assert getattr(Window, "unboundedFollowing") == 2**63 - 1
def test_rows_between_chaining(spark) -> None:
df = spark.createDataFrame(
[
{"dept": "a", "salary": 10},
{"dept": "a", "salary": 20},
{"dept": "b", "salary": 15},
],
schema="dept string, salary int",
)
win = (
Window.partitionBy("dept")
.orderBy("salary")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
out = df.withColumn("rn", F.row_number().over(win))
rows = out.collect()
assert len(rows) == 3
rn_vals = [r["rn"] for r in rows]
assert 1 in rn_vals and 2 in rn_vals
def test_range_between_chaining(spark) -> None:
df = spark.createDataFrame(
[{"g": 1, "v": 1}, {"g": 1, "v": 2}, {"g": 1, "v": 3}],
schema="g int, v int",
)
win = (
Window.partitionBy("g")
.orderBy("v")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
out = df.withColumn("rn", F.row_number().over(win))
rows = out.collect()
assert len(rows) == 3
rn_vals = [r["rn"] for r in rows]
assert rn_vals == [1, 2, 3]