from sparkless.testing import get_imports
_imports = get_imports()
SparkSession = _imports.SparkSession
F = _imports.F
class TestColumnSubscriptParity:
def test_column_subscript_single_field_parity(self):
spark = SparkSession.builder.appName("issue-339-parity").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn("Extract-E1", F.col("StructVal")["E1"])
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert alice_row["Extract-E1"] == 1
assert bob_row["Extract-E1"] == 3
finally:
spark.stop()
def test_column_subscript_multiple_fields_parity(self):
spark = SparkSession.builder.appName("issue-339-parity").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn("Extract-E1", F.col("StructVal")["E1"]).withColumn(
"Extract-E2", F.col("StructVal")["E2"]
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert alice_row["Extract-E1"] == 1
assert alice_row["Extract-E2"] == 2
assert bob_row["Extract-E1"] == 3
assert bob_row["Extract-E2"] == 4
finally:
spark.stop()
def test_column_subscript_equals_dot_notation_parity(self):
spark = SparkSession.builder.appName("issue-339-parity").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result_subscript = df.withColumn("Extract-E1", F.col("StructVal")["E1"])
result_dot = df.withColumn("Extract-E1", F.col("StructVal.E1"))
rows_subscript = result_subscript.collect()
rows_dot = result_dot.collect()
assert len(rows_subscript) == len(rows_dot) == 2
for sub_row, dot_row in zip(rows_subscript, rows_dot):
assert sub_row["Extract-E1"] == dot_row["Extract-E1"]
finally:
spark.stop()