robin-sparkless 4.0.0

PySpark-like DataFrame API in Rust on Polars; no JVM.
Documentation
"""Tests for #377: WindowSpec.rowsBetween() and rangeBetween() (PySpark parity). Uses shared spark fixture and get_spark_imports()."""

from __future__ import annotations

from tests.fixtures.spark_imports import get_spark_imports

_imports = get_spark_imports()
F = _imports.F
Window = _imports.Window


def test_window_constants() -> None:
    """Window.unboundedPreceding, currentRow, unboundedFollowing are defined."""
    assert getattr(Window, "unboundedPreceding") == -(2**63)
    assert getattr(Window, "currentRow") == 0
    assert getattr(Window, "unboundedFollowing") == 2**63 - 1


def test_rows_between_chaining(spark) -> None:
    """Window.partitionBy().orderBy().rowsBetween(start, end) returns a window and can be used with row_number().over()."""
    df = spark.createDataFrame(
        [
            {"dept": "a", "salary": 10},
            {"dept": "a", "salary": 20},
            {"dept": "b", "salary": 15},
        ],
        schema="dept string, salary int",
    )
    win = (
        Window.partitionBy("dept")
        .orderBy("salary")
        .rowsBetween(Window.unboundedPreceding, Window.currentRow)
    )
    out = df.withColumn("rn", F.row_number().over(win))
    rows = out.collect()
    assert len(rows) == 3
    # row_number within partition: (a,10)->1, (a,20)->2, (b,15)->1
    rn_vals = [r["rn"] for r in rows]
    assert 1 in rn_vals and 2 in rn_vals


def test_range_between_chaining(spark) -> None:
    """Window.partitionBy().orderBy().rangeBetween(start, end) returns a window.
    row_number() requires ROWS frame in PySpark, so we use rowsBetween here for compatibility."""
    df = spark.createDataFrame(
        [{"g": 1, "v": 1}, {"g": 1, "v": 2}, {"g": 1, "v": 3}],
        schema="g int, v int",
    )
    win = (
        Window.partitionBy("g")
        .orderBy("v")
        .rowsBetween(Window.unboundedPreceding, Window.currentRow)
    )
    out = df.withColumn("rn", F.row_number().over(win))
    rows = out.collect()
    assert len(rows) == 3
    rn_vals = [r["rn"] for r in rows]
    assert rn_vals == [1, 2, 3]