import pytest
from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
F = _imports.F
StructType = _imports.StructType
StructField = _imports.StructField
StringType = _imports.StringType
Window = _imports.Window
class TestCreateMap:
def test_create_map_with_literals(self, spark):
df = spark.createDataFrame([{"val1": "a", "val2": 1}])
result = df.select(
F.create_map(
F.lit("key1"), F.col("val1"), F.lit("key2"), F.col("val2")
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"] == {"key1": "a", "key2": "1"}
def test_create_map_with_column_values(self, spark):
df = spark.createDataFrame(
[
{"first": "Alice", "last": "Smith"},
{"first": "Bob", "last": "Jones"},
]
)
result = df.select(
F.create_map(
F.lit("first_name"),
F.col("first"),
F.lit("last_name"),
F.col("last"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["map_col"] == {"first_name": "Alice", "last_name": "Smith"}
assert rows[1]["map_col"] == {"first_name": "Bob", "last_name": "Jones"}
def test_create_map_multiple_pairs(self, spark):
df = spark.createDataFrame([{"a": 1, "b": 2, "c": 3}])
result = df.select(
F.create_map(
F.lit("x"),
F.col("a"),
F.lit("y"),
F.col("b"),
F.lit("z"),
F.col("c"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"] == {"x": 1, "y": 2, "z": 3}
def test_create_map_with_null_values(self, spark):
schema = StructType(
[
StructField("val1", StringType()),
StructField("val2", StringType()),
]
)
df = spark.createDataFrame([{"val1": "a", "val2": None}], schema=schema)
result = df.select(
F.create_map(
F.lit("key1"), F.col("val1"), F.lit("key2"), F.col("val2")
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"]["key1"] == "a"
assert rows[0]["map_col"]["key2"] is None
def test_create_map_validation_odd_args(self, spark):
df = spark.createDataFrame([{"x": 1}])
with pytest.raises(Exception):
df.select(F.create_map(F.lit("key1")).alias("m")).collect()
def test_create_map_empty_returns_empty_map(self, spark):
df = spark.createDataFrame([{"Name": "Alice"}, {"Name": "Bob"}])
result = df.withColumn("NewMap", F.create_map())
rows = result.collect()
assert len(rows) == 2
assert rows[0]["NewMap"] == {}
assert rows[1]["NewMap"] == {}
def test_create_map_empty_list_returns_empty_map(self, spark):
df = spark.createDataFrame([{"Name": "Alice"}, {"Name": "Bob"}])
result = df.withColumn("NewMap", F.create_map([]))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["NewMap"] == {}
assert rows[1]["NewMap"] == {}
def test_create_map_in_withcolumn(self, spark):
df = spark.createDataFrame([{"name": "Alice", "age": 25}])
result = df.withColumn(
"info",
F.create_map(
F.lit("name"),
F.col("name"),
F.lit("age"),
F.col("age"),
),
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["info"]["name"] == "Alice"
assert rows[0]["info"]["age"] in (25, "25")
def test_create_map_after_filter(self, spark):
df = spark.createDataFrame(
[
{"val1": "a", "val2": 1},
{"val1": "b", "val2": 2},
]
)
result = df.filter(F.col("val2") > 1).select(
F.create_map(
F.lit("key"),
F.col("val1"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"] == {"key": "b"}
def test_create_map_all_literals(self, spark):
df = spark.createDataFrame([{"dummy": 1}])
result = df.select(
F.create_map(
F.lit("static_key"),
F.lit("static_value"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"] == {"static_key": "static_value"}
def test_create_map_with_numeric_types(self, spark):
df = spark.createDataFrame(
[
{"int_val": 42, "float_val": 3.14, "long_val": 1000000},
]
)
result = df.select(
F.create_map(
F.lit("int"),
F.col("int_val"),
F.lit("float"),
F.col("float_val"),
F.lit("long"),
F.col("long_val"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"]["int"] == 42
assert rows[0]["map_col"]["float"] == 3.14
assert rows[0]["map_col"]["long"] == 1000000
def test_create_map_with_boolean_values(self, spark):
df = spark.createDataFrame([{"flag1": True, "flag2": False}])
result = df.select(
F.create_map(
F.lit("active"),
F.col("flag1"),
F.lit("inactive"),
F.col("flag2"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"]["active"] is True
assert rows[0]["map_col"]["inactive"] is False
def test_create_map_with_computed_values(self, spark):
df = spark.createDataFrame([{"a": 10, "b": 20}])
result = df.select(
F.create_map(
F.lit("sum"),
F.col("a") + F.col("b"),
F.lit("product"),
F.col("a") * F.col("b"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"]["sum"] == 30
assert rows[0]["map_col"]["product"] == 200
def test_create_map_multiple_maps_in_select(self, spark):
df = spark.createDataFrame([{"name": "Alice", "age": 25}])
result = df.select(
F.create_map(F.lit("name"), F.col("name")).alias("name_map"),
F.create_map(F.lit("age"), F.col("age")).alias("age_map"),
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["name_map"] == {"name": "Alice"}
assert rows[0]["age_map"]["age"] in (25, "25")
def test_create_map_in_groupby_agg(self, spark):
df = spark.createDataFrame(
[
{"dept": "A", "name": "Alice", "salary": 100},
{"dept": "A", "name": "Bob", "salary": 200},
{"dept": "B", "name": "Charlie", "salary": 150},
]
)
result = df.groupBy("dept").agg(F.max("salary").alias("max_salary"))
rows = result.collect()
assert len(rows) >= 1
result2 = df.select(
F.create_map(
F.lit("dept"),
F.col("dept"),
F.lit("name"),
F.col("name"),
).alias("info")
)
rows2 = result2.collect()
assert len(rows2) == 3
def test_create_map_with_special_characters_in_keys(self, spark):
df = spark.createDataFrame([{"val": "test"}])
result = df.select(
F.create_map(
F.lit("key_with_underscore"),
F.col("val"),
F.lit("key-with-dash"),
F.col("val"),
F.lit("key.with.dot"),
F.col("val"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"]["key_with_underscore"] == "test"
assert rows[0]["map_col"]["key-with-dash"] == "test"
assert rows[0]["map_col"]["key.with.dot"] == "test"
def test_create_map_nested_in_expressions(self, spark):
df = spark.createDataFrame([{"val": 10}])
result = df.select(
F.create_map(
F.lit("value"),
F.col("val"),
F.lit("doubled"),
F.col("val") * 2,
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
map_val = rows[0]["map_col"]
assert map_val["value"] == 10
assert map_val["doubled"] == 20
def test_create_map_with_empty_string_keys(self, spark):
df = spark.createDataFrame([{"val": "test"}])
result = df.select(
F.create_map(
F.lit(""),
F.col("val"),
).alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"][""] == "test"
def test_create_map_after_union(self, spark):
df = spark.createDataFrame([{"val": 1}, {"val": 2}])
result = df.select(F.create_map(F.lit("value"), F.col("val")).alias("map_col"))
rows = result.collect()
assert len(rows) == 2
values = set()
for row in rows:
map_val = row["map_col"]
if map_val is not None and isinstance(map_val, dict) and "value" in map_val:
values.add(map_val["value"])
assert values == {1, 2}
def test_create_map_with_null_keys(self, spark):
df = spark.createDataFrame([{"val": "test"}])
with pytest.raises(Exception):
df.select(
F.create_map(
F.lit(None),
F.col("val"),
F.lit("valid_key"),
F.col("val"),
).alias("map_col")
).collect()
def test_create_map_large_number_of_pairs(self, spark):
df = spark.createDataFrame([{"val": 1}])
pairs = []
for i in range(10):
pairs.extend([F.lit(f"key_{i}"), F.col("val") + i])
result = df.select(F.create_map(*pairs).alias("map_col"))
rows = result.collect()
assert len(rows) == 1
map_val = rows[0]["map_col"]
assert len(map_val) == 10
for i in range(10):
assert map_val[f"key_{i}"] == 1 + i
def test_create_map_empty_in_select(self, spark):
df = spark.createDataFrame([{"id": 1}, {"id": 2}, {"id": 3}])
result = df.select(F.col("id"), F.create_map().alias("empty_map"))
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert row["empty_map"] == {}
assert "id" in row
def test_create_map_empty_with_different_data_types(self, spark):
df1 = spark.createDataFrame([{"name": "Alice"}])
result1 = df1.withColumn("meta", F.create_map())
assert result1.collect()[0]["meta"] == {}
df2 = spark.createDataFrame([{"age": 25}])
result2 = df2.withColumn("meta", F.create_map())
assert result2.collect()[0]["meta"] == {}
df3 = spark.createDataFrame([{"active": True}])
result3 = df3.withColumn("meta", F.create_map())
assert result3.collect()[0]["meta"] == {}
schema = StructType([StructField("val", StringType(), True)])
df4 = spark.createDataFrame([{"val": None}], schema=schema)
result4 = df4.withColumn("meta", F.create_map())
assert result4.collect()[0]["meta"] == {}
def test_create_map_empty_after_filter(self, spark):
df = spark.createDataFrame(
[
{"id": 1, "value": 10},
{"id": 2, "value": 20},
{"id": 3, "value": 30},
]
)
result = df.filter(F.col("value") > 15).withColumn("empty_map", F.create_map())
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert row["empty_map"] == {}
assert row["value"] > 15
def test_create_map_empty_in_groupby_context(self, spark):
df = spark.createDataFrame(
[
{"dept": "A", "name": "Alice"},
{"dept": "A", "name": "Bob"},
{"dept": "B", "name": "Charlie"},
]
)
df_with_map = df.withColumn("meta", F.create_map())
result = df_with_map.groupBy("dept").agg(F.count("name").alias("count"))
rows = result.collect()
assert len(rows) == 2
dept_counts = {row["dept"]: row["count"] for row in rows}
assert dept_counts["A"] == 2
assert dept_counts["B"] == 1
def test_create_map_empty_in_join(self, spark):
df1 = spark.createDataFrame([{"id": 1, "name": "Alice"}])
df2 = spark.createDataFrame([{"id": 1, "age": 25}])
df1_with_map = df1.withColumn("meta", F.create_map())
result = df1_with_map.join(df2, "id")
rows = result.collect()
assert len(rows) == 1
assert rows[0]["meta"] == {}
assert rows[0]["name"] == "Alice"
assert rows[0]["age"] == 25
def test_create_map_empty_with_window_functions(self, spark):
df = spark.createDataFrame(
[
{"id": 1, "value": 10},
{"id": 2, "value": 20},
{"id": 3, "value": 30},
]
)
window = Window.orderBy("id")
result = df.withColumn("meta", F.create_map()).withColumn(
"row_num", F.row_number().over(window)
)
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert row["meta"] == {}
assert "row_num" in row
def test_create_map_empty_in_nested_expressions(self, spark):
df = spark.createDataFrame([{"val": 10}])
result = df.select(
F.when(F.col("val") > 5, F.create_map())
.otherwise(F.create_map())
.alias("map_col")
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map_col"] == {}
def test_create_map_empty_with_null_handling(self, spark):
schema = StructType([StructField("val", StringType(), True)])
df = spark.createDataFrame([{"val": None}, {"val": "test"}], schema=schema)
result = df.withColumn("meta", F.create_map()).withColumn(
"coalesced", F.coalesce(F.col("val"), F.lit("default"))
)
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert row["meta"] == {}
def test_create_map_empty_multiple_times(self, spark):
df = spark.createDataFrame([{"id": 1}])
result = df.select(
F.col("id"),
F.create_map().alias("map1"),
F.create_map().alias("map2"),
F.create_map().alias("map3"),
)
rows = result.collect()
assert len(rows) == 1
row = rows[0]
assert row["map1"] == {}
assert row["map2"] == {}
assert row["map3"] == {}
def test_create_map_empty_with_computed_columns(self, spark):
df = spark.createDataFrame([{"a": 10, "b": 20}])
result = df.select(
F.col("a"),
F.col("b"),
(F.col("a") + F.col("b")).alias("sum"),
F.create_map().alias("meta"),
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["sum"] == 30
assert rows[0]["meta"] == {}
def test_create_map_empty_in_union(self, spark):
df1 = spark.createDataFrame([{"id": 1, "val": "a"}])
df2 = spark.createDataFrame([{"id": 2, "val": "b"}])
df1_with_map = df1.withColumn("meta", F.create_map())
df2_with_map = df2.withColumn("meta", F.create_map())
result = df1_with_map.union(df2_with_map)
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert row["meta"] == {}
def test_create_map_empty_preserves_type(self, spark):
df = spark.createDataFrame([{"id": 1}])
result = df.withColumn("meta", F.create_map())
schema = result.schema
meta_field = next((f for f in schema.fields if f.name == "meta"), None)
assert meta_field is not None
rows = result.collect()
assert len(rows) == 1
assert rows[0]["meta"] == {}
def test_create_map_empty_tuple_returns_empty_map(self, spark):
df = spark.createDataFrame([{"Name": "Alice"}, {"Name": "Bob"}])
result = df.withColumn("NewMap", F.create_map())
rows = result.collect()
assert len(rows) == 2
assert rows[0]["NewMap"] == {}
assert rows[1]["NewMap"] == {}
def test_create_map_empty_list_in_select(self, spark):
df = spark.createDataFrame([{"id": 1}, {"id": 2}, {"id": 3}])
result = df.select(F.col("id"), F.create_map([]).alias("empty_map"))
rows = result.collect()
assert len(rows) == 3
for row in rows:
assert row["empty_map"] == {}
assert "id" in row
def test_create_map_empty_list_with_different_data_types(self, spark):
df1 = spark.createDataFrame([{"name": "Alice"}])
assert df1.withColumn("meta", F.create_map([])).collect()[0]["meta"] == {}
df2 = spark.createDataFrame([{"age": 25}])
assert df2.withColumn("meta", F.create_map([])).collect()[0]["meta"] == {}
df3 = spark.createDataFrame([{"active": True}])
assert df3.withColumn("meta", F.create_map([])).collect()[0]["meta"] == {}
schema = StructType([StructField("val", StringType(), True)])
df4 = spark.createDataFrame([{"val": None}], schema=schema)
assert df4.withColumn("meta", F.create_map([])).collect()[0]["meta"] == {}
def test_create_map_empty_list_after_filter(self, spark):
df = spark.createDataFrame(
[
{"id": 1, "value": 10},
{"id": 2, "value": 20},
{"id": 3, "value": 30},
]
)
result = df.filter(F.col("value") > 15).withColumn(
"empty_map", F.create_map([])
)
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert row["empty_map"] == {}
assert row["value"] > 15
def test_create_map_empty_list_and_create_map_equivalent(self, spark):
df = spark.createDataFrame([{"id": 1}, {"id": 2}])
result = df.select(
F.col("id"),
F.create_map().alias("map_no_args"),
F.create_map([]).alias("map_empty_list"),
)
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert row["map_no_args"] == {}
assert row["map_empty_list"] == {}
assert row["map_no_args"] == row["map_empty_list"]
def test_create_map_empty_list_multiple_times(self, spark):
df = spark.createDataFrame([{"id": 1}])
result = df.select(
F.col("id"),
F.create_map([]).alias("map1"),
F.create_map([]).alias("map2"),
F.create_map([]).alias("map3"),
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["map1"] == {}
assert rows[0]["map2"] == {}
assert rows[0]["map3"] == {}
def test_create_map_empty_list_in_union(self, spark):
df1 = spark.createDataFrame([{"id": 1, "val": "a"}])
df2 = spark.createDataFrame([{"id": 2, "val": "b"}])
df1_with_map = df1.withColumn("meta", F.create_map([]))
df2_with_map = df2.withColumn("meta", F.create_map([]))
result = df1_with_map.union(df2_with_map)
rows = result.collect()
assert len(rows) == 2
for row in rows:
assert row["meta"] == {}
def test_create_map_empty_list_show(self, spark):
df = spark.createDataFrame([{"Name": "Alice"}, {"Name": "Bob"}])
df = df.withColumn("NewMap", F.create_map([]))
df.show()
rows = df.collect()
assert len(rows) == 2
assert rows[0]["Name"] == "Alice" and rows[0]["NewMap"] == {}
assert rows[1]["Name"] == "Bob" and rows[1]["NewMap"] == {}
def test_create_map_empty_list_with_computed_columns(self, spark):
df = spark.createDataFrame([{"a": 10, "b": 20}])
result = df.select(
F.col("a"),
F.col("b"),
(F.col("a") + F.col("b")).alias("sum"),
F.create_map([]).alias("meta"),
)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["sum"] == 30
assert rows[0]["meta"] == {}
def test_create_map_empty_list_in_join(self, spark):
df1 = spark.createDataFrame([{"id": 1, "name": "Alice"}])
df2 = spark.createDataFrame([{"id": 1, "age": 25}])
df1_with_map = df1.withColumn("meta", F.create_map([]))
result = df1_with_map.join(df2, "id")
rows = result.collect()
assert len(rows) == 1
assert rows[0]["meta"] == {}
assert rows[0]["name"] == "Alice"
assert rows[0]["age"] == 25