import pytest
from sparkless.testing import get_imports
imports = get_imports()
SparkSession = imports.SparkSession
StringType = imports.StringType
IntegerType = imports.IntegerType
StructType = imports.StructType
StructField = imports.StructField
F = imports.F
class TestNaFill:
def test_na_fill_scalar(self, spark):
df = spark.createDataFrame(
[
{"key": "A", "value": "1"},
{"key": None, "value": "2"},
{"key": "C", "value": None},
]
)
result = df.na.fill("0")
rows = result.collect()
assert len(rows) == 3
assert rows[0]["key"] == "A"
assert rows[0]["value"] == "1"
assert rows[1]["key"] == "0" assert rows[1]["value"] == "2"
assert rows[2]["key"] == "C"
assert rows[2]["value"] == "0"
def test_na_fill_dict(self, spark):
df = spark.createDataFrame(
[
{"col1": None, "col2": "X", "col3": None},
{"col1": "A", "col2": None, "col3": "Y"},
]
)
result = df.na.fill({"col1": "DEFAULT1", "col3": "DEFAULT3"})
rows = result.collect()
assert len(rows) == 2
assert rows[0]["col1"] == "DEFAULT1" assert rows[0]["col2"] == "X" assert rows[0]["col3"] == "DEFAULT3" assert rows[1]["col1"] == "A" assert rows[1]["col2"] is None assert rows[1]["col3"] == "Y"
def test_na_fill_subset(self, spark):
StructType = imports.StructType
StructField = imports.StructField
StringType = imports.StringType
schema = StructType(
[
StructField("key", StringType(), nullable=True),
StructField("value", StringType(), nullable=True),
StructField("other", StringType(), nullable=True),
]
)
df = spark.createDataFrame(
[
{"key": None, "value": None, "other": "X"},
{"key": "A", "value": None, "other": None},
],
schema=schema,
)
result = df.na.fill("FILLED", subset=["key", "value"])
rows = result.collect()
assert len(rows) == 2
assert rows[0]["key"] == "FILLED" assert rows[0]["value"] == "FILLED" assert rows[0]["other"] == "X" assert rows[1]["key"] == "A" assert rows[1]["value"] == "FILLED" assert rows[1]["other"] is None
def test_na_fill_after_join(self, spark):
df_left = spark.createDataFrame(
[
{"key": "123", "value_left": 1},
{
"key": "456",
"value_left": None,
}, ]
)
df_right = spark.createDataFrame(
[
{
"key": "123",
"value_right": None,
}, {"key": "456", "value_right": 2},
]
)
df = df_left.join(df_right, on="key", how="inner").na.fill(0)
rows = df.collect()
assert len(rows) == 2
row1 = next((r for r in rows if r["key"] == "123"), None)
assert row1 is not None
assert row1["value_left"] == 1
assert row1["value_right"] == 0
row2 = next((r for r in rows if r["key"] == "456"), None)
assert row2 is not None
assert row2["value_left"] == 0 assert row2["value_right"] == 2
def test_na_fill_nonexistent_column(self, spark):
StructType = imports.StructType
StructField = imports.StructField
StringType = imports.StringType
schema = StructType(
[
StructField("col1", StringType(), nullable=True),
StructField("col2", StringType()),
]
)
df = spark.createDataFrame([{"col1": None, "col2": "X"}], schema=schema)
with pytest.raises(Exception):
df.na.fill("FILLED", subset=["col1", "nonexistent"])
with pytest.raises(Exception):
df.na.fill({"col1": "FILLED", "nonexistent": "VALUE"})
def test_na_fill_different_types(self, spark):
from sparkless.testing import get_imports
imports = get_imports()
StructType = imports.StructType
StructField = imports.StructField
StringType = imports.StringType
IntegerType = imports.IntegerType
DoubleType = imports.DoubleType
BooleanType = imports.BooleanType
schema_int = StructType(
[
StructField("key", IntegerType()),
StructField("value", IntegerType(), nullable=True),
]
)
df_int = spark.createDataFrame([{"key": 1, "value": None}], schema=schema_int)
result_int = df_int.na.fill(0)
rows_int = result_int.collect()
assert rows_int[0]["value"] == 0
schema_str = StructType(
[
StructField("key", StringType()),
StructField("value", StringType(), nullable=True),
]
)
df_str = spark.createDataFrame([{"key": "A", "value": None}], schema=schema_str)
result_str = df_str.na.fill("DEFAULT")
rows_str = result_str.collect()
assert rows_str[0]["value"] == "DEFAULT"
schema_float = StructType(
[
StructField("key", DoubleType()),
StructField("value", DoubleType(), nullable=True),
]
)
df_float = spark.createDataFrame(
[{"key": 1.5, "value": None}], schema=schema_float
)
result_float = df_float.na.fill(0.0)
rows_float = result_float.collect()
assert rows_float[0]["value"] == 0.0
schema_bool = StructType(
[
StructField("key", BooleanType()),
StructField("value", BooleanType(), nullable=True),
]
)
df_bool = spark.createDataFrame(
[{"key": True, "value": None}], schema=schema_bool
)
result_bool = df_bool.na.fill(False)
rows_bool = result_bool.collect()
assert rows_bool[0]["value"] is False
def test_na_fill_chained_operations(self, spark):
df = spark.createDataFrame(
[
{"name": None, "age": 25, "city": None},
{"name": "Bob", "age": None, "city": "NYC"},
]
)
result = df.na.fill("UNKNOWN", subset=["name"]).na.fill("N/A", subset=["city"])
rows = result.collect()
assert len(rows) == 2
assert rows[0]["name"] == "UNKNOWN" assert rows[0]["age"] == 25 assert rows[0]["city"] == "N/A" assert rows[1]["name"] == "Bob" assert rows[1]["age"] is None assert rows[1]["city"] == "NYC"
def test_na_fill_pyspark_parity(self, spark):
df = spark.createDataFrame(
[
{"key": "123", "value_left": 1},
{"key": "456", "value_left": None},
]
)
result = df.na.fill(0)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["key"] == "123"
assert rows[0]["value_left"] == 1
assert rows[1]["key"] == "456"
assert rows[1]["value_left"] == 0
schema = result.schema
value_left_field = next(
(f for f in schema.fields if f.name == "value_left"), None
)
assert value_left_field is not None
def test_na_fill_empty_dataframe(self, spark):
from sparkless.testing import get_imports
imports = get_imports()
StructType = imports.StructType
StructField = imports.StructField
StringType = imports.StringType
schema = StructType(
[
StructField("col1", StringType()),
StructField("col2", StringType()),
]
)
df = spark.createDataFrame([], schema=schema)
result = df.na.fill("DEFAULT")
rows = result.collect()
assert len(rows) == 0
assert len(result.schema.fields) == 2
def test_na_fill_no_nulls(self, spark):
df = spark.createDataFrame(
[
{"key": "A", "value": 1},
{"key": "B", "value": 2},
]
)
result = df.na.fill(0)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["key"] == "A"
assert rows[0]["value"] == 1
assert rows[1]["key"] == "B"
assert rows[1]["value"] == 2
def test_na_fill_equivalence_with_fillna(self, spark):
df = spark.createDataFrame(
[
{"col1": None, "col2": "X", "col3": None},
{"col1": "A", "col2": None, "col3": "Y"},
]
)
result_na = df.na.fill("FILLED")
result_fillna = df.fillna("FILLED")
rows_na = result_na.collect()
rows_fillna = result_fillna.collect()
assert len(rows_na) == len(rows_fillna)
for i, row_na in enumerate(rows_na):
row_fillna = rows_fillna[i]
assert row_na["col1"] == row_fillna["col1"]
assert row_na["col2"] == row_fillna["col2"]
assert row_na["col3"] == row_fillna["col3"]
def test_na_fill_subset_string(self, spark):
from sparkless.testing import get_imports
imports = get_imports()
StructType = imports.StructType
StructField = imports.StructField
StringType = imports.StringType
schema = StructType(
[
StructField("key", StringType(), nullable=True),
StructField("value", StringType(), nullable=True),
]
)
df = spark.createDataFrame(
[
{"key": None, "value": None},
{"key": "A", "value": None},
],
schema=schema,
)
result = df.na.fill("FILLED", subset="key")
rows = result.collect()
assert rows[0]["key"] == "FILLED" assert rows[0]["value"] is None assert rows[1]["key"] == "A" assert rows[1]["value"] is None
def test_na_fill_subset_tuple(self, spark):
from sparkless.testing import get_imports
imports = get_imports()
StructType = imports.StructType
StructField = imports.StructField
StringType = imports.StringType
schema = StructType(
[
StructField("col1", StringType(), nullable=True),
StructField("col2", StringType(), nullable=True),
StructField("col3", StringType()),
]
)
df = spark.createDataFrame(
[
{"col1": None, "col2": None, "col3": "X"},
],
schema=schema,
)
result = df.na.fill("FILLED", subset=("col1", "col2"))
rows = result.collect()
assert rows[0]["col1"] == "FILLED" assert rows[0]["col2"] == "FILLED" assert rows[0]["col3"] == "X"
def test_na_fill_all_nulls(self, spark):
from sparkless.testing import get_imports
imports = get_imports()
StructType = imports.StructType
StructField = imports.StructField
StringType = imports.StringType
schema = StructType(
[
StructField("col1", StringType()),
StructField("col2", StringType()),
]
)
df = spark.createDataFrame(
[
{"col1": None, "col2": None},
{"col1": None, "col2": None},
],
schema=schema,
)
result = df.na.fill("ALL_FILLED")
rows = result.collect()
assert len(rows) == 2
assert rows[0]["col1"] == "ALL_FILLED"
assert rows[0]["col2"] == "ALL_FILLED"
assert rows[1]["col1"] == "ALL_FILLED"
assert rows[1]["col2"] == "ALL_FILLED"