from tests.fixtures.spark_imports import get_spark_imports
imports = get_spark_imports()
SparkSession = imports.SparkSession
StringType = imports.StringType
IntegerType = imports.IntegerType
DoubleType = imports.DoubleType
StructType = imports.StructType
StructField = imports.StructField
F = imports.F
class TestStringArithmetic:
def test_string_division_by_numeric_literal(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 5)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 2.0 assert rows[1]["result"] == 4.0
def test_numeric_literal_divided_by_string(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
{"string_1": "5"},
],
schema=schema,
)
result = df.withColumn("result", F.lit(100) / F.col("string_1"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 10.0 assert rows[1]["result"] == 20.0
def test_string_addition_with_numeric(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.5"},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") + 5)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 15.5 assert rows[1]["result"] == 25.0
def test_string_subtraction_with_numeric(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.5"},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") - 3)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 7.5 assert rows[1]["result"] == 17.0
def test_string_multiplication_with_numeric(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.5"},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") * 2)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 21.0 assert rows[1]["result"] == 40.0
def test_string_modulo_with_numeric(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10"},
{"string_1": "7"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") % 3)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 1.0 assert rows[1]["result"] == 1.0
def test_string_arithmetic_with_string_column(self, spark):
schema = StructType(
[
StructField("string_1", StringType(), True),
StructField("string_2", StringType(), True),
]
)
df = spark.createDataFrame(
[
{"string_1": "10.5", "string_2": "2"},
{"string_1": "20", "string_2": "4"},
],
schema=schema,
)
result = df.withColumn("div", F.col("string_1") / F.col("string_2"))
rows = result.collect()
assert rows[0]["div"] == 5.25 assert rows[1]["div"] == 5.0
result = df.withColumn("add", F.col("string_1") + F.col("string_2"))
rows = result.collect()
assert rows[0]["add"] == 12.5 assert rows[1]["add"] == 24.0
result = df.withColumn("mul", F.col("string_1") * F.col("string_2"))
rows = result.collect()
assert rows[0]["mul"] == 21.0 assert rows[1]["mul"] == 80.0
def test_string_arithmetic_with_invalid_strings(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
{"string_1": "invalid"},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 5)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["result"] == 2.0 assert rows[1]["result"] is None assert rows[2]["result"] == 4.0
def test_string_arithmetic_with_null_strings(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
{"string_1": None},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 5)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["result"] == 2.0 assert rows[1]["result"] is None assert rows[2]["result"] == 4.0
def test_string_arithmetic_result_type(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 5)
result_field = next(f for f in result.schema.fields if f.name == "result")
assert isinstance(result_field.dataType, DoubleType)
def test_string_arithmetic_chained_operations(self, spark):
schema = StructType(
[
StructField("string_1", StringType(), True),
StructField("string_2", StringType(), True),
]
)
df = spark.createDataFrame(
[
{"string_1": "10.0", "string_2": "2"},
],
schema=schema,
)
result = df.withColumn(
"result", (F.col("string_1") + F.col("string_2")) * 3 - 5
)
rows = result.collect()
assert rows[0]["result"] == 31.0
def test_string_arithmetic_with_integer_strings(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10"},
{"string_1": "25"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 2.5)
rows = result.collect()
assert rows[0]["result"] == 4.0 assert rows[1]["result"] == 10.0
def test_string_arithmetic_with_float_strings(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.5"},
{"string_1": "25.75"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") * 2)
rows = result.collect()
assert rows[0]["result"] == 21.0 assert rows[1]["result"] == 51.5
def test_string_arithmetic_division_by_zero(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 0)
rows = result.collect()
result0 = rows[0]["result"]
result1 = rows[1]["result"]
assert result0 is None or (
isinstance(result0, float)
and (result0 == float("inf") or result0 == float("-inf"))
)
assert result1 is None or (
isinstance(result1, float)
and (result1 == float("inf") or result1 == float("-inf"))
)
def test_string_arithmetic_with_negative_numbers(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "-10.5"},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") + 5)
rows = result.collect()
assert rows[0]["result"] == -5.5 assert rows[1]["result"] == 25.0
def test_string_arithmetic_with_scientific_notation(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "1e2"}, {"string_1": "2.5e1"}, ],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 5)
rows = result.collect()
assert rows[0]["result"] == 20.0 assert rows[1]["result"] == 5.0
def test_string_arithmetic_with_empty_strings(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": ""},
{"string_1": "10.0"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 5)
rows = result.collect()
assert rows[0]["result"] is None assert rows[1]["result"] == 2.0
def test_string_arithmetic_with_whitespace(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": " 10.5 "}, {"string_1": "20"}, ],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") * 2)
rows = result.collect()
assert rows[0]["result"] == 21.0 or rows[0]["result"] is None
assert rows[1]["result"] == 40.0
def test_string_arithmetic_with_very_large_numbers(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "1e10"}, {"string_1": "999999999999.99"}, ],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") / 1000)
rows = result.collect()
assert rows[0]["result"] == 10000000.0 assert abs(rows[1]["result"] - 999999999.999) < 0.01
def test_string_arithmetic_in_select(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
{"string_1": "20"},
],
schema=schema,
)
result = df.select((F.col("string_1") / 5).alias("result"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 2.0
assert rows[1]["result"] == 4.0
def test_string_arithmetic_with_filter(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
{"string_1": "20"},
{"string_1": "30"},
],
schema=schema,
)
result = df.filter(F.col("string_1") / 5 > 3)
rows = result.collect()
assert len(rows) == 2 assert rows[0]["string_1"] == "20"
assert rows[1]["string_1"] == "30"
def test_string_arithmetic_mixed_with_numeric_column(self, spark):
schema = StructType(
[
StructField("string_1", StringType(), True),
StructField("numeric_1", IntegerType(), True),
]
)
df = spark.createDataFrame(
[
{"string_1": "10.5", "numeric_1": 2},
{"string_1": "20", "numeric_1": 3},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") * F.col("numeric_1"))
rows = result.collect()
assert rows[0]["result"] == 21.0 assert rows[1]["result"] == 60.0
def test_string_arithmetic_with_when_otherwise(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.0"},
{"string_1": "20"},
],
schema=schema,
)
result = df.withColumn(
"result",
F.when(F.col("string_1") / 5 > 3, F.col("string_1") * 2).otherwise(
F.col("string_1") / 2
),
)
rows = result.collect()
assert rows[0]["result"] == 5.0 assert rows[1]["result"] == 40.0
def test_string_arithmetic_chained_with_cast(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "10.5"},
],
schema=schema,
)
result = df.withColumn("result", (F.col("string_1") / 2).cast("int"))
rows = result.collect()
assert rows[0]["result"] == 5
def test_string_arithmetic_all_operations_comprehensive(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "12.5"},
],
schema=schema,
)
result = (
df.withColumn("add", F.col("string_1") + 3)
.withColumn("sub", F.col("string_1") - 3)
.withColumn("mul", F.col("string_1") * 2)
.withColumn("div", F.col("string_1") / 2)
.withColumn("mod", F.col("string_1") % 5)
)
rows = result.collect()
assert rows[0]["add"] == 15.5 assert rows[0]["sub"] == 9.5 assert rows[0]["mul"] == 25.0 assert rows[0]["div"] == 6.25 assert rows[0]["mod"] == 2.5
def test_string_arithmetic_with_zero_string(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "0"},
{"string_1": "0.0"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") * 10)
rows = result.collect()
assert rows[0]["result"] == 0.0 assert rows[1]["result"] == 0.0
def test_string_arithmetic_decimal_precision(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "0.1"},
{"string_1": "0.2"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") + F.lit(0.3))
rows = result.collect()
assert abs(rows[0]["result"] - 0.4) < 0.0001 assert abs(rows[1]["result"] - 0.5) < 0.0001
def test_string_arithmetic_with_negative_zero(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "-0"},
{"string_1": "-0.0"},
],
schema=schema,
)
result = df.withColumn("result", F.col("string_1") + 5)
rows = result.collect()
assert rows[0]["result"] == 5.0 assert rows[1]["result"] == 5.0
def test_string_arithmetic_complex_expression(self, spark):
schema = StructType(
[
StructField("string_1", StringType(), True),
StructField("string_2", StringType(), True),
]
)
df = spark.createDataFrame(
[
{"string_1": "10", "string_2": "5"},
],
schema=schema,
)
result = df.withColumn(
"result",
((F.col("string_1") + F.col("string_2")) * 2 - F.col("string_1"))
/ F.col("string_2"),
)
rows = result.collect()
assert rows[0]["result"] == 4.0
def test_string_arithmetic_with_orderby(self, spark):
schema = StructType([StructField("string_1", StringType(), True)])
df = spark.createDataFrame(
[
{"string_1": "30"},
{"string_1": "10"},
{"string_1": "20"},
],
schema=schema,
)
result = df.orderBy(F.col("string_1") / 10)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["string_1"] == "10" assert rows[1]["string_1"] == "20" assert rows[2]["string_1"] == "30"
def test_string_arithmetic_with_groupby_aggregation(self, spark):
schema = StructType(
[
StructField("category", StringType(), True),
StructField("string_value", StringType(), True),
]
)
df = spark.createDataFrame(
[
{"category": "A", "string_value": "10"},
{"category": "A", "string_value": "20"},
{"category": "B", "string_value": "30"},
],
schema=schema,
)
result = df.groupBy("category").agg(F.sum(F.col("string_value")).alias("total"))
rows = result.collect()
assert len(rows) == 2
row_a = next(r for r in rows if r["category"] == "A")
row_b = next(r for r in rows if r["category"] == "B")
assert row_a["total"] == 30.0 assert row_b["total"] == 30.0