from datetime import date, datetime
from sparkless.testing import get_imports
_imports = get_imports()
SparkSession = _imports.SparkSession
F = _imports.F
col = F.col
to_timestamp = F.to_timestamp
to_date = F.to_date
current_date = F.current_date
class TestIssue139DatetimeValidationCompatibility:
def test_validation_with_datetime_column(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [
("e1", "2024-01-15T10:30:00"),
("e2", "2024-01-16T11:00:00"),
("e3", "2024-01-17T12:00:00"),
]
df = spark.createDataFrame(data, ["event_id", "event_date_str"])
transformed = df.withColumn(
"event_date",
to_timestamp(col("event_date_str"), "yyyy-MM-dd'T'HH:mm:ss"),
)
validation_result = transformed.filter(col("event_date").isNotNull())
count = validation_result.count()
assert count == 3, f"Expected 3 valid rows, got {count}"
rows = validation_result.collect()
assert len(rows) == 3
for row in rows:
assert row["event_date"] is not None
assert isinstance(row["event_date"], datetime)
finally:
spark.stop()
def test_validation_with_date_column_and_operations(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("p1", "John", "1990-01-15")]
df = spark.createDataFrame(
data, ["patient_id", "first_name", "date_of_birth"]
)
transformed = df.withColumn(
"birth_date", to_date(col("date_of_birth"), "yyyy-MM-dd")
)
validation_result = transformed.filter(
col("birth_date").isNotNull() & (col("birth_date") < current_date())
)
count = validation_result.count()
assert count == 1, f"Expected 1 valid row, got {count}"
rows = validation_result.collect()
assert len(rows) == 1
assert rows[0]["birth_date"] is not None
assert isinstance(rows[0]["birth_date"], date)
finally:
spark.stop()
def test_validation_with_datetime_comparison(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [
("t1", "2024-01-10T10:00:00"),
("t2", "2024-01-15T10:00:00"),
("t3", "2024-01-20T10:00:00"),
]
df = spark.createDataFrame(data, ["txn_id", "txn_date_str"])
transformed = df.withColumn(
"txn_date", to_timestamp(col("txn_date_str"), "yyyy-MM-dd'T'HH:mm:ss")
)
start_date = datetime(2024, 1, 12, 0, 0, 0)
end_date = datetime(2024, 1, 18, 23, 59, 59)
validation_result = transformed.filter(
(col("txn_date") >= start_date) & (col("txn_date") <= end_date)
)
count = validation_result.count()
assert count == 1, f"Expected 1 valid row, got {count}"
rows = validation_result.collect()
assert len(rows) == 1
assert rows[0]["txn_id"] == "t2"
assert isinstance(rows[0]["txn_date"], datetime)
finally:
spark.stop()
def test_validation_with_multiple_datetime_columns(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [
("r1", "2024-01-10T10:00:00", "2024-01-15T10:00:00"),
("r2", "2024-01-12T10:00:00", "2024-01-18T10:00:00"),
]
df = spark.createDataFrame(
data, ["record_id", "start_date_str", "end_date_str"]
)
transformed = df.withColumn(
"start_date",
to_timestamp(col("start_date_str"), "yyyy-MM-dd'T'HH:mm:ss"),
).withColumn(
"end_date", to_timestamp(col("end_date_str"), "yyyy-MM-dd'T'HH:mm:ss")
)
validation_result = transformed.filter(col("end_date") > col("start_date"))
count = validation_result.count()
assert count == 2, f"Expected 2 valid rows, got {count}"
rows = validation_result.collect()
assert len(rows) == 2
for row in rows:
assert isinstance(row["start_date"], datetime)
assert isinstance(row["end_date"], datetime)
assert row["end_date"] > row["start_date"]
finally:
spark.stop()
def test_validation_with_datetime_after_column_rename(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("inv1", "2024-01-15T10:30:00")]
df = spark.createDataFrame(data, ["inventory_id", "snapshot_date"])
transformed = (
df.withColumn(
"snapshot_date_parsed",
to_timestamp(col("snapshot_date"), "yyyy-MM-dd'T'HH:mm:ss"),
)
.withColumnRenamed("inventory_id", "id")
.drop("snapshot_date")
)
validation_result = transformed.filter(
col("snapshot_date_parsed").isNotNull()
)
count = validation_result.count()
assert count == 1, f"Expected 1 valid row, got {count}"
rows = validation_result.collect()
assert len(rows) == 1
assert isinstance(rows[0]["snapshot_date_parsed"], datetime)
finally:
spark.stop()