import pytest
from tests.fixtures.spark_imports import get_spark_imports
imports = get_spark_imports()
F = imports.F
class TestPivotGroupedData:
@pytest.fixture
def sample_data(self):
return [
{"type": "A", "value": 1},
{"type": "A", "value": 10},
{"type": "B", "value": 5},
]
def test_pivot_sum(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = df.groupBy("type").pivot("type", ["A", "B"]).sum("value")
rows = result.collect()
assert len(rows) == 2
schema_names = [f.name for f in result.schema.fields]
assert "type" in schema_names
assert "A" in schema_names
assert "B" in schema_names
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 11
assert row_a["B"] is None assert row_b["A"] is None assert row_b["B"] == 5
def test_pivot_avg(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = df.groupBy("type").pivot("type", ["A", "B"]).avg("value")
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 5.5
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] == 5.0
def test_pivot_count(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = df.groupBy("type").pivot("type", ["A", "B"]).count()
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 2
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] == 1
def test_pivot_max(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = df.groupBy("type").pivot("type", ["A", "B"]).max("value")
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 10
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] == 5
def test_pivot_min(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = df.groupBy("type").pivot("type", ["A", "B"]).min("value")
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 1
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] == 5
def test_pivot_count_distinct(self, spark):
data = [
{"type": "A", "value": 1},
{"type": "A", "value": 1}, {"type": "A", "value": 10},
{"type": "B", "value": 5},
]
df = spark.createDataFrame(data)
result = (
df.groupBy("type").pivot("type", ["A", "B"]).agg(F.count_distinct("value"))
)
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 2
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] == 1
def test_pivot_collect_list(self, spark):
data = [
{"type": "A", "value": 1},
{"type": "A", "value": 10},
{"type": "B", "value": 5},
]
df = spark.createDataFrame(data)
result = (
df.groupBy("type").pivot("type", ["A", "B"]).agg(F.collect_list("value"))
)
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == [1, 10]
assert row_a["B"] is None or row_a["B"] == []
assert row_b["A"] is None or row_b["A"] == []
assert row_b["B"] == [5]
def test_pivot_collect_set(self, spark):
data = [
{"type": "A", "value": 1},
{"type": "A", "value": 1}, {"type": "A", "value": 10},
{"type": "B", "value": 5},
]
df = spark.createDataFrame(data)
result = (
df.groupBy("type").pivot("type", ["A", "B"]).agg(F.collect_set("value"))
)
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert set(row_a["A"]) == {1, 10}
assert row_a["B"] is None or row_a["B"] == []
assert row_b["A"] is None or row_b["A"] == []
assert row_b["B"] == [5]
def test_pivot_first(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = df.groupBy("type").pivot("type", ["A", "B"]).agg(F.first("value"))
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 1
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] == 5
def test_pivot_last(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = df.groupBy("type").pivot("type", ["A", "B"]).agg(F.last("value"))
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 10
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] == 5
def test_pivot_stddev(self, spark):
data = [
{"type": "A", "value": 1},
{"type": "A", "value": 10},
{"type": "B", "value": 5},
]
df = spark.createDataFrame(data)
result = df.groupBy("type").pivot("type", ["A", "B"]).agg(F.stddev("value"))
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] is not None
assert isinstance(row_a["A"], float)
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] is None
def test_pivot_variance(self, spark):
data = [
{"type": "A", "value": 1},
{"type": "A", "value": 10},
{"type": "B", "value": 5},
]
df = spark.createDataFrame(data)
result = df.groupBy("type").pivot("type", ["A", "B"]).agg(F.variance("value"))
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] is not None
assert isinstance(row_a["A"], float)
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] is None
def test_pivot_mean(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = df.groupBy("type").pivot("type", ["A", "B"]).mean("value")
rows = result.collect()
assert len(rows) == 2
row_a = next((r for r in rows if r["type"] == "A"), None)
row_b = next((r for r in rows if r["type"] == "B"), None)
assert row_a is not None
assert row_b is not None
assert row_a["A"] == 5.5
assert row_a["B"] is None
assert row_b["A"] is None
assert row_b["B"] == 5.0
def test_pivot_multiple_aggregates(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = (
df.groupBy("type")
.pivot("type", ["A", "B"])
.agg(F.sum("value").alias("total"), F.avg("value").alias("avg_val"))
)
rows = result.collect()
assert len(rows) == 2
schema_names = [f.name for f in result.schema.fields]
assert "type" in schema_names
assert "A_total" in schema_names
assert "A_avg_val" in schema_names
assert "B_total" in schema_names
assert "B_avg_val" in schema_names
def test_pivot_single_aggregate_with_alias(self, spark, sample_data):
df = spark.createDataFrame(sample_data)
result = (
df.groupBy("type")
.pivot("type", ["A", "B"])
.agg(F.sum("value").alias("total"))
)
rows = result.collect()
assert len(rows) == 2
schema_names = [f.name for f in result.schema.fields]
assert "type" in schema_names
assert "total" in schema_names or ("A" in schema_names and "B" in schema_names)
def test_pivot_empty_groups(self, spark):
data = [{"type": "A", "value": 1}]
df = spark.createDataFrame(data)
result = df.groupBy("type").pivot("type", ["A", "B"]).sum("value")
rows = result.collect()
assert len(rows) == 1
row_a = next((r for r in rows if r["type"] == "A"), None)
assert row_a is not None
assert row_a["A"] == 1
assert row_a["B"] is None