import os
from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
SparkSession = _imports.SparkSession
F = _imports.F
Window = _imports.Window
class TestIssue339ColumnSubscript:
def test_column_subscript_single_field(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn("Extract-E1", F.col("StructVal")["E1"])
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert alice_row["Extract-E1"] == 1
assert bob_row["Extract-E1"] == 3
finally:
spark.stop()
def test_column_subscript_multiple_fields(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn("Extract-E1", F.col("StructVal")["E1"]).withColumn(
"Extract-E2", F.col("StructVal")["E2"]
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert alice_row["Extract-E1"] == 1
assert alice_row["Extract-E2"] == 2
assert bob_row["Extract-E1"] == 3
assert bob_row["Extract-E2"] == 4
finally:
spark.stop()
def test_column_subscript_in_select(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.select("Name", F.col("StructVal")["E1"].alias("E1"))
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["E1"] == 1
finally:
spark.stop()
def test_column_subscript_in_filter(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.filter(F.col("StructVal")["E1"] > 2)
rows = result.collect()
assert len(rows) == 1
assert rows[0]["Name"] == "Bob"
finally:
spark.stop()
def test_column_subscript_equals_dot_notation(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result_subscript = df.withColumn("Extract-E1", F.col("StructVal")["E1"])
result_dot = df.withColumn("Extract-E1", F.col("StructVal.E1"))
rows_subscript = result_subscript.collect()
rows_dot = result_dot.collect()
assert len(rows_subscript) == len(rows_dot) == 2
for sub_row, dot_row in zip(rows_subscript, rows_dot):
assert sub_row["Extract-E1"] == dot_row["Extract-E1"]
finally:
spark.stop()
def test_column_subscript_with_nested_struct(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{
"Name": "Alice",
"Outer": {"Inner": {"E1": 1, "E2": 2}},
},
{
"Name": "Bob",
"Outer": {"Inner": {"E1": 3, "E2": 4}},
},
]
)
result = df.withColumn("Extract-E1", F.col("Outer")["Inner"]["E1"])
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Extract-E1"] == 1
finally:
spark.stop()
def test_column_subscript_with_alias(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn("Extract-E1", F.col("StructVal").alias("SV")["E1"])
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Extract-E1"] == 1
finally:
spark.stop()
def test_column_subscript_with_null_struct(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": None},
]
)
result = df.withColumn("Extract-E1", F.col("StructVal")["E1"])
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Extract-E1"] == 1
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Extract-E1"] is None
finally:
spark.stop()
def test_column_subscript_with_null_field(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": None}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn("Extract-E2", F.col("StructVal")["E2"])
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Extract-E2"] is None
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Extract-E2"] == 4
finally:
spark.stop()
def test_column_subscript_with_orderBy(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn("Extract-E1", F.col("StructVal")["E1"]).orderBy(
F.col("Extract-E1")
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Extract-E1"] == 1
assert rows[1]["Extract-E1"] == 3
finally:
spark.stop()
def test_column_subscript_with_groupBy(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 3}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = (
df.withColumn("Extract-E1", F.col("StructVal")["E1"])
.groupBy("Extract-E1")
.agg(F.count("*").alias("count"))
)
rows = result.collect()
assert len(rows) == 2
count_1 = next(row for row in rows if row["Extract-E1"] == 1)["count"]
count_3 = next(row for row in rows if row["Extract-E1"] == 3)["count"]
assert count_1 == 2
assert count_3 == 1
finally:
spark.stop()
def test_column_subscript_with_join(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
df2 = spark.createDataFrame(
[
{"E1": 1, "Type": "A"},
{"E1": 3, "Type": "B"},
]
)
df1_with_extract = df1.withColumn("Extract-E1", F.col("StructVal")["E1"])
result = df1_with_extract.join(
df2, df1_with_extract["Extract-E1"] == df2["E1"], how="left"
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Type"] == "A"
finally:
spark.stop()
def test_column_subscript_with_case_when(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn(
"Category",
F.when(F.col("StructVal")["E1"] > 2, "High").otherwise("Low"),
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Category"] == "Low"
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Category"] == "High"
finally:
spark.stop()
def test_column_subscript_with_comparison(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn("IsHigh", F.col("StructVal")["E1"] > 2)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["IsHigh"] is False
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["IsHigh"] is True
finally:
spark.stop()
def test_column_subscript_with_arithmetic(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn(
"Sum", F.col("StructVal")["E1"] + F.col("StructVal")["E2"]
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Sum"] == 3 bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Sum"] == 7 finally:
spark.stop()
def test_column_subscript_error_non_string_key(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
]
)
backend = (
os.getenv("MOCK_SPARK_TEST_BACKEND")
or os.getenv("SPARKLESS_TEST_BACKEND")
or "sparkless"
)
if backend == "pyspark":
try:
result = df.withColumn("Extract", F.col("StructVal")[1])
result.collect()
except Exception:
pass
else:
try:
result = df.withColumn("Extract", F.col("StructVal")[1])
result.collect()
assert False, "Should have raised TypeError"
except TypeError as e:
assert "string keys" in str(e) or "subscript access" in str(e)
finally:
spark.stop()
def test_column_subscript_schema_verification(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
]
)
result = df.withColumn("Extract-E1", F.col("StructVal")["E1"])
schema = result.schema
field_names = [field.name for field in schema.fields]
assert "Name" in field_names
assert "StructVal" in field_names
assert "Extract-E1" in field_names
finally:
spark.stop()
def test_column_subscript_with_multiple_struct_columns(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{
"Name": "Alice",
"Struct1": {"E1": 1},
"Struct2": {"E2": 2},
},
{
"Name": "Bob",
"Struct1": {"E1": 3},
"Struct2": {"E2": 4},
},
]
)
result = df.withColumn("Extract-E1", F.col("Struct1")["E1"]).withColumn(
"Extract-E2", F.col("Struct2")["E2"]
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Extract-E1"] == 1
assert alice_row["Extract-E2"] == 2
finally:
spark.stop()
def test_column_subscript_with_computed_column(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = df.withColumn(
"NewStruct", F.struct(F.col("Name"), F.lit("X"))
).withColumn("Extract", F.col("NewStruct")["Name"])
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Extract"] == "Alice"
finally:
spark.stop()
def test_column_subscript_deeply_nested_struct(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{
"Name": "Alice",
"Level1": {
"Level2": {"Level3": {"Value": 100}},
"Other": 1,
},
},
{
"Name": "Bob",
"Level1": {
"Level2": {"Level3": {"Value": 200}},
"Other": 2,
},
},
]
)
result = df.withColumn(
"DeepValue", F.col("Level1")["Level2"]["Level3"]["Value"]
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["DeepValue"] == 100
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["DeepValue"] == 200
finally:
spark.stop()
def test_column_subscript_with_union(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df1 = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
]
)
df2 = spark.createDataFrame(
[
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
df1_with_extract = df1.withColumn("Extract-E1", F.col("StructVal")["E1"])
df2_with_extract = df2.withColumn("Extract-E1", F.col("StructVal")["E1"])
result = df1_with_extract.union(df2_with_extract).orderBy("Name")
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Extract-E1"] == 1
assert rows[1]["Extract-E1"] == 3
finally:
spark.stop()
def test_column_subscript_with_distinct(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
df_with_extract = df.withColumn("Extract-E1", F.col("StructVal")["E1"])
result = df_with_extract.select("Name", "Extract-E1").distinct()
rows = result.collect()
assert len(rows) == 2
extract_values = {
row["Extract-E1"] for row in rows if row["Extract-E1"] is not None
}
assert extract_values == {1, 3}
finally:
spark.stop()
def test_column_subscript_with_cast(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": "1", "E2": "2"}},
{"Name": "Bob", "StructVal": {"E1": "3", "E2": "4"}},
]
)
result = df.withColumn(
"Extract-E1-Int", F.col("StructVal")["E1"].cast("int")
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Extract-E1-Int"] == 1
finally:
spark.stop()
def test_column_subscript_with_window_function(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "Type": "A", "StructVal": {"E1": 3, "E2": 4}},
{"Name": "Charlie", "Type": "B", "StructVal": {"E1": 5, "E2": 6}},
]
)
w = Window.partitionBy("Type").orderBy("Name")
result = (
df.withColumn("Rank", F.row_number().over(w))
.withColumn("Extract-E1", F.col("StructVal")["E1"])
.orderBy("Type", "Name")
)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["Rank"] == 1
assert rows[0]["Extract-E1"] == 1
finally:
spark.stop()
def test_column_subscript_with_multiple_aggregations(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "Type": "A", "StructVal": {"E1": 3, "E2": 4}},
{"Name": "Charlie", "Type": "B", "StructVal": {"E1": 5, "E2": 6}},
]
)
result = (
df.withColumn("Extract-E1", F.col("StructVal")["E1"])
.groupBy("Type")
.agg(
F.avg("Extract-E1").alias("AvgE1"),
F.max("Extract-E1").alias("MaxE1"),
F.min("Extract-E1").alias("MinE1"),
)
)
rows = result.collect()
assert len(rows) == 2
type_a = next(row for row in rows if row["Type"] == "A")
assert type_a["AvgE1"] == 2.0 assert type_a["MaxE1"] == 3
assert type_a["MinE1"] == 1
finally:
spark.stop()
def test_column_subscript_with_coalesce(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": None, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
]
)
result = (
df.withColumn("Extract-E1", F.col("StructVal")["E1"])
.withColumn("Extract-E2", F.col("StructVal")["E2"])
.withColumn(
"Coalesced", F.coalesce(F.col("Extract-E1"), F.col("Extract-E2"))
)
)
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Coalesced"] in [2, None]
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Coalesced"] == 3 finally:
spark.stop()
def test_column_subscript_with_when_otherwise_nested(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
{"Name": "Charlie", "StructVal": {"E1": 5, "E2": 6}},
]
)
result = df.withColumn(
"Category",
F.when(F.col("StructVal")["E1"] < 2, "Low")
.when(F.col("StructVal")["E1"] < 4, "Medium")
.otherwise("High"),
)
rows = result.collect()
assert len(rows) == 3
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["Category"] == "Low"
bob_row = next(row for row in rows if row["Name"] == "Bob")
assert bob_row["Category"] == "Medium"
charlie_row = next(row for row in rows if row["Name"] == "Charlie")
assert charlie_row["Category"] == "High"
finally:
spark.stop()
def test_column_subscript_with_string_operations(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": "hello", "E2": "world"}},
{"Name": "Bob", "StructVal": {"E1": "test", "E2": "data"}},
]
)
result = df.withColumn(
"UpperE1", F.upper(F.col("StructVal")["E1"])
).withColumn("LengthE2", F.length(F.col("StructVal")["E2"]))
rows = result.collect()
assert len(rows) == 2
alice_row = next(row for row in rows if row["Name"] == "Alice")
assert alice_row["UpperE1"] == "HELLO"
assert alice_row["LengthE2"] == 5
finally:
spark.stop()
def test_column_subscript_with_limit(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "StructVal": {"E1": 3, "E2": 4}},
{"Name": "Charlie", "StructVal": {"E1": 5, "E2": 6}},
]
)
result = (
df.withColumn("Extract-E1", F.col("StructVal")["E1"])
.orderBy("Extract-E1")
.limit(2)
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["Extract-E1"] == 1
assert rows[1]["Extract-E1"] == 3
finally:
spark.stop()
def test_column_subscript_chained_operations(self):
spark = SparkSession.builder.appName("issue-339").getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Type": "A", "StructVal": {"E1": 1, "E2": 2}},
{"Name": "Bob", "Type": "A", "StructVal": {"E1": 3, "E2": 4}},
{"Name": "Charlie", "Type": "B", "StructVal": {"E1": 5, "E2": 6}},
]
)
result = (
df.withColumn("Extract-E1", F.col("StructVal")["E1"])
.withColumn("Extract-E2", F.col("StructVal")["E2"])
.withColumn("Sum", F.col("Extract-E1") + F.col("Extract-E2"))
.filter(F.col("Sum") > 5)
.groupBy("Type")
.agg(F.avg("Sum").alias("AvgSum"))
.orderBy("Type")
)
rows = result.collect()
assert len(rows) == 2
type_a = next(row for row in rows if row["Type"] == "A")
assert type_a["AvgSum"] == 7.0
type_b = next(row for row in rows if row["Type"] == "B")
assert type_b["AvgSum"] == 11.0
finally:
spark.stop()