from sparkless.testing import get_imports
_imports = get_imports()
SparkSession = _imports.SparkSession
F = _imports.F
class TestIssue286AggregateFunctionArithmetic:
def test_count_distinct_minus_one(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 2},
{"Name": "Bob", "Value": 3},
]
)
result = df.groupBy("Name").agg(
(F.countDistinct("Value") - 1).alias("count_minus_one")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["count_minus_one"] == 1
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["count_minus_one"] == 0
assert "count_minus_one" in result.columns
finally:
spark.stop()
def test_count_plus_one(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 2},
{"Name": "Bob", "Value": 3},
]
)
result = df.groupBy("Name").agg(
(F.count("Value") + 1).alias("count_plus_one")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["count_plus_one"] == 3
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["count_plus_one"] == 2
finally:
spark.stop()
def test_sum_multiply(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 2},
{"Name": "Bob", "Value": 3},
]
)
result = df.groupBy("Name").agg((F.sum("Value") * 2).alias("sum_times_two"))
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["sum_times_two"] == 6
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["sum_times_two"] == 6
finally:
spark.stop()
def test_avg_divide(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
{"Name": "Bob", "Value": 30},
]
)
result = df.groupBy("Name").agg(
(F.avg("Value") / 2).alias("avg_divided_by_two")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["avg_divided_by_two"] == 7.5
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["avg_divided_by_two"] == 15.0
finally:
spark.stop()
def test_max_modulo(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 7},
{"Name": "Alice", "Value": 8},
{"Name": "Bob", "Value": 9},
]
)
result = df.groupBy("Name").agg((F.max("Value") % 3).alias("max_mod_three"))
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["max_mod_three"] == 2
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["max_mod_three"] == 0
finally:
spark.stop()
def test_reverse_operations(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 2},
{"Name": "Bob", "Value": 3},
]
)
result = df.groupBy("Name").agg(
(10 - F.countDistinct("Value")).alias("ten_minus_count")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["ten_minus_count"] == 8
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["ten_minus_count"] == 9
finally:
spark.stop()
def test_chained_arithmetic(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 2},
{"Name": "Bob", "Value": 3},
]
)
result = df.groupBy("Name").agg(
((F.countDistinct("Value") - 1) * 2).alias("count_minus_one_times_two")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["count_minus_one_times_two"] == 2
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["count_minus_one_times_two"] == 0
finally:
spark.stop()
def test_multiple_aggregate_arithmetic(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 2},
{"Name": "Bob", "Value": 3},
]
)
result = df.groupBy("Name").agg(
(F.countDistinct("Value") - 1).alias("count_minus_one"),
(F.count("Value") + 1).alias("count_plus_one"),
(F.sum("Value") * 2).alias("sum_times_two"),
)
rows = result.collect()
assert len(rows) == 2
assert "count_minus_one" in result.columns
assert "count_plus_one" in result.columns
assert "sum_times_two" in result.columns
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["count_minus_one"] == 1
assert alice_row["count_plus_one"] == 3
assert alice_row["sum_times_two"] == 6
finally:
spark.stop()
def test_arithmetic_with_nulls(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": None},
{"Name": "Alice", "Value": None},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").agg(
(F.count("Value") + 1).alias("count_plus_one")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["count_plus_one"] == 1
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["count_plus_one"] == 2
finally:
spark.stop()
def test_arithmetic_with_floats(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1.5},
{"Name": "Alice", "Value": 2.5},
{"Name": "Bob", "Value": 3.7},
]
)
result = df.groupBy("Name").agg(
(F.sum("Value") * 1.5).alias("sum_times_one_five")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert abs(alice_row["sum_times_one_five"] - 6.0) < 0.001
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert abs(bob_row["sum_times_one_five"] - 5.55) < 0.001
finally:
spark.stop()
def test_arithmetic_with_negative_numbers(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": -5},
{"Name": "Alice", "Value": -3},
{"Name": "Bob", "Value": 10},
]
)
result = df.groupBy("Name").agg((F.sum("Value") + 10).alias("sum_plus_ten"))
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["sum_plus_ten"] == 2
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["sum_plus_ten"] == 20
finally:
spark.stop()
def test_arithmetic_with_zero(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 0},
{"Name": "Alice", "Value": 0},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").agg((F.sum("Value") * 2).alias("sum_times_two"))
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["sum_times_two"] == 0
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["sum_times_two"] == 10
finally:
spark.stop()
def test_division_by_zero_handling(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
]
)
result = df.groupBy("Name").agg(
(F.sum("Value") / 0).alias("sum_divided_by_zero")
)
rows = result.collect()
assert len(rows) == 1
alice_row = rows[0]
assert alice_row["sum_divided_by_zero"] is None
finally:
spark.stop()
def test_modulo_by_zero_handling(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
]
)
result = df.groupBy("Name").agg((F.sum("Value") % 0).alias("sum_mod_zero"))
rows = result.collect()
assert len(rows) == 1
alice_row = rows[0]
assert alice_row["sum_mod_zero"] is None
finally:
spark.stop()
def test_min_function_arithmetic(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 5},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 3},
]
)
result = df.groupBy("Name").agg((F.min("Value") + 5).alias("min_plus_five"))
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["min_plus_five"] == 10
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["min_plus_five"] == 8
finally:
spark.stop()
def test_stddev_arithmetic(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
{"Name": "Alice", "Value": 30},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").agg(
(F.stddev("Value") * 2).alias("stddev_times_two")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["stddev_times_two"] is not None
assert isinstance(alice_row["stddev_times_two"], (int, float))
finally:
spark.stop()
def test_variance_arithmetic(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
{"Name": "Alice", "Value": 30},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").agg(
(F.variance("Value") + 1).alias("variance_plus_one")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["variance_plus_one"] is not None
assert isinstance(alice_row["variance_plus_one"], (int, float))
finally:
spark.stop()
def test_complex_nested_operations(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").agg(
(((F.sum("Value") + 1) * 2) - 3).alias("complex_expr")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["complex_expr"] == 59
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["complex_expr"] == 9
finally:
spark.stop()
def test_all_arithmetic_operators(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
]
)
result = df.groupBy("Name").agg(
(F.sum("Value") + 5).alias("add"),
(F.sum("Value") - 5).alias("sub"),
(F.sum("Value") * 2).alias("mul"),
(F.sum("Value") / 2).alias("div"),
(F.sum("Value") % 7).alias("mod"),
)
rows = result.collect()
assert len(rows) == 1
alice_row = rows[0]
assert alice_row["add"] == 35 assert alice_row["sub"] == 25 assert alice_row["mul"] == 60 assert alice_row["div"] == 15.0 assert alice_row["mod"] == 2 finally:
spark.stop()
def test_reverse_all_operators(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 5},
{"Name": "Alice", "Value": 10},
]
)
result = df.groupBy("Name").agg(
(10 + F.sum("Value")).alias("radd"),
(100 - F.sum("Value")).alias("rsub"),
(2 * F.sum("Value")).alias("rmul"),
(60 / F.sum("Value")).alias("rtruediv"),
(30 % F.sum("Value")).alias("rmod"),
)
rows = result.collect()
assert len(rows) == 1
alice_row = rows[0]
assert alice_row["radd"] == 25 assert alice_row["rsub"] == 85 assert alice_row["rmul"] == 30 assert abs(alice_row["rtruediv"] - 4.0) < 0.001 assert alice_row["rmod"] == 0 finally:
spark.stop()
def test_count_star_arithmetic(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 2},
{"Name": "Bob", "Value": 3},
]
)
result = df.groupBy("Name").agg(
(F.count("*") - 1).alias("count_star_minus_one")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["count_star_minus_one"] == 1
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["count_star_minus_one"] == 0
finally:
spark.stop()
def test_empty_group_handling(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 20},
]
)
filtered_df = df.filter(F.col("Name") == "Charlie")
result = filtered_df.groupBy("Name").agg(
(F.sum("Value") + 1).alias("sum_plus_one")
)
rows = result.collect()
assert len(rows) == 0
finally:
spark.stop()
def test_large_numbers(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1000000},
{"Name": "Alice", "Value": 2000000},
]
)
result = df.groupBy("Name").agg((F.sum("Value") * 2).alias("sum_times_two"))
rows = result.collect()
assert len(rows) == 1
alice_row = rows[0]
assert alice_row["sum_times_two"] == 6000000 finally:
spark.stop()
def test_mixed_aggregate_functions(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
{"Name": "Alice", "Value": 30},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").agg(
(F.count("Value") + 1).alias("count_plus_one"),
(F.sum("Value") - 5).alias("sum_minus_five"),
(F.avg("Value") * 2).alias("avg_times_two"),
(F.max("Value") / 2).alias("max_divided_by_two"),
(F.min("Value") % 3).alias("min_mod_three"),
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["count_plus_one"] == 4 assert alice_row["sum_minus_five"] == 55 assert alice_row["avg_times_two"] == 40.0 assert alice_row["max_divided_by_two"] == 15.0 assert alice_row["min_mod_three"] == 1 finally:
spark.stop()
def test_arithmetic_with_alias(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
]
)
result = df.groupBy("Name").agg((F.sum("Value") - 1).alias("custom_alias"))
rows = result.collect()
assert len(rows) == 1
alice_row = rows[0]
assert "custom_alias" in result.columns
assert alice_row["custom_alias"] == 29 finally:
spark.stop()
def test_arithmetic_precedence(self):
spark = SparkSession.builder.appName("issue-286").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 10},
{"Name": "Alice", "Value": 20},
]
)
result = df.groupBy("Name").agg(
(F.sum("Value") + 1 * 2).alias("precedence_test")
)
rows = result.collect()
assert len(rows) == 1
alice_row = rows[0]
assert alice_row["precedence_test"] == 32 finally:
spark.stop()