from sparkless.testing import get_imports
def _norm(val):
if val is None:
return None
if isinstance(val, (int, float)):
return float(val) if isinstance(val, float) else int(val)
return val
class TestIssue393SumStringColumn:
def test_sum_string_column_partition_by_order_by(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Value": "10"},
{"Name": "Bob", "Type": "A", "Value": "20"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
df = df.withColumn("SumValue", F.sum(df.Value).over(w))
rows = df.collect()
assert len(rows) == 2
for r in rows:
assert _norm(r["SumValue"]) == 30.0
def test_avg_string_column(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Type": "A", "Value": "10"},
{"Type": "A", "Value": "20"},
{"Type": "A", "Value": "30"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
df = df.withColumn("AvgValue", F.avg("Value").over(w))
rows = df.collect()
for r in rows:
assert _norm(r["AvgValue"]) == 20.0
def test_sum_string_column_running_sum(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Type": "A", "Name": "Alice", "Value": "10"},
{"Type": "A", "Name": "Bob", "Value": "20"},
]
)
w = Window.partitionBy("Type").orderBy("Name")
df = df.withColumn("SumValue", F.sum("Value").over(w))
rows = df.collect()
alice = next(r for r in rows if r["Name"] == "Alice")
bob = next(r for r in rows if r["Name"] == "Bob")
assert _norm(alice["SumValue"]) == 10.0
assert _norm(bob["SumValue"]) == 30.0
def test_sum_string_column_with_show(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Value": "10"},
{"Name": "Bob", "Type": "A", "Value": "20"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
df = df.withColumn("SumValue", F.sum(df.Value).over(w))
df.show()
rows = df.collect()
for r in rows:
assert _norm(r["SumValue"]) == 30.0
def test_sum_string_column_with_nulls(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Type": "A", "Value": "10"},
{"Type": "A", "Value": None},
{"Type": "A", "Value": "20"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
df = df.withColumn("SumValue", F.sum("Value").over(w))
rows = df.collect()
for r in rows:
assert _norm(r["SumValue"]) == 30.0
def test_sum_string_column_no_partition_running_sum(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Value": "5"},
{"Value": "15"},
{"Value": "10"},
]
)
w = Window.orderBy("Value")
df = df.withColumn("SumValue", F.sum("Value").over(w))
rows = df.collect()
sums = sorted([_norm(r["SumValue"]) for r in rows])
assert sums == [10.0, 25.0, 30.0]
def test_avg_string_column_multiple_partitions(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Type": "A", "Value": "10"},
{"Type": "A", "Value": "20"},
{"Type": "B", "Value": "30"},
{"Type": "B", "Value": "50"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
df = df.withColumn("AvgValue", F.avg("Value").over(w))
rows = df.collect()
a_rows = [r for r in rows if r["Type"] == "A"]
b_rows = [r for r in rows if r["Type"] == "B"]
for r in a_rows:
assert _norm(r["AvgValue"]) == 15.0
for r in b_rows:
assert _norm(r["AvgValue"]) == 40.0
def test_sum_string_column_decimal_like(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Type": "A", "Value": "1.5"},
{"Type": "A", "Value": "2.5"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
df = df.withColumn("SumValue", F.sum("Value").over(w))
rows = df.collect()
for r in rows:
assert _norm(r["SumValue"]) == 4.0
def test_sum_string_column_single_row_partition(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame([{"Type": "A", "Value": "42"}])
w = Window().partitionBy("Type").orderBy("Type")
df = df.withColumn("SumValue", F.sum("Value").over(w))
rows = df.collect()
assert len(rows) == 1
assert _norm(rows[0]["SumValue"]) == 42.0
def test_sum_string_column_select_after(self, spark):
imports = get_imports()
F, Window = imports.F, imports.Window
df = spark.createDataFrame(
[
{"Type": "A", "Value": "10"},
{"Type": "A", "Value": "20"},
{"Type": "B", "Value": "5"},
]
)
w = Window().partitionBy("Type").orderBy("Type")
df = df.withColumn("SumValue", F.sum("Value").over(w)).select(
"Type", "Value", "SumValue"
)
rows = df.collect()
assert len(rows) == 3
a_rows = [r for r in rows if r["Type"] == "A"]
assert _norm(a_rows[0]["SumValue"]) == 30.0
assert _norm(a_rows[1]["SumValue"]) == 30.0
b_rows = [r for r in rows if r["Type"] == "B"]
assert _norm(b_rows[0]["SumValue"]) == 5.0