from sparkless.testing import get_imports
_imports = get_imports()
SparkSession = _imports.SparkSession
F = _imports.F
StructType = _imports.StructType
StructField = _imports.StructField
StringType = _imports.StringType
IntegerType = _imports.IntegerType
LongType = _imports.LongType
DoubleType = _imports.DoubleType
Window = _imports.Window
class TestIssue337GroupedDataMean:
def test_grouped_data_mean_single_column(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert alice_row["avg(Value)"] == 5.5
assert bob_row["avg(Value)"] == 5.0
finally:
spark.stop()
def test_grouped_data_mean_multiple_columns(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value1": 1, "Value2": 2},
{"Name": "Alice", "Value1": 10, "Value2": 20},
{"Name": "Bob", "Value1": 5, "Value2": 6},
]
)
result = df.groupBy("Name").mean("Value1", "Value2")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert alice_row["avg(Value1)"] == 5.5
assert alice_row["avg(Value2)"] == 11.0
assert bob_row["avg(Value1)"] == 5.0
assert bob_row["avg(Value2)"] == 6.0
finally:
spark.stop()
def test_grouped_data_mean_no_columns(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
try:
result = df.groupBy("Name").mean()
rows = result.collect()
assert len(rows) == 2
if rows:
assert "avg(1)" in rows[0] or "avg(Value)" in rows[0]
except Exception:
pass
finally:
spark.stop()
def test_grouped_data_mean_with_column_object(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
try:
df.groupBy("Name").mean(F.col("Value")).collect()
except Exception as exc:
msg = str(exc)
assert "NOT_ITERABLE" in msg or "Column is not iterable" in msg
else:
rows = df.groupBy("Name").mean("Value").collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_null_values(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": None},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_equals_avg(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result_mean = df.groupBy("Name").mean("Value")
result_avg = df.groupBy("Name").avg("Value")
rows_mean = result_mean.collect()
rows_avg = result_avg.collect()
assert len(rows_mean) == len(rows_avg) == 2
for mean_row in rows_mean:
avg_row = next(
row for row in rows_avg if row["Name"] == mean_row["Name"]
)
assert mean_row["avg(Value)"] == avg_row["avg(Value)"]
finally:
spark.stop()
def test_grouped_data_mean_with_float_values(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1.5},
{"Name": "Alice", "Value": 10.5},
{"Name": "Bob", "Value": 5.25},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 6.0
finally:
spark.stop()
def test_grouped_data_mean_with_negative_values(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": -10},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": -5},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 0.0
finally:
spark.stop()
def test_grouped_data_mean_with_zero_values(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 0},
{"Name": "Alice", "Value": 0},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 0.0
finally:
spark.stop()
def test_grouped_data_mean_with_single_row_per_group(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 1.0
finally:
spark.stop()
def test_grouped_data_mean_with_empty_dataframe(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame([], schema="Name string, Value int")
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 0
finally:
spark.stop()
def test_grouped_data_mean_with_chained_operations(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").mean("Value").orderBy("Name")
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Name"] in ["Alice", "Bob"]
finally:
spark.stop()
def test_grouped_data_mean_with_select(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").mean("Value").select("Name", "avg(Value)")
rows = result.collect()
assert len(rows) == 2
assert "Name" in rows[0]
assert "avg(Value)" in rows[0]
finally:
spark.stop()
def test_grouped_data_mean_with_filter(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").mean("Value").filter(F.col("avg(Value)") > 5.0)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["Name"] == "Alice"
finally:
spark.stop()
def test_grouped_data_mean_with_multiple_group_columns(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "Value": 1},
{"Name": "Alice", "Type": "A", "Value": 10},
{"Name": "Alice", "Type": "B", "Value": 5},
{"Name": "Bob", "Type": "A", "Value": 3},
]
)
result = df.groupBy("Name", "Type").mean("Value")
rows = result.collect()
assert len(rows) == 3
alice_a = next(
row for row in rows if row["Name"] == "Alice" and row["Type"] == "A"
)
assert alice_a["avg(Value)"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_all_null_values(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": None},
{"Name": "Alice", "Value": None},
{"Name": "Bob", "Value": 5},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] is None or alice_row["avg(Value)"] == 0
finally:
spark.stop()
def test_grouped_data_mean_with_large_dataset(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
data = [{"Name": "Alice", "Value": i} for i in range(1, 21)] data.extend(
[{"Name": "Bob", "Value": i} for i in range(1, 11)]
)
df = spark.createDataFrame(data)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 10.5
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["avg(Value)"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_duplicate_values(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 5},
{"Name": "Alice", "Value": 5},
{"Name": "Alice", "Value": 5},
{"Name": "Bob", "Value": 10},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 5.0
finally:
spark.stop()
def test_grouped_data_mean_with_very_large_numbers(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1000000},
{"Name": "Alice", "Value": 2000000},
{"Name": "Bob", "Value": 500000},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 1500000.0
finally:
spark.stop()
def test_grouped_data_mean_with_very_small_numbers(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 0.0001},
{"Name": "Alice", "Value": 0.0002},
{"Name": "Bob", "Value": 0.0005},
]
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert abs(alice_row["avg(Value)"] - 0.00015) < 0.000001
finally:
spark.stop()
def test_grouped_data_mean_with_mixed_int_float(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
schema = StructType(
[
StructField("Name", StringType(), True),
StructField("Value", DoubleType(), True),
]
)
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1.0},
{"Name": "Alice", "Value": 2.5},
{"Name": "Alice", "Value": 3.0},
{"Name": "Bob", "Value": 5.0},
],
schema=schema,
)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert abs(alice_row["avg(Value)"] - 2.1666666666666665) < 0.0001
finally:
spark.stop()
def test_grouped_data_mean_with_join(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
df2 = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
result = df1.groupBy("Name").mean("Value").join(df2, on="Name", how="left")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 5.5
assert alice_row["Type"] == "A"
finally:
spark.stop()
def test_grouped_data_mean_with_union(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
]
)
df2 = spark.createDataFrame(
[
{"Name": "Bob", "Value": 5},
]
)
combined = df1.unionByName(df2, allowMissingColumns=True)
result = combined.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_distinct(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 1}, {"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = df.distinct().groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_limit(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
{"Name": "Charlie", "Value": 3},
]
)
result = df.groupBy("Name").mean("Value").limit(2)
rows = result.collect()
assert len(rows) == 2
finally:
spark.stop()
def test_grouped_data_mean_with_withColumn(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name")
.mean("Value")
.withColumn("DoubleMean", F.col("avg(Value)") * 2)
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["DoubleMean"] == 11.0 finally:
spark.stop()
def test_grouped_data_mean_with_drop(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1, "Other": "X"},
{"Name": "Alice", "Value": 10, "Other": "Y"},
{"Name": "Bob", "Value": 5, "Other": "Z"},
]
)
result = df.groupBy("Name").mean("Value").drop("Name")
rows = result.collect()
assert len(rows) == 2
assert "avg(Value)" in rows[0]
finally:
spark.stop()
def test_grouped_data_mean_with_alias(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name")
.mean("Value")
.select(F.col("avg(Value)").alias("MeanValue"), "Name")
)
rows = result.collect()
assert len(rows) == 2
assert "MeanValue" in rows[0]
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["MeanValue"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_case_when(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name")
.mean("Value")
.withColumn(
"Category",
F.when(F.col("avg(Value)") > 5.0, "High").otherwise("Low"),
)
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Category"] == "High"
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Category"] == "Low"
finally:
spark.stop()
def test_grouped_data_mean_with_coalesce(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name")
.mean("Value")
.withColumn(
"MeanOrZero",
F.coalesce(F.col("avg(Value)"), F.lit(0)),
)
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["MeanOrZero"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_cast(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name")
.mean("Value")
.withColumn("MeanInt", F.col("avg(Value)").cast("int"))
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["MeanInt"] == 5 finally:
spark.stop()
def test_grouped_data_mean_schema_verification(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
]
)
result = df.groupBy("Name").mean("Value")
schema = result.schema
field_names = [field.name for field in schema.fields]
assert "Name" in field_names
assert "avg(Value)" in field_names
finally:
spark.stop()
def test_grouped_data_mean_with_multiple_aggregations(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name")
.mean("Value")
.join(
df.groupBy("Name").agg(F.max("Value").alias("MaxValue")),
on="Name",
)
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(Value)"] == 5.5
assert alice_row["MaxValue"] == 10
finally:
spark.stop()
def test_grouped_data_mean_with_window_functions(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name")
.mean("Value")
.withColumn(
"RowNum",
F.row_number().over(Window.orderBy("Name")),
)
)
rows = result.collect()
assert len(rows) == 2
assert "RowNum" in rows[0]
finally:
spark.stop()
def test_grouped_data_mean_with_orderBy(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
{"Name": "Charlie", "Value": 3},
]
)
result = df.groupBy("Name").mean("Value").orderBy("Name")
rows = result.collect()
assert len(rows) == 3
names = [row["Name"] for row in rows]
assert names == sorted(names)
finally:
spark.stop()
def test_grouped_data_mean_with_desc_orderBy(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name").mean("Value").orderBy(F.col("avg(Value)").desc())
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Name"] == "Alice"
assert rows[0]["avg(Value)"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_many_groups(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
data = [{"Name": f"Person{i}", "Value": i} for i in range(10)]
df = spark.createDataFrame(data)
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 10
for i in range(10):
person_row = next(row for row in rows if row["Name"] == f"Person{i}")
assert person_row["avg(Value)"] == float(i)
finally:
spark.stop()
def test_grouped_data_mean_with_complex_chained_operations(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1, "Type": "A"},
{"Name": "Alice", "Value": 10, "Type": "A"},
{"Name": "Bob", "Value": 5, "Type": "B"},
{"Name": "Charlie", "Value": 3, "Type": "A"},
]
)
result = (
df.groupBy("Name", "Type")
.mean("Value")
.filter(F.col("avg(Value)") > 4.0)
.select("Name", "Type", F.col("avg(Value)").alias("MeanValue"))
.orderBy("MeanValue", ascending=False)
)
rows = result.collect()
assert len(rows) == 2 assert rows[0]["Name"] == "Alice"
assert rows[0]["MeanValue"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_nested_select(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.groupBy("Name")
.mean("Value")
.select("Name", "avg(Value)")
.select("Name")
)
rows = result.collect()
assert len(rows) == 2
assert "Name" in rows[0]
finally:
spark.stop()
def test_grouped_data_mean_with_string_column_error(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": "A"},
{"Name": "Alice", "Value": "B"},
{"Name": "Bob", "Value": "C"},
]
)
try:
result = df.groupBy("Name").mean("Value")
rows = result.collect()
assert len(rows) == 2
except Exception:
pass
finally:
spark.stop()
def test_grouped_data_mean_with_column_alias_in_groupBy(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.select(F.col("Name").alias("Person"), "Value")
.groupBy("Person")
.mean("Value")
)
rows = result.collect()
assert len(rows) == 2
assert "Person" in rows[0]
alice_row = next(row for row in rows if row["Person"] == "Alice")
assert alice_row["avg(Value)"] == 5.5
finally:
spark.stop()
def test_grouped_data_mean_with_computed_column(self):
spark = SparkSession.builder.appName("issue-337").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": 1},
{"Name": "Alice", "Value": 10},
{"Name": "Bob", "Value": 5},
]
)
result = (
df.withColumn("DoubleValue", F.col("Value") * 2)
.groupBy("Name")
.mean("DoubleValue")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["avg(DoubleValue)"] == 11.0
finally:
spark.stop()