from tests.fixtures.spark_imports import get_spark_imports
imports = get_spark_imports()
SparkSession = imports.SparkSession
DoubleType = imports.DoubleType
IntegerType = imports.IntegerType
StringType = imports.StringType
StructType = imports.StructType
StructField = imports.StructField
F = imports.F
class TestChainedArithmetic:
def test_reverse_multiplication(self, spark):
schema = StructType([StructField("number_2", DoubleType(), True)])
df = spark.createDataFrame(
[
{"number_2": 1.0},
{"number_2": 2.0},
{"number_2": 3.0},
],
schema=schema,
)
result = df.withColumn("result", 2 * F.col("number_2"))
rows = result.collect()
assert len(rows) == 3
assert rows[0]["result"] == 2.0 assert rows[1]["result"] == 4.0 assert rows[2]["result"] == 6.0
def test_reverse_addition(self, spark):
schema = StructType([StructField("number_2", DoubleType(), True)])
df = spark.createDataFrame(
[
{"number_2": 1.0},
{"number_2": 2.0},
],
schema=schema,
)
result = df.withColumn("result", 2 + F.col("number_2"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 3.0 assert rows[1]["result"] == 4.0
def test_reverse_subtraction(self, spark):
schema = StructType([StructField("number_2", DoubleType(), True)])
df = spark.createDataFrame(
[
{"number_2": 1.0},
{"number_2": 2.0},
],
schema=schema,
)
result = df.withColumn("result", 2 - F.col("number_2"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 1.0 assert rows[1]["result"] == 0.0
def test_reverse_division(self, spark):
schema = StructType([StructField("number_2", DoubleType(), True)])
df = spark.createDataFrame(
[
{"number_2": 1.0},
{"number_2": 2.0},
],
schema=schema,
)
result = df.withColumn("result", 2 / F.col("number_2"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 2.0 assert rows[1]["result"] == 1.0
def test_reverse_modulo(self, spark):
schema = StructType([StructField("number_2", DoubleType(), True)])
df = spark.createDataFrame(
[
{"number_2": 3.0},
{"number_2": 2.0},
],
schema=schema,
)
result = df.withColumn("result", 2 % F.col("number_2"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 2.0 assert rows[1]["result"] == 0.0
def test_chained_arithmetic_issue_237_example(self, spark):
schema = StructType(
[
StructField("number_1", DoubleType(), True),
StructField("number_2", DoubleType(), True),
]
)
df = spark.createDataFrame(
[
{"number_1": 1.0, "number_2": 1.0},
{"number_1": 2.0, "number_2": 2.0},
{"number_1": 3.0, "number_2": 3.0},
],
schema=schema,
)
result = df.withColumn(
"result", F.col("number_1") + 2 * F.col("number_2") + 0.01
)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["result"] == 3.01 assert rows[1]["result"] == 6.01 assert rows[2]["result"] == 9.01
def test_complex_chained_operations(self, spark):
schema = StructType(
[
StructField("a", DoubleType(), True),
StructField("b", DoubleType(), True),
]
)
df = spark.createDataFrame(
[
{"a": 10.0, "b": 5.0},
{"a": 20.0, "b": 10.0},
],
schema=schema,
)
result = df.withColumn("result", F.col("a") * 3 - F.col("b") / 2 + 1.5)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] == 29.0
assert rows[1]["result"] == 56.5
def test_all_reverse_operations(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 2.0},
],
schema=schema,
)
result = (
df.withColumn("add", 5 + F.col("col"))
.withColumn("sub", 5 - F.col("col"))
.withColumn("mul", 5 * F.col("col"))
.withColumn("div", 5 / F.col("col"))
.withColumn("mod", 5 % F.col("col"))
)
rows = result.collect()
assert rows[0]["add"] == 7.0 assert rows[0]["sub"] == 3.0 assert rows[0]["mul"] == 10.0 assert rows[0]["div"] == 2.5 assert rows[0]["mod"] == 1.0
def test_reverse_operations_with_integers(self, spark):
schema = StructType([StructField("col", IntegerType(), True)])
df = spark.createDataFrame(
[
{"col": 3},
],
schema=schema,
)
result = df.withColumn("result", 10 * F.col("col"))
rows = result.collect()
assert rows[0]["result"] == 30
def test_reverse_operations_with_floats(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 2.0},
],
schema=schema,
)
result = df.withColumn("result", 3.5 * F.col("col"))
rows = result.collect()
assert rows[0]["result"] == 7.0
def test_nested_chained_operations(self, spark):
schema = StructType(
[
StructField("a", DoubleType(), True),
StructField("b", DoubleType(), True),
StructField("c", DoubleType(), True),
]
)
df = spark.createDataFrame(
[
{"a": 1.0, "b": 2.0, "c": 3.0},
],
schema=schema,
)
result = df.withColumn(
"result", 2 * F.col("a") + 3 * F.col("b") - 4 * F.col("c") + 1.0
)
rows = result.collect()
assert rows[0]["result"] == -3.0
def test_reverse_operations_in_select(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 5.0},
],
schema=schema,
)
result = df.select((10 * F.col("col")).alias("result"))
rows = result.collect()
assert rows[0]["result"] == 50.0
def test_reverse_operations_in_filter(self, spark):
schema = StructType([StructField("value", DoubleType(), True)])
df = spark.createDataFrame(
[
{"value": 5.0},
{"value": 10.0},
{"value": 15.0},
],
schema=schema,
)
result = df.filter(2 * F.col("value") > 10)
rows = result.collect()
assert len(rows) == 2 assert rows[0]["value"] == 10.0
assert rows[1]["value"] == 15.0
def test_mixed_forward_and_reverse_operations(self, spark):
schema = StructType(
[
StructField("a", DoubleType(), True),
StructField("b", DoubleType(), True),
]
)
df = spark.createDataFrame(
[
{"a": 2.0, "b": 3.0},
],
schema=schema,
)
result = df.withColumn("result", F.col("a") * 5 + 10 * F.col("b"))
rows = result.collect()
assert rows[0]["result"] == 40.0
def test_reverse_operations_with_null_values(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 2.0},
{"col": None},
{"col": 4.0},
],
schema=schema,
)
result = df.withColumn("result", 10 * F.col("col"))
rows = result.collect()
assert rows[0]["result"] == 20.0 assert rows[1]["result"] is None assert rows[2]["result"] == 40.0
def test_reverse_operations_with_negative_numbers(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 5.0},
],
schema=schema,
)
result = df.withColumn("result", -2 * F.col("col"))
rows = result.collect()
assert rows[0]["result"] == -10.0
def test_reverse_operations_chained_with_arithmetic(self, spark):
schema = StructType(
[
StructField("x", DoubleType(), True),
StructField("y", DoubleType(), True),
]
)
df = spark.createDataFrame(
[
{"x": 1.0, "y": 2.0},
],
schema=schema,
)
result = df.withColumn(
"result", 2 * F.col("x") + 3 * F.col("y") - 4 * F.col("x") + 1.0
)
rows = result.collect()
assert rows[0]["result"] == 5.0
def test_operator_precedence(self, spark):
schema = StructType([StructField("a", DoubleType(), True)])
df = spark.createDataFrame(
[
{"a": 2.0},
],
schema=schema,
)
result = df.withColumn("result", 1 + 2 * F.col("a"))
rows = result.collect()
assert rows[0]["result"] == 5.0
def test_operator_precedence_with_parentheses_equivalent(self, spark):
schema = StructType([StructField("a", DoubleType(), True)])
df = spark.createDataFrame(
[
{"a": 2.0},
],
schema=schema,
)
result = df.withColumn("result", (1 + 2) * F.col("a"))
rows = result.collect()
assert rows[0]["result"] == 6.0
def test_all_operators_in_single_expression(self, spark):
schema = StructType([StructField("a", DoubleType(), True)])
df = spark.createDataFrame(
[
{"a": 3.0},
],
schema=schema,
)
result = df.withColumn(
"result", 10 + 2 * F.col("a") - 5 / F.col("a") + 3 % F.col("a")
)
rows = result.collect()
assert abs(rows[0]["result"] - 14.333333333333334) < 0.0001
def test_reverse_operations_with_zero(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 5.0},
],
schema=schema,
)
result = (
df.withColumn("mul", 0 * F.col("col"))
.withColumn("add", 0 + F.col("col"))
.withColumn("sub", 0 - F.col("col"))
)
rows = result.collect()
assert rows[0]["mul"] == 0.0 assert rows[0]["add"] == 5.0 assert rows[0]["sub"] == -5.0
def test_reverse_operations_with_one(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 5.0},
],
schema=schema,
)
result = df.withColumn("mul", 1 * F.col("col")).withColumn(
"div", 1 / F.col("col")
)
rows = result.collect()
assert rows[0]["mul"] == 5.0 assert rows[0]["div"] == 0.2
def test_reverse_operations_with_negative_literals(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 5.0},
],
schema=schema,
)
result = (
df.withColumn("mul", -2 * F.col("col"))
.withColumn("add", -3 + F.col("col"))
.withColumn("sub", -4 - F.col("col"))
)
rows = result.collect()
assert rows[0]["mul"] == -10.0 assert rows[0]["add"] == 2.0 assert rows[0]["sub"] == -9.0
def test_reverse_operations_with_decimal_literals(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 4.0},
],
schema=schema,
)
result = (
df.withColumn("mul", 0.5 * F.col("col"))
.withColumn("div", 0.5 / F.col("col"))
.withColumn("add", 0.5 + F.col("col"))
)
rows = result.collect()
assert rows[0]["mul"] == 2.0 assert rows[0]["div"] == 0.125 assert rows[0]["add"] == 4.5
def test_chained_operations_with_mixed_types(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 2.0},
],
schema=schema,
)
result = df.withColumn("result", 3 * F.col("col") + 1.5 * F.col("col"))
rows = result.collect()
assert rows[0]["result"] == 9.0
def test_very_long_chained_expression(self, spark):
schema = StructType([StructField("a", DoubleType(), True)])
df = spark.createDataFrame(
[
{"a": 2.0},
],
schema=schema,
)
result = df.withColumn(
"result",
1 + 2 * F.col("a") + 3 * F.col("a") - 4 * F.col("a") + 5 * F.col("a"),
)
rows = result.collect()
assert rows[0]["result"] == 13.0
def test_reverse_operations_in_orderby(self, spark):
schema = StructType([StructField("value", DoubleType(), True)])
df = spark.createDataFrame(
[
{"value": 3.0},
{"value": 1.0},
{"value": 2.0},
],
schema=schema,
)
result = df.orderBy(2 * F.col("value"))
rows = result.collect()
assert len(rows) == 3
assert rows[0]["value"] == 1.0 assert rows[1]["value"] == 2.0 assert rows[2]["value"] == 3.0
def test_reverse_operations_in_groupby_aggregation(self, spark):
schema = StructType(
[
StructField("category", StringType(), True),
StructField("value", DoubleType(), True),
]
)
df = spark.createDataFrame(
[
{"category": "A", "value": 2.0},
{"category": "A", "value": 3.0},
{"category": "B", "value": 4.0},
],
schema=schema,
)
result = df.groupBy("category").agg(F.sum(2 * F.col("value")).alias("total"))
rows = result.collect()
assert len(rows) == 2
row_a = next(r for r in rows if r["category"] == "A")
row_b = next(r for r in rows if r["category"] == "B")
assert row_a["total"] == 10.0 assert row_b["total"] == 8.0
def test_reverse_operations_with_when_otherwise(self, spark):
schema = StructType([StructField("value", DoubleType(), True)])
df = spark.createDataFrame(
[
{"value": 5.0},
{"value": 15.0},
],
schema=schema,
)
result = df.withColumn(
"result",
F.when(2 * F.col("value") > 10, 3 * F.col("value")).otherwise(
1 * F.col("value")
),
)
rows = result.collect()
assert rows[0]["result"] == 5.0 assert rows[1]["result"] == 45.0
def test_reverse_operations_with_cast(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 2.5},
],
schema=schema,
)
result = df.withColumn("result", (2 * F.col("col")).cast("int"))
rows = result.collect()
assert rows[0]["result"] == 5
def test_reverse_operations_with_string_columns(self, spark):
schema = StructType([StructField("string_col", StringType(), True)])
df = spark.createDataFrame(
[
{"string_col": "10.0"},
{"string_col": "20"},
],
schema=schema,
)
result = df.withColumn("result", 2 * F.col("string_col"))
rows = result.collect()
assert rows[0]["result"] == 20.0 assert rows[1]["result"] == 40.0
def test_reverse_operations_division_by_zero(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 0.0},
{"col": 5.0},
],
schema=schema,
)
result = df.withColumn("result", 10 / F.col("col"))
rows = result.collect()
assert rows[0]["result"] is None assert rows[1]["result"] == 2.0
def test_reverse_operations_modulo_by_zero(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 0.0},
{"col": 3.0},
],
schema=schema,
)
result = df.withColumn("result", 10 % F.col("col"))
rows = result.collect()
assert rows[0]["result"] is None assert rows[1]["result"] == 1.0
def test_reverse_operations_with_null_literals(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 5.0},
{"col": None},
{"col": 10.0},
],
schema=schema,
)
result = df.withColumn("result", 2 * F.col("col"))
rows = result.collect()
assert rows[0]["result"] == 10.0 assert rows[1]["result"] is None assert rows[2]["result"] == 20.0
def test_very_complex_nested_expression(self, spark):
schema = StructType(
[
StructField("a", DoubleType(), True),
StructField("b", DoubleType(), True),
StructField("c", DoubleType(), True),
]
)
df = spark.createDataFrame(
[
{"a": 1.0, "b": 2.0, "c": 3.0},
],
schema=schema,
)
result = df.withColumn(
"result",
2 * F.col("a")
+ 3 * F.col("b")
- 4 * F.col("c")
+ 5 * F.col("a")
- 6 * F.col("b")
+ 7 * F.col("c"),
)
rows = result.collect()
assert rows[0]["result"] == 10.0
def test_reverse_operations_with_select_multiple_columns(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 3.0},
],
schema=schema,
)
result = df.select(
(2 * F.col("col")).alias("double"),
(3 + F.col("col")).alias("add"),
(10 - F.col("col")).alias("sub"),
)
rows = result.collect()
assert rows[0]["double"] == 6.0 assert rows[0]["add"] == 6.0 assert rows[0]["sub"] == 7.0
def test_reverse_operations_preserve_precision(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 0.1},
],
schema=schema,
)
result = df.withColumn("result", 3 * F.col("col"))
rows = result.collect()
assert abs(rows[0]["result"] - 0.3) < 0.0001
def test_reverse_operations_with_large_numbers(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 1000000.0},
],
schema=schema,
)
result = df.withColumn("result", 2 * F.col("col"))
rows = result.collect()
assert rows[0]["result"] == 2000000.0
def test_reverse_operations_with_small_numbers(self, spark):
schema = StructType([StructField("col", DoubleType(), True)])
df = spark.createDataFrame(
[
{"col": 0.0001},
],
schema=schema,
)
result = df.withColumn("result", 1000 * F.col("col"))
rows = result.collect()
assert abs(rows[0]["result"] - 0.1) < 0.0001