from sparkless.testing import get_imports
_imports = get_imports()
SparkSession = _imports.SparkSession
F = _imports.F
col = F.col
to_date = F.to_date
datediff = F.datediff
current_date = F.current_date
floor = F.floor
class TestIssue137DatetimeValidation:
def test_datetime_validation_with_age_calculation(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("p1", "John", "1990-01-15")]
df = spark.createDataFrame(
data, ["patient_id", "first_name", "date_of_birth"]
)
transformed = df.withColumn(
"birth_date", to_date(col("date_of_birth"), "yyyy-MM-dd")
).withColumn(
"age", floor(datediff(current_date(), col("birth_date")) / 365.25)
)
validation_result = transformed.filter(
col("patient_id").isNotNull()
& col("age").isNotNull()
& (col("age") >= 0)
)
count = validation_result.count()
assert count == 1, f"Expected 1 valid row, got {count}"
rows = validation_result.collect()
assert len(rows) == 1
assert rows[0]["patient_id"] == "p1"
assert rows[0]["age"] is not None
assert rows[0]["age"] >= 0
finally:
spark.stop()
def test_datetime_validation_simple_filter(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("p1", "1990-01-15"), ("p2", "1985-05-20")]
df = spark.createDataFrame(data, ["patient_id", "date_of_birth"])
transformed = df.withColumn(
"birth_date", to_date(col("date_of_birth"), "yyyy-MM-dd")
)
result = transformed.filter(col("birth_date").isNotNull())
count = result.count()
assert count == 2, f"Expected 2 valid rows, got {count}"
finally:
spark.stop()
def test_datetime_validation_with_multiple_conditions(self):
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("p1", "John", "1990-01-15"), ("p2", "Jane", "1985-05-20")]
df = spark.createDataFrame(
data, ["patient_id", "first_name", "date_of_birth"]
)
transformed = df.withColumn(
"birth_date", to_date(col("date_of_birth"), "yyyy-MM-dd")
).withColumn(
"age", floor(datediff(current_date(), col("birth_date")) / 365.25)
)
result = transformed.filter(
col("patient_id").isNotNull()
& col("birth_date").isNotNull()
& col("age").isNotNull()
& (col("age") > 0)
)
count = result.count()
assert count == 2, f"Expected 2 valid rows, got {count}"
finally:
spark.stop()