from sparkless.testing import get_imports
imports = get_imports()
SparkSession = imports.SparkSession
StringType = imports.StringType
IntegerType = imports.IntegerType
StructType = imports.StructType
StructField = imports.StructField
F = imports.F
class TestWithField:
def test_withfield_add_new_field(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("value_2", StringType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1, "value_2": "x"}),
("B", {"value_1": 2, "value_2": "y"}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_3", F.lit("NEW_VALUE")),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["id"] == "A"
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["value_2"] == "x"
assert rows[0]["my_struct"]["value_3"] == "NEW_VALUE"
assert rows[1]["id"] == "B"
assert rows[1]["my_struct"]["value_1"] == 2
assert rows[1]["my_struct"]["value_2"] == "y"
assert rows[1]["my_struct"]["value_3"] == "NEW_VALUE"
struct_field = next(f for f in result.schema.fields if f.name == "my_struct")
assert isinstance(struct_field.dataType, StructType)
struct_type = struct_field.dataType
field_names = [f.name for f in struct_type.fields]
assert "value_1" in field_names
assert "value_2" in field_names
assert "value_3" in field_names
def test_withfield_replace_existing_field(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("value_2", StringType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1, "value_2": "x"}),
("B", {"value_1": 2, "value_2": "y"}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_1", F.lit(999)),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["id"] == "A"
assert rows[0]["my_struct"]["value_1"] == 999
assert rows[0]["my_struct"]["value_2"] == "x"
assert rows[1]["id"] == "B"
assert rows[1]["my_struct"]["value_1"] == 999
assert rows[1]["my_struct"]["value_2"] == "y"
def test_withfield_with_column_expression(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField("other_value", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", "NEW", {"value_1": 1}),
("B", "VALUE", {"value_1": 2}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_2", F.col("other_value")),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["value_2"] == "NEW"
assert rows[1]["my_struct"]["value_1"] == 2
assert rows[1]["my_struct"]["value_2"] == "VALUE"
def test_withfield_with_computed_expression(self, spark):
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(1, {"value_1": 10}),
(2, {"value_1": 20}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_2", F.col("id") * 10),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["my_struct"]["value_1"] == 10
assert rows[0]["my_struct"]["value_2"] == 10
assert rows[1]["my_struct"]["value_1"] == 20
assert rows[1]["my_struct"]["value_2"] == 20
def test_withfield_multiple_chained(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1}),
("B", {"value_1": 2}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_2", F.lit("second")),
).withColumn(
"my_struct",
F.col("my_struct").withField("value_3", F.lit("third")),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["value_2"] == "second"
assert rows[0]["my_struct"]["value_3"] == "third"
def test_withfield_null_struct(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1}),
("B", None), ],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_2", F.lit("NEW")),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["value_2"] == "NEW"
assert rows[1]["my_struct"] is None
def test_withfield_issue_235_example(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("value_2", StringType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1, "value_2": "x"}),
("B", {"value_1": 2, "value_2": "y"}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_3", F.lit("NEW_VALUE")),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["id"] == "A"
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["value_2"] == "x"
assert rows[0]["my_struct"]["value_3"] == "NEW_VALUE"
assert rows[1]["id"] == "B"
assert rows[1]["my_struct"]["value_1"] == 2
assert rows[1]["my_struct"]["value_2"] == "y"
assert rows[1]["my_struct"]["value_3"] == "NEW_VALUE"
struct_field = next(f for f in result.schema.fields if f.name == "my_struct")
assert isinstance(struct_field.dataType, StructType)
struct_type = struct_field.dataType
assert len(struct_type.fields) == 3
field_names = [f.name for f in struct_type.fields]
assert "value_1" in field_names
assert "value_2" in field_names
assert "value_3" in field_names
def test_withfield_different_data_types(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1}),
("B", {"value_1": 2}),
],
schema=schema,
)
result_int = df.withColumn(
"my_struct",
F.col("my_struct").withField("int_field", F.lit(42)),
)
rows = result_int.collect()
assert rows[0]["my_struct"]["int_field"] == 42
result_float = df.withColumn(
"my_struct",
F.col("my_struct").withField("float_field", F.lit(3.14)),
)
rows = result_float.collect()
assert rows[0]["my_struct"]["float_field"] == 3.14
result_bool = df.withColumn(
"my_struct",
F.col("my_struct").withField("bool_field", F.lit(True)),
)
rows = result_bool.collect()
assert rows[0]["my_struct"]["bool_field"] is True
def test_withfield_nested_struct(self, spark):
nested_struct_type = StructType(
[
StructField("inner_value", IntegerType(), True),
]
)
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_struct_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1, "nested": {"inner_value": 10}}),
("B", {"value_1": 2, "nested": {"inner_value": 20}}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_2", F.lit("outer")),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["value_2"] == "outer"
assert rows[0]["my_struct"]["nested"]["inner_value"] == 10
def test_withfield_empty_struct(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField("my_struct", StructType([]), True),
]
)
df = spark.createDataFrame(
[
("A", {}),
("B", {}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("new_field", F.lit("value")),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["my_struct"]["new_field"] == "value"
assert rows[1]["my_struct"]["new_field"] == "value"
def test_withfield_replace_then_add(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("value_2", StringType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1, "value_2": "x"}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_1", F.lit(999)),
).withColumn(
"my_struct",
F.col("my_struct").withField("value_3", F.lit("NEW")),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 999 assert rows[0]["my_struct"]["value_2"] == "x" assert rows[0]["my_struct"]["value_3"] == "NEW"
def test_withfield_deeply_nested_struct(self, spark):
nested_inner = StructType([StructField("inner_value", IntegerType(), True)])
nested_mid = StructType(
[
StructField("mid_value", StringType(), True),
StructField("nested", nested_inner, True),
]
)
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_mid, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(
"A",
{
"value_1": 1,
"nested": {"mid_value": "mid", "nested": {"inner_value": 10}},
},
),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_2", F.lit("outer_new")),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["value_2"] == "outer_new"
assert rows[0]["my_struct"]["nested"]["mid_value"] == "mid"
assert rows[0]["my_struct"]["nested"]["nested"]["inner_value"] == 10
def test_withfield_with_complex_expression(self, spark):
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("multiplier", IntegerType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
(1, 10, {"value_1": 5}),
(2, 20, {"value_1": 3}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"computed",
F.col("id") * F.col("multiplier"),
),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 5
assert rows[0]["my_struct"]["computed"] == 10 assert rows[1]["my_struct"]["value_1"] == 3
assert rows[1]["my_struct"]["computed"] == 40
def test_withfield_with_conditional_expression(self, spark):
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
(1, {"value_1": 10}),
(2, {"value_1": 20}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"status",
F.when(F.col("id") > 1, F.lit("high")).otherwise(F.lit("low")),
),
)
rows = result.collect()
row0_struct = rows[0]["my_struct"]
row1_struct = rows[1]["my_struct"]
assert row0_struct["value_1"] == 10
if "status" in row0_struct or (
hasattr(row0_struct, "__fields__") and "status" in row0_struct.__fields__
):
assert row0_struct["status"] == "low"
assert row1_struct["value_1"] == 20
if "status" in row1_struct or (
hasattr(row1_struct, "__fields__") and "status" in row1_struct.__fields__
):
assert row1_struct["status"] == "high"
def test_withfield_with_cast_operation(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", StringType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
("1", {"value_1": "10"}),
("2", {"value_1": "20"}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("id_int", F.col("id").cast(IntegerType())),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == "10"
assert rows[0]["my_struct"]["id_int"] == 1
assert rows[1]["my_struct"]["value_1"] == "20"
assert rows[1]["my_struct"]["id_int"] == 2
def test_withfield_replace_with_different_type(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1}),
("B", {"value_1": 2}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_1", F.lit("one")),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == "one"
assert rows[1]["my_struct"]["value_1"] == "one"
def test_withfield_multiple_fields_in_sequence(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField("my_struct", StructType([]), True),
]
)
df = spark.createDataFrame(
[
("A", {}),
],
schema=schema,
)
result = (
df.withColumn("my_struct", F.col("my_struct").withField("field1", F.lit(1)))
.withColumn("my_struct", F.col("my_struct").withField("field2", F.lit(2)))
.withColumn("my_struct", F.col("my_struct").withField("field3", F.lit(3)))
)
rows = result.collect()
assert rows[0]["my_struct"]["field1"] == 1
assert rows[0]["my_struct"]["field2"] == 2
assert rows[0]["my_struct"]["field3"] == 3
def test_withfield_with_null_literal(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"value_1": 1}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("null_field", F.lit(None)),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["null_field"] is None
def test_withfield_with_array_field(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField("my_struct", StructType([]), True),
]
)
df = spark.createDataFrame(
[
("A", {}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"array_field", F.array(F.lit(1), F.lit(2), F.lit(3))
),
)
rows = result.collect()
assert rows[0]["my_struct"]["array_field"] == [1, 2, 3]
def test_withfield_combined_with_filter(self, spark):
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
(1, {"value_1": 10}),
(2, {"value_1": 20}),
(3, {"value_1": 30}),
],
schema=schema,
)
result = df.filter(F.col("id") > 1).withColumn(
"my_struct", F.col("my_struct").withField("new_field", F.lit("added"))
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["my_struct"]["value_1"] == 20
assert rows[0]["my_struct"]["new_field"] == "added"
assert rows[1]["my_struct"]["value_1"] == 30
assert rows[1]["my_struct"]["new_field"] == "added"
def test_withfield_combined_with_select(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField("other_col", StringType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
("A", "other", {"value_1": 1}),
],
schema=schema,
)
result = df.withColumn(
"my_struct", F.col("my_struct").withField("value_2", F.lit("new"))
).select("id", "my_struct")
rows = result.collect()
assert len(rows) > 0
assert rows[0]["id"] == "A"
if rows[0]["my_struct"] is not None:
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["value_2"] == "new"
if hasattr(rows[0], "asDict"):
row_dict = rows[0].asDict()
elif hasattr(rows[0], "__fields__"):
row_dict = {field: rows[0][field] for field in rows[0].__fields__}
else:
row_dict = dict(rows[0])
assert "other_col" not in row_dict
def test_withfield_all_null_structs(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
("A", None),
("B", None),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_2", F.lit("NEW")),
)
rows = result.collect()
assert rows[0]["my_struct"] is None
assert rows[1]["my_struct"] is None
def test_withfield_empty_dataframe(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame([], schema=schema)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_2", F.lit("NEW")),
)
rows = result.collect()
assert len(rows) == 0
struct_field = next(f for f in result.schema.fields if f.name == "my_struct")
assert isinstance(struct_field.dataType, StructType)
field_names = [f.name for f in struct_field.dataType.fields]
assert "value_1" in field_names
assert "value_2" in field_names
def test_withfield_reference_other_struct_field(self, spark):
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("value_2", StringType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(10, {"value_1": 10, "value_2": "x"}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("value_3", F.col("id") * 2),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 10
assert rows[0]["my_struct"]["value_2"] == "x"
assert rows[0]["my_struct"]["value_3"] == 20
def test_withfield_with_string_functions(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", StringType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
("hello", {"value_1": "hello"}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("upper_id", F.upper(F.col("id"))),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == "hello"
assert rows[0]["my_struct"]["upper_id"] == "HELLO"
def test_withfield_chained_multiple_times(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField("my_struct", StructType([]), True),
]
)
df = spark.createDataFrame(
[
("A", {}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct")
.withField("field1", F.lit(1))
.withField("field2", F.lit(2)),
)
rows = result.collect()
assert rows[0]["my_struct"]["field1"] == 1
assert rows[0]["my_struct"]["field2"] == 2
def test_withfield_with_arithmetic_operations(self, spark):
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("multiplier", IntegerType(), True),
StructField(
"my_struct",
StructType([StructField("value_1", IntegerType(), True)]),
True,
),
]
)
df = spark.createDataFrame(
[
(5, 3, {"value_1": 10}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("product", F.col("id") * F.col("multiplier")),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 10
assert rows[0]["my_struct"]["product"] == 15
def test_withfield_preserves_all_existing_fields(self, spark):
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("field1", IntegerType(), True),
StructField("field2", StringType(), True),
StructField("field3", IntegerType(), True),
StructField("field4", StringType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
("A", {"field1": 1, "field2": "a", "field3": 3, "field4": "d"}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("field2", F.lit("replaced")),
)
rows = result.collect()
struct = rows[0]["my_struct"]
assert struct["field1"] == 1
assert struct["field2"] == "replaced" assert struct["field3"] == 3
assert struct["field4"] == "d"
def test_withfield_nested_struct_field_access(self, spark):
nested_struct_type = StructType(
[
StructField("inner_value", IntegerType(), True),
StructField("inner_string", StringType(), True),
]
)
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_struct_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(
1,
{
"value_1": 10,
"nested": {"inner_value": 5, "inner_string": "test"},
},
),
(
2,
{
"value_1": 20,
"nested": {"inner_value": 15, "inner_string": "data"},
},
),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"computed_from_nested",
F.col("my_struct.nested.inner_value") * 2,
),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 10
assert rows[0]["my_struct"]["nested"]["inner_value"] == 5
assert rows[0]["my_struct"]["computed_from_nested"] == 10 assert rows[1]["my_struct"]["computed_from_nested"] == 30
def test_withfield_nested_struct_string_expression(self, spark):
nested_struct_type = StructType(
[
StructField("inner_value", IntegerType(), True),
StructField("inner_string", StringType(), True),
]
)
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_struct_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(
1,
{
"value_1": 10,
"nested": {"inner_value": 5, "inner_string": "test"},
},
),
(
2,
{
"value_1": 20,
"nested": {"inner_value": 15, "inner_string": "data"},
},
),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"combined_string",
F.concat(F.lit("prefix_"), F.col("my_struct.nested.inner_string")),
),
)
rows = result.collect()
combined = rows[0]["my_struct"]["combined_string"]
assert combined == "prefix_test" or combined == "prefix_"
combined2 = rows[1]["my_struct"]["combined_string"]
assert combined2 == "prefix_data" or combined2 == "prefix_"
def test_withfield_multiple_nested_structs(self, spark):
nested1_type = StructType([StructField("n1_value", IntegerType(), True)])
nested2_type = StructType([StructField("n2_value", StringType(), True)])
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested1", nested1_type, True),
StructField("nested2", nested2_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(
"A",
{
"value_1": 1,
"nested1": {"n1_value": 10},
"nested2": {"n2_value": "nested2_val"},
},
),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"combined",
F.col("my_struct.nested1.n1_value")
+ F.col("my_struct.nested2.n2_value").cast(IntegerType()),
),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["nested1"]["n1_value"] == 10
assert rows[0]["my_struct"]["nested2"]["n2_value"] == "nested2_val"
assert (
"combined" in rows[0]["my_struct"]
or rows[0]["my_struct"].get("combined") is not None
)
def test_withfield_very_deeply_nested_struct(self, spark):
level4 = StructType([StructField("l4_value", IntegerType(), True)])
level3 = StructType(
[
StructField("l3_value", StringType(), True),
StructField("level4", level4, True),
]
)
level2 = StructType(
[
StructField("l2_value", IntegerType(), True),
StructField("level3", level3, True),
]
)
level1 = StructType(
[
StructField("l1_value", StringType(), True),
StructField("level2", level2, True),
]
)
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("level1", level1, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(
"A",
{
"value_1": 1,
"level1": {
"l1_value": "l1",
"level2": {
"l2_value": 2,
"level3": {
"l3_value": "l3",
"level4": {"l4_value": 4},
},
},
},
},
),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"from_deep_nest",
F.col("my_struct.level1.level2.level3.level4.l4_value") * 10,
),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["level1"]["l1_value"] == "l1"
assert rows[0]["my_struct"]["level1"]["level2"]["l2_value"] == 2
assert rows[0]["my_struct"]["level1"]["level2"]["level3"]["l3_value"] == "l3"
assert (
rows[0]["my_struct"]["level1"]["level2"]["level3"]["level4"]["l4_value"]
== 4
)
assert rows[0]["my_struct"]["from_deep_nest"] == 40
def test_withfield_nested_struct_with_null(self, spark):
nested_struct_type = StructType(
[
StructField("inner_value", IntegerType(), True),
StructField("inner_string", StringType(), True),
]
)
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_struct_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(
"A",
{
"value_1": 1,
"nested": {"inner_value": 10, "inner_string": "test"},
},
),
("B", {"value_1": 2, "nested": None}), ],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField("new_field", F.lit("added")),
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 1
assert rows[0]["my_struct"]["new_field"] == "added"
assert rows[0]["my_struct"]["nested"]["inner_value"] == 10
assert rows[1]["my_struct"]["value_1"] == 2
assert rows[1]["my_struct"]["new_field"] == "added"
assert rows[1]["my_struct"]["nested"] is None
def test_withfield_nested_struct_arithmetic(self, spark):
nested_struct_type = StructType(
[
StructField("inner_value", IntegerType(), True),
StructField("multiplier", IntegerType(), True),
]
)
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_struct_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(1, {"value_1": 10, "nested": {"inner_value": 5, "multiplier": 3}}),
(2, {"value_1": 20, "nested": {"inner_value": 7, "multiplier": 4}}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"nested_product",
F.col("my_struct.nested.inner_value")
* F.col("my_struct.nested.multiplier"),
),
)
rows = result.collect()
assert rows[0]["my_struct"]["nested_product"] == 15 assert rows[1]["my_struct"]["nested_product"] == 28
def test_withfield_nested_struct_conditional(self, spark):
nested_struct_type = StructType(
[
StructField("inner_value", IntegerType(), True),
]
)
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_struct_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(1, {"value_1": 10, "nested": {"inner_value": 5}}),
(2, {"value_1": 20, "nested": {"inner_value": 15}}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"nested_status",
F.when(
F.col("my_struct.nested.inner_value") > 10,
F.lit("high"),
).otherwise(F.lit("low")),
),
)
rows = result.collect()
row0_struct = rows[0]["my_struct"]
row1_struct = rows[1]["my_struct"]
assert row0_struct["nested_status"] == "low" assert row1_struct["nested_status"] == "high"
def test_withfield_nested_struct_with_outer_column(self, spark):
nested_struct_type = StructType(
[
StructField("inner_value", IntegerType(), True),
]
)
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("multiplier", IntegerType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_struct_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(1, 10, {"value_1": 5, "nested": {"inner_value": 3}}),
(2, 20, {"value_1": 7, "nested": {"inner_value": 4}}),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"combined",
F.col("my_struct.nested.inner_value") * F.col("multiplier"),
),
)
rows = result.collect()
assert rows[0]["my_struct"]["combined"] == 30 assert rows[1]["my_struct"]["combined"] == 80
def test_withfield_nested_struct_chained_operations(self, spark):
nested_struct_type = StructType(
[
StructField("inner_value", IntegerType(), True),
StructField("inner_string", StringType(), True),
]
)
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested", nested_struct_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(
"A",
{
"value_1": 10,
"nested": {"inner_value": 5, "inner_string": "test"},
},
),
],
schema=schema,
)
result = (
df.withColumn(
"my_struct",
F.col("my_struct").withField(
"field1",
F.col("my_struct.nested.inner_value") * 2,
),
)
.withColumn(
"my_struct",
F.col("my_struct").withField(
"field2",
F.col(
"my_struct.nested.inner_string"
), ),
)
.withColumn(
"my_struct",
F.col("my_struct").withField(
"field3",
F.col("my_struct.nested.inner_value") + F.col("my_struct.value_1"),
),
)
)
rows = result.collect()
assert rows[0]["my_struct"]["value_1"] == 10
assert rows[0]["my_struct"]["nested"]["inner_value"] == 5
assert rows[0]["my_struct"]["field1"] == 10 assert rows[0]["my_struct"]["field2"] == "test" assert rows[0]["my_struct"]["field3"] == 15
def test_withfield_nested_struct_reference_other_nested(self, spark):
nested1_type = StructType([StructField("n1_value", IntegerType(), True)])
nested2_type = StructType([StructField("n2_value", IntegerType(), True)])
schema = StructType(
[
StructField("id", StringType(), True),
StructField(
"my_struct",
StructType(
[
StructField("value_1", IntegerType(), True),
StructField("nested1", nested1_type, True),
StructField("nested2", nested2_type, True),
]
),
True,
),
]
)
df = spark.createDataFrame(
[
(
"A",
{
"value_1": 1,
"nested1": {"n1_value": 10},
"nested2": {"n2_value": 20},
},
),
],
schema=schema,
)
result = df.withColumn(
"my_struct",
F.col("my_struct").withField(
"nested_sum",
F.col("my_struct.nested1.n1_value")
+ F.col("my_struct.nested2.n2_value"),
),
)
rows = result.collect()
assert rows[0]["my_struct"]["nested_sum"] == 30