from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
SparkSession = _imports.SparkSession
col = _imports.F.col
to_timestamp = _imports.F.to_timestamp
to_date = _imports.F.to_date
StringType = _imports.StringType
def test_string_cast_schema_is_string_type():
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("user1", "2024-01-01 10:30:00", 100.0)]
df = spark.createDataFrame(data, ["user_id", "timestamp", "value"])
transformed = df.withColumn(
"timestamp_str",
col("timestamp").cast("string"), )
schema = transformed.schema
field_dict = {f.name: f.dataType for f in schema.fields}
assert isinstance(field_dict["timestamp_str"], StringType), (
f"timestamp_str should be StringType in schema, got {type(field_dict['timestamp_str']).__name__}"
)
finally:
spark.stop()
def test_string_cast_works_with_to_timestamp():
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
data = [("user1", "2024-01-01 10:30:00", 100.0)]
df = spark.createDataFrame(data, ["user_id", "timestamp", "value"])
transformed = df.withColumn(
"timestamp_str",
col("timestamp").cast("string"), )
result = transformed.withColumn(
"event_date",
to_date(to_timestamp(col("timestamp_str"), "yyyy-MM-dd HH:mm:ss")),
)
count = result.count()
assert count == 1, "Should have 1 row"
finally:
spark.stop()
def test_string_cast_from_datetime_column():
spark = SparkSession.builder.appName("BugRepro").getOrCreate()
try:
from datetime import datetime
data = [("user1", datetime(2024, 1, 1, 10, 30, 0), 100.0)]
df = spark.createDataFrame(data, ["user_id", "timestamp", "value"])
transformed = df.withColumn(
"timestamp_str",
col("timestamp").cast("string"), )
schema = transformed.schema
field_dict = {f.name: f.dataType for f in schema.fields}
assert isinstance(field_dict["timestamp_str"], StringType), (
f"timestamp_str should be StringType, got {type(field_dict['timestamp_str']).__name__}"
)
count = transformed.count()
assert count == 1, "Should have 1 row"
finally:
spark.stop()