from tests.fixtures.spark_imports import get_spark_imports
import pytest
imports = get_spark_imports()
F = imports.F
StructType = imports.StructType
StructField = imports.StructField
StringType = imports.StringType
IntegerType = imports.IntegerType
Window = imports.Window
class TestArrayParameterFormats:
def test_array_with_string_columns(self, spark):
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
result = df.withColumn("Array-Field-1", F.array("Name", "Type"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Array-Field-1"] == ["Alice", "A"]
assert rows[1]["Array-Field-1"] == ["Bob", "B"]
def test_array_with_list_of_strings(self, spark):
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
result = df.withColumn("Array-Field-2", F.array(["Name", "Type"]))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Array-Field-2"] == ["Alice", "A"]
assert rows[1]["Array-Field-2"] == ["Bob", "B"]
def test_array_with_column_objects(self, spark):
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
result = df.withColumn("Array-Field-3", F.array(F.col("Name"), F.col("Type")))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Array-Field-3"] == ["Alice", "A"]
assert rows[1]["Array-Field-3"] == ["Bob", "B"]
def test_array_with_list_of_column_objects(self, spark):
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
result = df.withColumn("Array-Field-4", F.array([F.col("Name"), F.col("Type")]))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Array-Field-4"] == ["Alice", "A"]
assert rows[1]["Array-Field-4"] == ["Bob", "B"]
def test_array_all_formats_together(self, spark):
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A"},
{"Name": "Bob", "Type": "B"},
]
)
result = (
df.withColumn("Array-Field-1", F.array("Name", "Type"))
.withColumn("Array-Field-2", F.array(["Name", "Type"]))
.withColumn("Array-Field-3", F.array(F.col("Name"), F.col("Type")))
.withColumn("Array-Field-4", F.array([F.col("Name"), F.col("Type")]))
)
rows = result.collect()
assert len(rows) == 2
expected = ["Alice", "A"]
assert rows[0]["Array-Field-1"] == expected
assert rows[0]["Array-Field-2"] == expected
assert rows[0]["Array-Field-3"] == expected
assert rows[0]["Array-Field-4"] == expected
expected = ["Bob", "B"]
assert rows[1]["Array-Field-1"] == expected
assert rows[1]["Array-Field-2"] == expected
assert rows[1]["Array-Field-3"] == expected
assert rows[1]["Array-Field-4"] == expected
def test_array_with_mixed_types(self, spark):
df = spark.createDataFrame(
[
{"name": "Alice", "age": 25, "active": True},
{"name": "Bob", "age": 30, "active": False},
]
)
with pytest.raises(Exception):
df.withColumn("info", F.array("name", "age", "active")).collect()
def test_array_with_single_column(self, spark):
df = spark.createDataFrame(
[
{"Name": "Alice"},
{"Name": "Bob"},
]
)
result = df.withColumn("Array-Field", F.array("Name"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Array-Field"] == ["Alice"]
assert rows[1]["Array-Field"] == ["Bob"]
def test_array_with_three_columns(self, spark):
df = spark.createDataFrame(
[
{"a": 1, "b": 2, "c": 3},
{"a": 4, "b": 5, "c": 6},
]
)
result = df.withColumn("combined", F.array("a", "b", "c"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["combined"] == [1, 2, 3]
assert rows[1]["combined"] == [4, 5, 6]
def test_array_with_computed_columns(self, spark):
df = spark.createDataFrame([{"a": 10, "b": 20}])
result = df.withColumn(
"computed", F.array(F.col("a") + F.col("b"), F.col("a") * F.col("b"))
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["computed"] == [30, 200]
def test_array_with_list_of_computed_columns(self, spark):
df = spark.createDataFrame([{"a": 5, "b": 3}])
result = df.withColumn(
"computed", F.array([F.col("a") + F.col("b"), F.col("a") - F.col("b")])
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["computed"] == [8, 2]
def test_array_with_null_values(self, spark):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
]
)
df = spark.createDataFrame(
[{"name": "Alice", "age": None}, {"name": None, "age": 25}], schema=schema
)
result = df.withColumn("info", F.array("name", "age"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["info"][0] == "Alice"
assert "Alice" in rows[0]["info"]
last_val = rows[1]["info"][-1]
assert last_val == 25 or last_val == "25"
def test_array_with_all_null_columns(self, spark):
schema = StructType(
[
StructField("val1", StringType(), True),
StructField("val2", StringType(), True),
]
)
df = spark.createDataFrame([{"val1": None, "val2": None}], schema=schema)
result = df.withColumn("nulls", F.array("val1", "val2"))
rows = result.collect()
assert len(rows) == 1
arr = rows[0]["nulls"]
assert isinstance(arr, list)
assert len(arr) >= 0
def test_array_with_numeric_types(self, spark):
df = spark.createDataFrame(
[
{"int_val": 42, "float_val": 3.14, "long_val": 1000000},
]
)
result = df.withColumn("numbers", F.array("int_val", "float_val", "long_val"))
rows = result.collect()
assert len(rows) == 1
arr = rows[0]["numbers"]
assert arr[0] == 42
assert arr[1] == 3.14
assert arr[2] == 1000000
def test_array_with_boolean_types(self, spark):
df = spark.createDataFrame([{"flag1": True, "flag2": False}])
result = df.withColumn("flags", F.array("flag1", "flag2"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["flags"] == [True, False]
def test_array_with_mixed_types_comprehensive(self, spark):
df = spark.createDataFrame(
[
{
"str_val": "Alice",
"int_val": 25,
"float_val": 3.14,
"bool_val": True,
},
{
"str_val": "Bob",
"int_val": 30,
"float_val": 2.71,
"bool_val": False,
},
]
)
with pytest.raises(Exception):
df.withColumn(
"mixed", F.array("str_val", "int_val", "float_val", "bool_val")
).collect()
def test_array_in_select_statement(self, spark):
df = spark.createDataFrame([{"Name": "Alice", "Type": "A"}])
result = df.select(
F.array("Name", "Type").alias("format1"),
F.array(["Name", "Type"]).alias("format2"),
F.array(F.col("Name"), F.col("Type")).alias("format3"),
F.array([F.col("Name"), F.col("Type")]).alias("format4"),
)
rows = result.collect()
assert len(rows) == 1
expected = ["Alice", "A"]
assert rows[0]["format1"] == expected
assert rows[0]["format2"] == expected
assert rows[0]["format3"] == expected
assert rows[0]["format4"] == expected
def test_array_after_filter(self, spark):
df = spark.createDataFrame(
[
{"id": 1, "a": 10, "b": 20},
{"id": 2, "a": 30, "b": 40},
{"id": 3, "a": 50, "b": 60},
]
)
result = df.filter(F.col("id") > 1).withColumn("combined", F.array("a", "b"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["combined"] == [30, 40]
assert rows[1]["combined"] == [50, 60]
def test_array_in_groupby_context(self, spark):
df = spark.createDataFrame(
[
{"dept": "A", "name": "Alice", "age": 25},
{"dept": "A", "name": "Bob", "age": 30},
{"dept": "B", "name": "Charlie", "age": 35},
]
)
df_with_array = df.withColumn("info", F.array("name", "age"))
result = df_with_array.groupBy("dept").agg(F.count("name").alias("count"))
rows = result.collect()
assert len(rows) == 2
dept_counts = {row["dept"]: row["count"] for row in rows}
assert dept_counts["A"] == 2
assert dept_counts["B"] == 1
def test_array_in_join(self, spark):
df1 = spark.createDataFrame([{"id": 1, "name": "Alice"}])
df2 = spark.createDataFrame([{"id": 1, "age": 25}])
df1_with_array = df1.withColumn("info", F.array("name"))
result = df1_with_array.join(df2, "id")
rows = result.collect()
assert len(rows) == 1
assert rows[0]["info"] == ["Alice"]
assert rows[0]["age"] == 25
def test_array_with_window_functions(self, spark):
df = spark.createDataFrame(
[
{"id": 1, "value": 10},
{"id": 2, "value": 20},
{"id": 3, "value": 30},
]
)
window = Window.orderBy("id")
result = df.withColumn("arr", F.array("value")).withColumn(
"row_num", F.row_number().over(window)
)
rows = result.collect()
assert len(rows) == 3
for i, row in enumerate(rows, 1):
assert row["arr"] == [row["value"]]
assert row["row_num"] == i
def test_array_with_large_number_of_columns(self, spark):
data = {f"col_{i}": i for i in range(10)}
df = spark.createDataFrame([data])
col_names = [f"col_{i}" for i in range(10)]
result1 = df.withColumn("large_array", F.array(*col_names))
rows1 = result1.collect()
assert rows1[0]["large_array"] == list(range(10))
result2 = df.withColumn("large_array2", F.array(col_names))
rows2 = result2.collect()
assert rows2[0]["large_array2"] == list(range(10))
def test_array_with_computed_expressions(self, spark):
df = spark.createDataFrame([{"val": 10}])
result = df.withColumn("computed", F.array(F.col("val"), F.col("val") * 2))
rows = result.collect()
assert len(rows) == 1
arr = rows[0]["computed"]
assert isinstance(arr, list)
assert len(arr) == 2
assert arr[0] == 10 assert isinstance(arr[1], (int, float))
def test_array_with_nested_expressions(self, spark):
df = spark.createDataFrame([{"a": 5, "b": 3}])
result = df.withColumn(
"nested",
F.array(
(F.col("a") + F.col("b")) * 2,
F.col("a") - F.col("b"),
F.col("a") * F.col("b"),
),
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["nested"] == [16, 2, 15]
def test_array_in_union(self, spark):
df1 = spark.createDataFrame([{"id": 1, "val": "a"}])
df2 = spark.createDataFrame([{"id": 2, "val": "b"}])
df1_with_array = df1.withColumn("arr", F.array("val"))
df2_with_array = df2.withColumn("arr", F.array("val"))
rows1 = df1_with_array.collect()
rows2 = df2_with_array.collect()
assert len(rows1) == 1
assert len(rows2) == 1
assert rows1[0]["arr"] == ["a"]
assert rows2[0]["arr"] == ["b"]
try:
result = df1_with_array.union(df2_with_array)
union_rows = result.collect()
assert len(union_rows) == 2
except Exception:
pass
def test_array_with_special_characters_in_column_names(self, spark):
df = spark.createDataFrame([{"col-name": "test1", "col_name": "test2"}])
result = df.withColumn("combined", F.array("col-name", "col_name"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["combined"] == ["test1", "test2"]
def test_array_preserves_order(self, spark):
df = spark.createDataFrame([{"a": 1, "b": 2, "c": 3}])
result1 = df.withColumn("order1", F.array("a", "b", "c"))
result2 = df.withColumn("order2", F.array("c", "a", "b"))
rows1 = result1.collect()
rows2 = result2.collect()
assert rows1[0]["order1"] == [1, 2, 3]
assert rows2[0]["order2"] == [3, 1, 2]
def test_array_with_empty_strings(self, spark):
df = spark.createDataFrame([{"val1": "", "val2": "test"}])
result = df.withColumn("arr", F.array("val1", "val2"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["arr"] == ["", "test"]
def test_array_with_zero_and_negative_numbers(self, spark):
df = spark.createDataFrame([{"a": 0, "b": -5, "c": 10}])
result = df.withColumn("numbers", F.array("a", "b", "c"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["numbers"] == [0, -5, 10]
def test_array_all_formats_with_mixed_types(self, spark):
df = spark.createDataFrame(
[
{"name": "Alice", "age": 25, "active": True, "score": 3.14},
]
)
with pytest.raises(Exception):
(
df.withColumn("format1", F.array("name", "age", "active", "score"))
.withColumn("format2", F.array(["name", "age", "active", "score"]))
.withColumn(
"format3",
F.array(
F.col("name"),
F.col("age"),
F.col("active"),
F.col("score"),
),
)
.withColumn(
"format4",
F.array(
[
F.col("name"),
F.col("age"),
F.col("active"),
F.col("score"),
]
),
)
).collect()