from sparkless.testing import get_imports
_imports = get_imports()
F = _imports.F
Window = _imports.Window
class TestIssue355DiamondDependency:
def test_unionByName_diamond_dependency(self, spark):
existing = spark.createDataFrame(
[(1, "a", 100), (2, "b", 200), (3, "c", 300)],
["id", "name", "value"],
)
source = spark.createDataFrame(
[(1, "a", 150), (2, "b", 250)],
["id", "name", "value"],
)
sk = source.select("id").distinct().withColumn("_m", F.lit(True))
branch_a = (
existing.join(sk, on=["id"], how="left")
.filter(F.col("_m").isNull())
.drop("_m")
)
branch_b = existing.join(source.select("id").distinct(), on=["id"], how="inner")
combined = branch_a.unionByName(branch_b, allowMissingColumns=True)
result = combined.collect()
assert len(result) == 3, f"Expected 3 rows, got {len(result)}"
ids = {row.id for row in result}
assert ids == {1, 2, 3}, f"Expected ids {{1, 2, 3}}, got {ids}"
id_to_value = {row.id: row.value for row in result}
assert id_to_value[1] == 100, (
f"Expected value 100 for id=1, got {id_to_value[1]}"
)
assert id_to_value[2] == 200, (
f"Expected value 200 for id=2, got {id_to_value[2]}"
)
assert id_to_value[3] == 300, (
f"Expected value 300 for id=3, got {id_to_value[3]}"
)
assert len(ids) == 3, "Found duplicate IDs in result"
def test_unionByName_diamond_dependency_with_filters(self, spark):
base = spark.createDataFrame(
[(1, "a", 10), (2, "b", 20), (3, "c", 30), (4, "d", 40)],
["id", "name", "value"],
)
branch_a = base.filter(F.col("id") % 2 == 0)
branch_b = base.filter(F.col("id") % 2 == 1)
result = branch_a.unionByName(branch_b).orderBy("id").collect()
assert len(result) == 4, f"Expected 4 rows, got {len(result)}"
ids = {row.id for row in result}
assert ids == {1, 2, 3, 4}, f"Expected ids {{1, 2, 3, 4}}, got {ids}"
def test_unionByName_diamond_dependency_with_select(self, spark):
base = spark.createDataFrame(
[(1, "a", 100), (2, "b", 200), (3, "c", 300)],
["id", "name", "value"],
)
branch_a = base.select("id", "value").withColumn("source", F.lit("A"))
branch_b = base.select("id", "name").withColumn("source", F.lit("B"))
result = (
branch_a.unionByName(branch_b, allowMissingColumns=True)
.orderBy("id", "source")
.collect()
)
assert len(result) == 6, f"Expected 6 rows, got {len(result)}"
id_counts = {}
for row in result:
id_counts[row.id] = id_counts.get(row.id, 0) + 1
assert all(count == 2 for count in id_counts.values()), (
"Each ID should appear exactly twice"
)
def test_unionByName_diamond_dependency_with_withColumn(self, spark):
base = spark.createDataFrame(
[(1, 10), (2, 20), (3, 30)],
["id", "value"],
)
branch_a = base.withColumn("doubled", F.col("value") * 2)
branch_b = base.withColumn("tripled", F.col("value") * 3)
result = (
branch_a.unionByName(branch_b, allowMissingColumns=True)
.orderBy("id")
.collect()
)
assert len(result) == 6, f"Expected 6 rows, got {len(result)}"
for row in result:
if hasattr(row, "doubled") and row.doubled is not None:
assert row.doubled == row.value * 2, (
f"Expected doubled={row.value * 2}, got {row.doubled}"
)
if hasattr(row, "tripled") and row.tripled is not None:
assert row.tripled == row.value * 3, (
f"Expected tripled={row.value * 3}, got {row.tripled}"
)
def test_unionByName_diamond_dependency_three_branches(self, spark):
base = spark.createDataFrame(
[(1, "a"), (2, "b"), (3, "c"), (4, "d")],
["id", "name"],
)
branch_a = base.filter(F.col("id") < 2).withColumn("branch", F.lit("A"))
branch_b = base.filter((F.col("id") >= 2) & (F.col("id") < 4)).withColumn(
"branch", F.lit("B")
)
branch_c = base.filter(F.col("id") >= 4).withColumn("branch", F.lit("C"))
result = (
branch_a.unionByName(branch_b).unionByName(branch_c).orderBy("id").collect()
)
assert len(result) == 4, f"Expected 4 rows, got {len(result)}"
ids = {row.id for row in result}
assert ids == {1, 2, 3, 4}, f"Expected ids {{1, 2, 3, 4}}, got {ids}"
def test_unionByName_diamond_dependency_nested_transformations(self, spark):
base = spark.createDataFrame(
[(1, "a", 10), (2, "b", 20), (3, "c", 30)],
["id", "name", "value"],
)
branch_a = (
base.filter(F.col("id") <= 2)
.withColumn("multiplied", F.col("value") * 2)
.select("id", "name", "multiplied")
.withColumnRenamed("multiplied", "result")
)
branch_b = (
base.filter(F.col("id") > 2)
.withColumn("added", F.col("value") + 10)
.select("id", "name", "added")
.withColumnRenamed("added", "result")
)
result = branch_a.unionByName(branch_b).orderBy("id").collect()
assert len(result) == 3, f"Expected 3 rows, got {len(result)}"
ids = {row.id for row in result}
assert ids == {1, 2, 3}, f"Expected ids {{1, 2, 3}}, got {ids}"
for row in result:
if row.id <= 2:
assert row.result == row.id * 10 * 2, (
f"Expected result={row.id * 10 * 2} for id={row.id}, got {row.result}"
)
else:
assert row.result == row.id * 10 + 10, (
f"Expected result={row.id * 10 + 10} for id={row.id}, got {row.result}"
)
def test_unionByName_diamond_dependency_with_aggregations(self, spark):
base = spark.createDataFrame(
[(1, "A", 10), (1, "A", 20), (2, "B", 30), (2, "B", 40)],
["id", "category", "value"],
)
branch_a = base.groupBy("id").agg(F.sum("value").alias("total"))
branch_b = base.groupBy("category").agg(F.avg("value").alias("average"))
result = (
branch_a.unionByName(branch_b, allowMissingColumns=True)
.orderBy("id")
.collect()
)
assert len(result) == 4, f"Expected 4 rows, got {len(result)}"
id_results = [row for row in result if row.id is not None]
assert len(id_results) == 2, "Should have 2 rows from branch_a"
id_to_total = {row.id: row.total for row in id_results}
assert id_to_total[1] == 30, f"Expected total=30 for id=1, got {id_to_total[1]}"
assert id_to_total[2] == 70, f"Expected total=70 for id=2, got {id_to_total[2]}"
category_results = [row for row in result if row.category is not None]
assert len(category_results) == 2, "Should have 2 rows from branch_b"
category_to_avg = {row.category: row.average for row in category_results}
assert category_to_avg["A"] == 15.0, (
f"Expected average=15.0 for category=A, got {category_to_avg['A']}"
)
assert category_to_avg["B"] == 35.0, (
f"Expected average=35.0 for category=B, got {category_to_avg['B']}"
)
def test_unionByName_diamond_dependency_empty_branch(self, spark):
base = spark.createDataFrame(
[(1, "a"), (2, "b"), (3, "c")],
["id", "name"],
)
branch_a = base.filter(F.col("id") > 0)
branch_b = base.filter(F.col("id") < 0)
result = branch_a.unionByName(branch_b).collect()
assert len(result) == 3, f"Expected 3 rows, got {len(result)}"
ids = {row.id for row in result}
assert ids == {1, 2, 3}, f"Expected ids {{1, 2, 3}}, got {ids}"
def test_unionByName_diamond_dependency_single_row(self, spark):
base = spark.createDataFrame([(1, "a", 100)], ["id", "name", "value"])
branch_a = base.withColumn("source", F.lit("A"))
branch_b = base.withColumn("source", F.lit("B")).withColumn(
"doubled", F.col("value") * 2
)
result = branch_a.unionByName(branch_b, allowMissingColumns=True).collect()
assert len(result) == 2, f"Expected 2 rows, got {len(result)}"
def test_unionByName_diamond_dependency_complex_expressions(self, spark):
base = spark.createDataFrame(
[(1, 10, 5), (2, 20, 10), (3, 30, 15)],
["id", "value1", "value2"],
)
branch_a = base.withColumn(
"computed", F.col("value1") * 2 + F.col("value2")
).select("id", "computed")
branch_b = base.withColumn(
"computed", F.col("value2") * 3 - F.col("value1")
).select("id", "computed")
result = branch_a.unionByName(branch_b).orderBy("id").collect()
assert len(result) == 6, f"Expected 6 rows, got {len(result)}"
for row in result:
if row.id == 1:
assert row.computed in [25, 5], (
f"Unexpected computed value {row.computed} for id=1"
)
def test_unionByName_diamond_dependency_with_drop(self, spark):
base = spark.createDataFrame(
[(1, "a", 10, "x"), (2, "b", 20, "y"), (3, "c", 30, "z")],
["id", "name", "value", "extra"],
)
branch_a = base.drop("extra")
branch_b = base.drop("name")
result = (
branch_a.unionByName(branch_b, allowMissingColumns=True)
.orderBy("id")
.collect()
)
assert len(result) == 6, f"Expected 6 rows, got {len(result)}"
def test_unionByName_diamond_dependency_multiple_unions(self, spark):
base = spark.createDataFrame(
[(1, "a"), (2, "b"), (3, "c"), (4, "d")],
["id", "name"],
)
branch_a = base.filter(F.col("id") == 1)
branch_b = base.filter(F.col("id") == 2)
branch_c = base.filter(F.col("id") == 3)
branch_d = base.filter(F.col("id") == 4)
result = (
branch_a.unionByName(branch_b)
.unionByName(branch_c)
.unionByName(branch_d)
.orderBy("id")
.collect()
)
assert len(result) == 4, f"Expected 4 rows, got {len(result)}"
ids = {row.id for row in result}
assert ids == {1, 2, 3, 4}, f"Expected ids {{1, 2, 3, 4}}, got {ids}"
def test_unionByName_diamond_dependency_with_window_functions(self, spark):
base = spark.createDataFrame(
[(1, "A", 10), (1, "A", 20), (2, "B", 30), (2, "B", 40)],
["id", "category", "value"],
)
branch_a = base.withColumn(
"row_num", F.row_number().over(Window.partitionBy("id").orderBy("value"))
)
branch_b = base.withColumn(
"rank_val", F.rank().over(Window.partitionBy("category").orderBy("value"))
)
result = (
branch_a.unionByName(branch_b, allowMissingColumns=True)
.orderBy("id", "value")
.collect()
)
assert len(result) == 8, f"Expected 8 rows, got {len(result)}"
def test_unionByName_diamond_dependency_preserves_original_data(self, spark):
base = spark.createDataFrame(
[(1, "a", 100), (2, "b", 200)],
["id", "name", "value"],
)
branch_a = base.filter(F.col("id") == 1).withColumn("branch", F.lit("A"))
branch_b = base.filter(F.col("id") == 2).withColumn("branch", F.lit("B"))
unioned = branch_a.unionByName(branch_b)
base_rows = base.collect()
assert len(base_rows) == 2, "Original DataFrame should be unchanged"
assert base_rows[0].value == 100, (
"Original DataFrame values should be unchanged"
)
assert base_rows[1].value == 200, (
"Original DataFrame values should be unchanged"
)
union_rows = unioned.collect()
assert len(union_rows) == 2, "Union should have 2 rows"