from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
SparkSession = _imports.SparkSession
F = _imports.F
Window = _imports.Window
StringType = _imports.StringType
StructType = _imports.StructType
StructField = _imports.StructField
IntegerType = _imports.IntegerType
LongType = _imports.LongType
ArrayType = _imports.ArrayType
class TestIssue331ArrayContainsJoin:
def test_array_contains_join_basic(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Name"] == "Alice"
assert rows[0]["IDs"] == [1, 2, 3]
assert rows[0]["Dept"] == "A"
assert rows[0]["ID"] == 3
assert rows[1]["Name"] == "Bob"
assert rows[1]["IDs"] == [4, 5, 6]
assert rows[1]["Dept"] == "B"
assert rows[1]["ID"] == 5
finally:
spark.stop()
def test_array_contains_join_inner(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
{"Name": "Charlie", "IDs": [7, 8, 9]}, ]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="inner")
rows = result.collect()
assert len(rows) == 2
names = {row["Name"] for row in rows}
assert names == {"Alice", "Bob"}
finally:
spark.stop()
def test_array_contains_join_left(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
{"Name": "Charlie", "IDs": [7, 8, 9]}, ]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
rows = result.collect()
assert len(rows) == 3
charlie_row = next(row for row in rows if row["Name"] == "Charlie")
assert charlie_row["Dept"] is None
assert charlie_row["ID"] is None
finally:
spark.stop()
def test_array_contains_join_multiple_matches(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 1},
{"Dept": "B", "ID": 2},
{"Dept": "C", "ID": 3},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="inner")
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert row["Name"] == "Alice"
assert row["IDs"] == [1, 2, 3]
depts = {row["Dept"] for row in rows}
assert depts == {"A", "B", "C"}
finally:
spark.stop()
def test_array_contains_join_no_matches(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 10}, {"Dept": "B", "ID": 20}, ]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="inner")
rows = result.collect()
assert len(rows) == 0
result2 = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
rows2 = result2.collect()
assert len(rows2) == 1
assert rows2[0]["Name"] == "Alice"
assert rows2[0]["Dept"] is None
assert rows2[0]["ID"] is None
finally:
spark.stop()
def test_array_contains_join_null_arrays(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": None},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
rows = result.collect()
assert len(rows) >= 1
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Dept"] == "A"
bob_rows = [row for row in rows if row["Name"] == "Bob"]
if bob_rows:
assert bob_rows[0]["Dept"] is None
finally:
spark.stop()
def test_array_contains_join_null_ids(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": None},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
rows = result.collect()
assert len(rows) >= 1
matching_rows = [row for row in rows if row["Dept"] == "A"]
assert len(matching_rows) >= 1
finally:
spark.stop()
def test_array_contains_join_right(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
{"Dept": "C", "ID": 10}, ]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="right")
rows = result.collect()
assert len(rows) >= 2
depts = {row["Dept"] for row in rows}
assert "A" in depts
assert "B" in depts
finally:
spark.stop()
def test_array_contains_join_outer(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
{"Name": "Charlie", "IDs": [7, 8, 9]}, ]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
{"Dept": "C", "ID": 10}, ]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="outer")
rows = result.collect()
assert len(rows) >= 2
names = {row["Name"] for row in rows if row["Name"] is not None}
assert "Alice" in names
assert "Bob" in names
finally:
spark.stop()
def test_array_contains_join_with_select(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).select("Name", "Dept")
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Name"] == "Alice"
assert rows[0]["Dept"] == "A"
finally:
spark.stop()
def test_array_contains_join_with_filter(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).filter(F.col("Dept") == "A")
rows = result.collect()
assert len(rows) == 1
assert rows[0]["Name"] == "Alice"
assert rows[0]["Dept"] == "A"
finally:
spark.stop()
def test_array_contains_join_column_name_conflicts(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3], "Value": 10},
]
)
df2 = spark.createDataFrame(
[
{"Name": "Bob", "ID": 3, "Value": 20},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="inner")
rows = result.collect()
assert len(rows) == 1
assert "Name" in rows[0]
assert "Value" in rows[0]
finally:
spark.stop()
def test_array_contains_join_empty_dataframes(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
schema1 = StructType(
[
StructField("Name", StringType(), True),
StructField("IDs", ArrayType(IntegerType()), True),
]
)
schema2 = StructType(
[
StructField("Dept", StringType(), True),
StructField("ID", IntegerType(), True),
]
)
df1 = spark.createDataFrame([], schema=schema1)
df2 = spark.createDataFrame([], schema=schema2)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="inner")
rows = result.collect()
assert len(rows) == 0
finally:
spark.stop()
def test_array_contains_join_backward_compatibility(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "ID": 1},
{"Name": "Bob", "ID": 2},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 1},
{"Dept": "B", "ID": 2},
]
)
result = df1.join(df2, on="ID", how="inner")
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Name"] == "Alice"
assert rows[0]["Dept"] == "A"
finally:
spark.stop()
def test_array_contains_join_empty_arrays(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": []}, {"Name": "Bob", "IDs": [1, 2, 3]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 1},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Dept"] is None
assert alice_row["ID"] is None
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Dept"] == "A"
finally:
spark.stop()
def test_array_contains_join_duplicate_values(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 1, 2, 2, 3]}, ]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 1},
{"Dept": "B", "ID": 2},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="inner")
rows = result.collect()
assert len(rows) == 2
depts = {row["Dept"] for row in rows}
assert depts == {"A", "B"}
finally:
spark.stop()
def test_array_contains_join_string_arrays(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "Tags": ["python", "spark", "data"]},
{"Name": "Bob", "Tags": ["java", "scala"]},
]
)
df2 = spark.createDataFrame(
[
{"Skill": "Python", "Tag": "python"},
{"Skill": "Java", "Tag": "java"},
]
)
result = df1.join(df2, on=F.array_contains(df1.Tags, df2.Tag), how="left")
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Skill"] == "Python"
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Skill"] == "Java"
finally:
spark.stop()
def test_array_contains_join_float_arrays(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "Values": [1.5, 2.5, 3.5]},
{"Name": "Bob", "Values": [4.0, 5.0]},
]
)
df2 = spark.createDataFrame(
[
{"Category": "A", "Value": 2.5},
{"Category": "B", "Value": 5.0},
]
)
result = df1.join(
df2, on=F.array_contains(df1.Values, df2.Value), how="left"
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Category"] == "A"
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Category"] == "B"
finally:
spark.stop()
def test_array_contains_join_large_arrays(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
large_array = list(range(1, 101))
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": large_array},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 50}, {"Dept": "B", "ID": 1}, {"Dept": "C", "ID": 100}, ]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="inner")
rows = result.collect()
assert len(rows) == 3
depts = {row["Dept"] for row in rows}
assert depts == {"A", "B", "C"}
finally:
spark.stop()
def test_array_contains_join_with_where_clause(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3], "Age": 25},
{"Name": "Bob", "IDs": [4, 5, 6], "Age": 30},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).filter(F.col("Age") > 25)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["Name"] == "Bob"
assert rows[0]["Dept"] == "B"
finally:
spark.stop()
def test_array_contains_join_with_aggregation(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
{"Name": "Charlie", "IDs": [1, 2, 3]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = (
df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
.groupBy("Dept")
.agg(F.count("Name").alias("Count"))
)
rows = result.collect()
assert len(rows) >= 1
dept_counts = {
row["Dept"]: row["Count"] for row in rows if row["Dept"] is not None
}
assert "A" in dept_counts
assert "B" in dept_counts
finally:
spark.stop()
def test_array_contains_join_with_window_functions(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3], "Score": 100},
{"Name": "Bob", "IDs": [4, 5, 6], "Score": 90},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
window = Window.partitionBy("Dept").orderBy(F.col("Score").desc())
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).withColumn("Rank", F.row_number().over(window))
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert "Rank" in row
assert row["Rank"] == 1 finally:
spark.stop()
def test_array_contains_join_with_union(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
]
)
df3 = spark.createDataFrame(
[
{"Name": "Bob", "IDs": [4, 5, 6]},
]
)
result1 = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
result2 = df3.join(df2, on=F.array_contains(df3.IDs, df2.ID), how="left")
combined = result1.unionByName(result2, allowMissingColumns=True)
rows = combined.collect()
assert len(rows) == 2
names = {row["Name"] for row in rows}
assert names == {"Alice", "Bob"}
finally:
spark.stop()
def test_array_contains_join_with_distinct(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3], "Value": 10},
{"Name": "Alice", "IDs": [1, 2, 3], "Value": 10}, ]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
]
)
result = (
df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
.select(
"Name", "Dept", "ID", "Value"
) .distinct()
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["Name"] == "Alice"
finally:
spark.stop()
def test_array_contains_join_with_limit(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
{"Name": "Charlie", "IDs": [7, 8, 9]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).limit(1)
rows = result.collect()
assert len(rows) == 1
finally:
spark.stop()
def test_array_contains_join_multiple_conditions_same_df(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3, 4, 5]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 1},
{"Dept": "B", "ID": 2},
{"Dept": "C", "ID": 3},
{"Dept": "D", "ID": 4},
{"Dept": "E", "ID": 5},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="inner")
rows = result.collect()
assert len(rows) == 5
depts = {row["Dept"] for row in rows}
assert depts == {"A", "B", "C", "D", "E"}
for row in rows:
assert row["Name"] == "Alice"
finally:
spark.stop()
def test_array_contains_join_with_nested_select(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).select(
F.col("Name"),
F.col("Dept").alias("Department"),
F.col("ID").alias("MatchedID"),
)
rows = result.collect()
assert len(rows) == 2
assert "Name" in rows[0]
assert "Department" in rows[0]
assert "MatchedID" in rows[0]
finally:
spark.stop()
def test_array_contains_join_with_case_when(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).withColumn(
"Status",
F.when(F.col("Dept").isNotNull(), "Matched").otherwise("NoMatch"),
)
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert "Status" in row
if row["Dept"] is not None:
assert row["Status"] == "Matched"
else:
assert row["Status"] == "NoMatch"
finally:
spark.stop()
def test_array_contains_join_with_coalesce(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
{"Name": "Bob", "IDs": [4, 5, 6]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
{"Dept": "B", "ID": 5},
]
)
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).withColumn("FinalDept", F.coalesce(F.col("Dept"), F.lit("Unknown")))
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert "FinalDept" in row
assert row["FinalDept"] in ["A", "B", "Unknown"]
finally:
spark.stop()
def test_array_contains_join_with_cast(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
]
)
result = df1.join(
df2, on=F.array_contains(df1.IDs, df2.ID), how="left"
).withColumn("DeptStr", F.col("Dept").cast(StringType()))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["DeptStr"] == "A"
finally:
spark.stop()
def test_array_contains_join_schema_verification(self):
spark = SparkSession.builder.appName("issue-331").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "IDs": [1, 2, 3]},
]
)
df2 = spark.createDataFrame(
[
{"Dept": "A", "ID": 3},
]
)
result = df1.join(df2, on=F.array_contains(df1.IDs, df2.ID), how="left")
schema = result.schema
field_names = [field.name for field in schema.fields]
assert "Name" in field_names
assert "IDs" in field_names
assert "Dept" in field_names
assert "ID" in field_names
finally:
spark.stop()