from sparkless.sql import SparkSession, functions as F
def test_datediff_returns_integer_type_not_long():
spark = SparkSession.builder.appName("issue_1404").getOrCreate()
try:
df = spark.createDataFrame(
[("2020-01-10", "2020-01-02"), (None, "2020-01-02")],
["a", "b"],
)
result = df.select(F.datediff(F.col("a"), F.col("b")).alias("dd"))
simple = result.schema.simpleString()
assert "dd:int" in simple, (
f"datediff result must be IntegerType (dd:int), got schema: {simple}"
)
assert "dd:long" not in simple, (
f"datediff result must not be LongType (dd:long), got schema: {simple}"
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["dd"] == 8 assert rows[1]["dd"] is None finally:
spark.stop()