from __future__ import annotations
from tests.fixtures.spark_imports import get_spark_imports
_imports = get_spark_imports()
F = _imports.F
def test_string_division_by_numeric_literal_robin(spark) -> None:
df = spark.createDataFrame(
[{"string_1": "10.0"}, {"string_1": "20"}],
)
result = df.withColumn("result", F.col("string_1").cast("double") / 5)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 2.0
assert rows[1]["result"] == 4.0
def test_numeric_literal_divided_by_string_robin(spark) -> None:
df = spark.createDataFrame(
[{"string_1": "10.0"}, {"string_1": "5"}],
)
result = df.withColumn("result", F.lit(100) / F.col("string_1").cast("double"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 10.0
assert rows[1]["result"] == 20.0
def test_string_arithmetic_with_invalid_strings_robin(spark) -> None:
df = spark.createDataFrame(
[{"string_1": "10.0"}, {"string_1": "invalid"}, {"string_1": "20"}],
)
result = df.withColumn(
"result",
F.col("string_1").cast("double") / 5,
)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["result"] == 2.0
assert rows[1]["result"] is None
assert rows[2]["result"] == 4.0