import pytest
from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
SparkSession = _imports.SparkSession
F = _imports.F
Window = _imports.Window
StructType = _imports.StructType
StructField = _imports.StructField
StringType = _imports.StringType
IntegerType = _imports.IntegerType
LongType = _imports.LongType
DoubleType = _imports.DoubleType
BooleanType = _imports.BooleanType
class TestIssue336WindowFunctionComparison:
def test_window_function_gt_comparison(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"GT-Zero",
F.when(F.row_number().over(w) > 0, F.lit(True)).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["GT-Zero"] is True
assert rows[1]["GT-Zero"] is True
finally:
spark.stop()
def test_window_function_lt_comparison(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"LT-Five",
F.when(F.row_number().over(w) < 5, F.lit(True)).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["LT-Five"] is True
assert rows[1]["LT-Five"] is True
finally:
spark.stop()
def test_window_function_ge_comparison(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"GE-One",
F.when(F.row_number().over(w) >= 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["GE-One"] is True
assert rows[1]["GE-One"] is True
finally:
spark.stop()
def test_window_function_le_comparison(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"LE-One",
F.when(F.row_number().over(w) <= 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["LE-One"] is True
assert rows[1]["LE-One"] is True
finally:
spark.stop()
def test_window_function_eq_comparison(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"EQ-One",
F.when(F.row_number().over(w) == 1, F.lit("First")).otherwise(
F.lit("Other")
),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["EQ-One"] == "First"
assert rows[1]["EQ-One"] == "First"
finally:
spark.stop()
def test_window_function_ne_comparison(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"NE-Zero",
F.when(F.row_number().over(w) != 0, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["NE-Zero"] is True
assert rows[1]["NE-Zero"] is True
finally:
spark.stop()
def test_window_function_comparison_with_filter(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
with pytest.raises(Exception) as exc_info:
(
df.filter(F.row_number().over(w) == 1)
.select("Name", "Type", "Score")
.collect()
)
msg = str(exc_info.value)
assert "window functions inside WHERE clause" in msg
finally:
spark.stop()
def test_window_function_comparison_with_multiple_conditions(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
{"Name": "Charlie", "Type": "C"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"Category",
F.when(F.row_number().over(w) == 1, F.lit("First"))
.when(F.row_number().over(w) == 2, F.lit("Second"))
.otherwise(F.lit("Other")),
)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["Category"] == "First"
assert rows[1]["Category"] == "First"
assert rows[2]["Category"] == "First"
finally:
spark.stop()
def test_window_function_comparison_with_rank(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 100},
{"Name": "Charlie", "Type": "A", "Score": 90},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"TopRank",
F.when(F.rank().over(w) <= 2, F.lit(True)).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["TopRank"] is True
assert rows[1]["TopRank"] is True
finally:
spark.stop()
def test_window_function_comparison_with_dense_rank(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 100},
{"Name": "Charlie", "Type": "A", "Score": 90},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"TopDenseRank",
F.when(F.dense_rank().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
true_count = sum(1 for row in rows if row["TopDenseRank"] is True)
assert true_count == 2
finally:
spark.stop()
def test_window_function_comparison_with_percent_rank(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"TopPercent",
F.when(F.percent_rank().over(w) == 0.0, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["TopPercent"] is True
finally:
spark.stop()
def test_window_function_comparison_with_lag(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"HasPrevious",
F.when(F.isnull(F.lag("Score", 1).over(w)), F.lit(False)).otherwise(
F.lit(True)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["HasPrevious"] is False
for row in rows:
if row["Name"] != "Alice":
assert row["HasPrevious"] is True
finally:
spark.stop()
def test_window_function_comparison_with_lead(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"HasNext",
F.when(F.isnull(F.lead("Score", 1).over(w)), F.lit(False)).otherwise(
F.lit(True)
),
)
rows = result.collect()
assert len(rows) == 3
charlie_row = next(row for row in rows if row["Name"] == "Charlie")
assert charlie_row["HasNext"] is False
other_rows = [row for row in rows if row["Name"] != "Charlie"]
for row in other_rows:
assert row["HasNext"] is True
finally:
spark.stop()
def test_window_function_comparison_with_sum(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"HighRunningSum",
F.when(F.sum("Score").over(w) > 150, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["HighRunningSum"] is False
other_rows = [row for row in rows if row["Name"] != "Alice"]
for row in other_rows:
assert row["HighRunningSum"] is True
finally:
spark.stop()
def test_window_function_comparison_direct_filter(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
with pytest.raises(Exception) as exc_info:
(
df.filter(F.row_number().over(w) == 1)
.select("Name", "Type", "Score")
.collect()
)
msg = str(exc_info.value)
assert "window functions inside WHERE clause" in msg
finally:
spark.stop()
def test_window_function_comparison_with_eqNullSafe(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"EQ-One-NullSafe",
F.when(F.row_number().over(w).eqNullSafe(1), F.lit("First")).otherwise(
F.lit("Other")
),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["EQ-One-NullSafe"] == "First"
assert rows[1]["EQ-One-NullSafe"] == "First"
finally:
spark.stop()
def test_window_function_comparison_with_isnotnull(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"HasNext",
F.when(F.isnotnull(F.lead("Score", 1).over(w)), F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
charlie_row = next(row for row in rows if row["Name"] == "Charlie")
assert charlie_row["HasNext"] is False
for row in rows:
if row["Name"] != "Charlie":
assert row["HasNext"] is True
finally:
spark.stop()
def test_window_function_comparison_with_null_values(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": None},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"HasPrevious",
F.when(F.isnull(F.lag("Score", 1).over(w)), F.lit("NoPrev")).otherwise(
F.lit("HasPrev")
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["HasPrevious"] == "NoPrev"
finally:
spark.stop()
def test_window_function_comparison_with_empty_dataframe(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame([], schema="Name string, Type string")
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"GT-Zero",
F.when(F.row_number().over(w) > 0, F.lit(True)).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 0
finally:
spark.stop()
def test_window_function_comparison_with_single_row(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame([{"Name": "Alice", "Type": "A"}])
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"EQ-One",
F.when(F.row_number().over(w) == 1, F.lit("First")).otherwise(
F.lit("Other")
),
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["EQ-One"] == "First"
finally:
spark.stop()
def test_window_function_comparison_with_large_dataset(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
data = [
{"Name": f"Person{i}", "Type": "A", "Score": 100 - i} for i in range(20)
]
df = spark.createDataFrame(data)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"TopTen",
F.when(F.row_number().over(w) <= 10, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 20
top_ten_count = sum(1 for row in rows if row["TopTen"] is True)
assert top_ten_count == 10
finally:
spark.stop()
def test_window_function_comparison_with_multiple_window_functions(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
).withColumn(
"TopRank",
F.when(F.rank().over(w) == 1, F.lit(True)).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 3
assert "IsFirst" in rows[0]
assert "TopRank" in rows[0]
finally:
spark.stop()
def test_window_function_comparison_with_arithmetic_operations(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"RankPlusOne", (F.row_number().over(w) + 1)
).withColumn(
"GT-Two",
F.when(F.col("RankPlusOne") > 2, F.lit(True)).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["GT-Two"] is False
for row in rows:
if row["Name"] != "Alice":
assert row["GT-Two"] is True
finally:
spark.stop()
def test_window_function_comparison_with_select(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
).select("Name", "Type", "IsFirst")
rows = result.collect()
assert len(rows) == 2
assert "IsFirst" in rows[0]
finally:
spark.stop()
def test_window_function_comparison_with_orderby(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
).orderBy("Name")
rows = result.collect()
assert len(rows) == 3
assert rows[0]["Name"] in ["Alice", "Bob", "Charlie"]
finally:
spark.stop()
def test_window_function_comparison_with_groupby(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "B", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = (
df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
.groupBy("Type")
.agg(F.max("Score").alias("MaxScore"))
)
rows = result.collect()
assert len(rows) == 2
types = {row["Type"] for row in rows}
assert types == {"A", "B"}
finally:
spark.stop()
def test_window_function_comparison_with_join(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
]
)
df2 = spark.createDataFrame(
[
{"Type": "A", "Dept": "Engineering"},
{"Type": "B", "Dept": "Sales"},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df1.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
).join(df2, on="Type", how="left")
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Dept"] == "Engineering"
finally:
spark.stop()
def test_window_function_comparison_with_union(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df1 = spark.createDataFrame([{"Name": "Alice", "Type": "A", "Score": 100}])
df2 = spark.createDataFrame([{"Name": "Bob", "Type": "B", "Score": 90}])
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result1 = df1.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
result2 = df2.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
combined = result1.unionByName(result2, allowMissingColumns=True)
rows = combined.collect()
assert len(rows) == 2
names = {row["Name"] for row in rows}
assert names == {"Alice", "Bob"}
finally:
spark.stop()
def test_window_function_comparison_with_distinct(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Alice", "Type": "A", "Score": 100}, ]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = (
df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
.select("Name", "Type", "Score")
.distinct()
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["Name"] == "Alice"
finally:
spark.stop()
def test_window_function_comparison_with_limit(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
).limit(2)
rows = result.collect()
assert len(rows) == 2
finally:
spark.stop()
def test_window_function_comparison_chained_operations(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "B", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = (
df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
.filter(F.col("IsFirst"))
.select("Name", "Type", "Score")
)
rows = result.collect()
assert len(rows) == 2 types = {row["Type"] for row in rows}
assert types == {"A", "B"}
finally:
spark.stop()
def test_window_function_comparison_with_nested_select(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = (
df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
.select("Name", "IsFirst")
.select("Name")
)
rows = result.collect()
assert len(rows) == 2
assert "Name" in rows[0]
finally:
spark.stop()
def test_window_function_comparison_with_case_when_chain(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"Category",
F.when(F.row_number().over(w) == 1, F.lit("Gold"))
.when(F.row_number().over(w) == 2, F.lit("Silver"))
.when(F.row_number().over(w) == 3, F.lit("Bronze"))
.otherwise(F.lit("Other")),
)
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert row["Category"] in ["Gold", "Silver", "Bronze", "Other"]
finally:
spark.stop()
def test_window_function_comparison_with_coalesce(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(1)).otherwise(F.lit(None)),
)
result = result.withColumn(
"RankOrOne", F.coalesce(F.col("IsFirst"), F.lit(0))
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["RankOrOne"] == 1
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["RankOrOne"] == 0
finally:
spark.stop()
def test_window_function_comparison_with_cast(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"RankStr",
F.when(F.row_number().over(w) == 1, F.lit("1")).otherwise(F.lit("0")),
).withColumn("RankInt", F.col("RankStr").cast("int"))
rows = result.collect()
assert len(rows) == 2
assert "RankInt" in rows[0]
assert isinstance(rows[0]["RankInt"], (int, type(None)))
finally:
spark.stop()
def test_window_function_comparison_with_avg(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"AboveAvg",
F.when(F.avg("Score").over(w) > 85, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert row["AboveAvg"] is True
finally:
spark.stop()
def test_window_function_comparison_with_max(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsMax",
F.when(F.max("Score").over(w) == F.col("Score"), F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["IsMax"] is True
finally:
spark.stop()
def test_window_function_comparison_with_min(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsMin",
F.when(F.min("Score").over(w) == F.col("Score"), F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
charlie_row = next(row for row in rows if row["Name"] == "Charlie")
assert charlie_row["IsMin"] is True
finally:
spark.stop()
def test_window_function_comparison_with_count(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"HasMultiple",
F.when(F.count("Score").over(w) > 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
other_rows = [row for row in rows if row["Name"] != "Alice"]
assert alice_row["HasMultiple"] is False
for row in other_rows:
assert row["HasMultiple"] is True
finally:
spark.stop()
def test_window_function_comparison_with_ntile(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
{"Name": "David", "Type": "A", "Score": 70},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"TopHalf",
F.when(F.ntile(2).over(w) == 1, F.lit(True)).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 4
top_half_count = sum(1 for row in rows if row["TopHalf"] is True)
assert top_half_count == 2
finally:
spark.stop()
def test_window_function_comparison_with_cume_dist(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"HighCumeDist",
F.when(F.cume_dist().over(w) > 0.5, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert "HighCumeDist" in row
finally:
spark.stop()
def test_window_function_comparison_with_first_value(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirstValue",
F.when(
F.first_value("Score").over(w) == F.col("Score"), F.lit(True)
).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["IsFirstValue"] is True
finally:
spark.stop()
def test_window_function_comparison_with_last_value(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsLastValue",
F.when(
F.last_value("Score").over(w) == F.col("Score"), F.lit(True)
).otherwise(F.lit(False)),
)
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert row["IsLastValue"] is True
finally:
spark.stop()
def test_window_function_comparison_with_countDistinct(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 100},
{"Name": "Charlie", "Type": "A", "Score": 90},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
with pytest.raises(Exception) as exc_info:
df.withColumn(
"HasMultipleDistinct",
F.when(F.countDistinct("Score").over(w) > 1, F.lit(True)).otherwise(
F.lit(False)
),
).collect()
msg = str(exc_info.value)
assert "Distinct window functions are not supported" in msg or (
"DISTINCT_WINDOW_FUNCTION_UNSUPPORTED" in msg
)
finally:
spark.stop()
def test_window_function_comparison_with_string_values(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Category": "High"},
{"Name": "Bob", "Type": "A", "Category": "Medium"},
{"Name": "Charlie", "Type": "A", "Category": "Low"},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Category").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
true_count = sum(1 for row in rows if row["IsFirst"] is True)
assert true_count == 1
for row in rows:
assert "IsFirst" in row
finally:
spark.stop()
def test_window_function_comparison_with_float_values(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100.5},
{"Name": "Bob", "Type": "A", "Score": 90.3},
{"Name": "Charlie", "Type": "A", "Score": 80.7},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["IsFirst"] is True
finally:
spark.stop()
def test_window_function_comparison_schema_verification(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
result = df.withColumn(
"GT-Zero",
F.when(F.row_number().over(w) > 0, F.lit(True)).otherwise(F.lit(False)),
)
schema = result.schema
field_names = [field.name for field in schema.fields]
assert "Name" in field_names
assert "Type" in field_names
assert "GT-Zero" in field_names
finally:
spark.stop()
def test_window_function_comparison_with_complex_filter(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
{"Name": "David", "Type": "B", "Score": 95},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = (
df.withColumn("Rank", F.row_number().over(w))
.filter((F.col("Rank") == 1) & (F.col("Type") == "A"))
.select("Name", "Type", "Score")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["Name"] == "Alice"
finally:
spark.stop()
def test_window_function_comparison_with_multiple_partitions(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "B", "Score": 95},
{"Name": "David", "Type": "B", "Score": 85},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 4
true_count = sum(1 for row in rows if row["IsFirst"] is True)
assert true_count == 2 finally:
spark.stop()
def test_window_function_comparison_with_no_partition(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
true_count = sum(1 for row in rows if row["IsFirst"] is True)
assert true_count == 1
first_row = next(row for row in rows if row["IsFirst"] is True)
assert first_row["Name"] in ["Alice", "Bob", "Charlie"]
finally:
spark.stop()
def test_window_function_comparison_with_rowsBetween(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = (
Window()
.partitionBy("Type")
.orderBy(F.col("Score").desc())
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
result = df.withColumn(
"HighRunningSum",
F.when(F.sum("Score").over(w) > 150, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["HighRunningSum"] is False
finally:
spark.stop()
def test_window_function_comparison_with_rangeBetween(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = (
Window()
.partitionBy("Type")
.orderBy(F.col("Score").desc())
.rangeBetween(Window.unboundedPreceding, Window.currentRow)
)
result = df.withColumn(
"HighRunningSum",
F.when(F.sum("Score").over(w) > 150, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert "HighRunningSum" in row
finally:
spark.stop()
def test_window_function_comparison_with_negative_values(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": -10},
{"Name": "Bob", "Type": "A", "Score": -20},
{"Name": "Charlie", "Type": "A", "Score": -30},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["IsFirst"] is True
finally:
spark.stop()
def test_window_function_comparison_with_zero_values(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 0},
{"Name": "Bob", "Type": "A", "Score": 0},
{"Name": "Charlie", "Type": "A", "Score": 0},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
true_count = sum(1 for row in rows if row["IsFirst"] is True)
assert true_count == 1
finally:
spark.stop()
def test_window_function_comparison_with_duplicate_scores(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 100},
{"Name": "Charlie", "Type": "A", "Score": 100},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
true_count = sum(1 for row in rows if row["IsFirst"] is True)
assert true_count == 1
finally:
spark.stop()
def test_window_function_comparison_with_all_null_partition(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
schema = StructType(
[
StructField("Name", StringType(), True),
StructField("Type", StringType(), True),
StructField("Score", IntegerType(), True),
]
)
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": None},
{"Name": "Bob", "Type": "A", "Score": None},
],
schema=schema,
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 2
true_count = sum(1 for row in rows if row["IsFirst"] is True)
assert true_count == 1
finally:
spark.stop()
def test_window_function_comparison_with_mixed_types(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100, "Category": "High"},
{"Name": "Bob", "Type": "A", "Score": 90, "Category": "Medium"},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit("Yes")).otherwise(
F.lit("No")
),
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["IsFirst"] == "Yes"
finally:
spark.stop()
def test_window_function_comparison_with_desc_ordering(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["IsFirst"] is True
finally:
spark.stop()
def test_window_function_comparison_with_asc_ordering(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").asc())
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
charlie_row = next(row for row in rows if row["Name"] == "Charlie")
assert charlie_row["IsFirst"] is True
finally:
spark.stop()
def test_window_function_comparison_with_multiple_orderby_columns(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100, "Age": 25},
{"Name": "Bob", "Type": "A", "Score": 100, "Age": 30},
{"Name": "Charlie", "Type": "A", "Score": 90, "Age": 20},
]
)
w = (
Window()
.partitionBy("Type")
.orderBy(F.col("Score").desc(), F.col("Age").asc())
)
result = df.withColumn(
"IsFirst",
F.when(F.row_number().over(w) == 1, F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
true_count = sum(1 for row in rows if row["IsFirst"] is True)
assert true_count == 1
first_row = next(row for row in rows if row["IsFirst"] is True)
assert first_row["Score"] == 100
assert first_row["Name"] in ["Alice", "Bob"]
finally:
spark.stop()
def test_window_function_comparison_with_chained_filters(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
{"Name": "David", "Type": "B", "Score": 95},
]
)
w = Window().partitionBy("Type").orderBy(F.col("Score").desc())
with pytest.raises(Exception) as exc_info:
(
df.filter(F.row_number().over(w) == 1)
.filter(F.col("Type") == "A")
.select("Name", "Type", "Score")
.collect()
)
msg = str(exc_info.value)
assert "window functions inside WHERE clause" in msg
finally:
spark.stop()
def test_window_function_comparison_with_window_function_in_value(self):
spark = SparkSession.builder.appName("issue-336").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Score": 100},
{"Name": "Bob", "Type": "A", "Score": 90},
{"Name": "Charlie", "Type": "A", "Score": 80},
]
)
w1 = Window().partitionBy("Type").orderBy(F.col("Score").desc())
w2 = Window().partitionBy("Type").orderBy(F.col("Score").asc())
df_with_ranks = df.withColumn("Rank1", F.row_number().over(w1)).withColumn(
"Rank2", F.row_number().over(w2)
)
result = df_with_ranks.withColumn(
"RankMatch",
F.when(F.col("Rank1") == F.col("Rank2"), F.lit(True)).otherwise(
F.lit(False)
),
)
rows = result.collect()
assert len(rows) == 3
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["RankMatch"] is True
finally:
spark.stop()