from datetime import datetime
from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
SparkSession = _imports.SparkSession
F = _imports.F
col = F.col
to_timestamp = F.to_timestamp
regexp_replace = F.regexp_replace
class TestIssue138ColumnDropReference:
def test_drop_column_after_transform(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("inv1", "2024-01-15T10:30:00", 100)]
df = spark.createDataFrame(
data, ["inventory_id", "snapshot_date", "quantity_on_hand"]
)
transformed = (
df.withColumn(
"snapshot_date_parsed",
to_timestamp(
regexp_replace(col("snapshot_date"), r"\.\d+", "").cast(
"string"
),
"yyyy-MM-dd'T'HH:mm:ss",
),
).drop("snapshot_date") )
result = transformed.select("inventory_id", "snapshot_date_parsed")
count = result.count()
assert count == 1, f"Expected count 1, got {count}"
rows = result.collect()
assert len(rows) == 1
assert rows[0]["inventory_id"] == "inv1"
assert isinstance(rows[0]["snapshot_date_parsed"], datetime)
finally:
spark.stop()
def test_drop_multiple_columns_after_transform(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("inv1", "2024-01-15T10:30:00", 100, "temp")]
df = spark.createDataFrame(
data,
["inventory_id", "snapshot_date", "quantity_on_hand", "temp_col"],
)
transformed = (
df.withColumn(
"snapshot_date_parsed",
to_timestamp(
regexp_replace(col("snapshot_date"), r"\.\d+", "").cast(
"string"
),
"yyyy-MM-dd'T'HH:mm:ss",
),
)
.drop("snapshot_date")
.drop("temp_col")
)
result = transformed.select(
"inventory_id", "snapshot_date_parsed", "quantity_on_hand"
)
count = result.count()
assert count == 1
rows = result.collect()
assert len(rows) == 1
assert "snapshot_date" not in rows[0]
assert "temp_col" not in rows[0]
finally:
spark.stop()
def test_drop_then_select(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("inv1", "2024-01-15T10:30:00", 100)]
df = spark.createDataFrame(
data, ["inventory_id", "snapshot_date", "quantity_on_hand"]
)
transformed = df.withColumn(
"snapshot_date_parsed",
to_timestamp(
regexp_replace(col("snapshot_date"), r"\.\d+", "").cast("string"),
"yyyy-MM-dd'T'HH:mm:ss",
),
).drop("snapshot_date")
result = transformed.select(
"inventory_id", "snapshot_date_parsed", "quantity_on_hand"
)
count = result.count()
assert count == 1
finally:
spark.stop()
def test_drop_then_filter(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("inv1", "2024-01-15T10:30:00", 100)]
df = spark.createDataFrame(
data, ["inventory_id", "snapshot_date", "quantity_on_hand"]
)
transformed = df.withColumn(
"snapshot_date_parsed",
to_timestamp(
regexp_replace(col("snapshot_date"), r"\.\d+", "").cast("string"),
"yyyy-MM-dd'T'HH:mm:ss",
),
).drop("snapshot_date")
result = transformed.filter(col("inventory_id") == "inv1")
count = result.count()
assert count == 1
finally:
spark.stop()