from sparkless.testing import get_imports
_imports = get_imports()
SparkSession = _imports.SparkSession
F = _imports.F
T = _imports
class TestIssue290UdfMultipleArguments:
def test_udf_two_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"Name": "Alice", "Value1": 1, "Value2": 2},
{"Name": "Bob", "Value1": 2, "Value2": 3},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda x, y: x + y, T.IntegerType())
result = df.withColumn(
"FinalValue", my_udf(F.col("Value1"), F.col("Value2"))
)
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["FinalValue"] == 3
bob_row = next((r for r in rows if r["Name"] == "Bob"), None)
assert bob_row is not None
assert bob_row["FinalValue"] == 5
finally:
spark.stop()
def test_udf_two_arguments_string_names(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"Name": "Alice", "Value1": 1, "Value2": 2},
{"Name": "Bob", "Value1": 2, "Value2": 3},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda x, y: x + y, T.IntegerType())
result = df.withColumn("FinalValue", my_udf("Value1", "Value2"))
rows = result.collect()
assert len(rows) == 2
alice_row = next((r for r in rows if r["Name"] == "Alice"), None)
assert alice_row is not None
assert alice_row["FinalValue"] == 3
finally:
spark.stop()
def test_udf_three_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": 1, "b": 2, "c": 3},
{"a": 4, "b": 5, "c": 6},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda x, y, z: x + y + z, T.IntegerType())
result = df.withColumn("sum", my_udf("a", "b", "c"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["sum"] == 6
assert rows[1]["sum"] == 15
finally:
spark.stop()
def test_udf_multiply_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"x": 3, "y": 4},
{"x": 5, "y": 6},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda a, b: a * b, T.IntegerType())
result = df.withColumn("product", my_udf("x", "y"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["product"] == 12
assert rows[1]["product"] == 30
finally:
spark.stop()
def test_udf_string_concatenation(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"first": "Hello", "second": "World"},
{"first": "Foo", "second": "Bar"},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda a, b: f"{a} {b}", T.StringType())
result = df.withColumn("combined", my_udf("first", "second"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["combined"] == "Hello World"
assert rows[1]["combined"] == "Foo Bar"
finally:
spark.stop()
def test_udf_with_nulls(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": 1, "b": 2},
{"a": None, "b": 3},
{"a": 4, "b": None},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda x, y: (x or 0) + (y or 0), T.IntegerType())
result = df.withColumn("sum", my_udf("a", "b"))
rows = result.collect()
assert len(rows) == 3
assert rows[0]["sum"] == 3
assert rows[1]["sum"] == 3
assert rows[2]["sum"] == 4
finally:
spark.stop()
def test_udf_in_select(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"x": 10, "y": 20},
{"x": 30, "y": 40},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda a, b: a + b, T.IntegerType())
result = df.select("x", "y", my_udf("x", "y").alias("sum"))
rows = result.collect()
assert len(rows) == 2
assert "sum" in rows[0]
assert rows[0]["sum"] is not None
assert rows[1]["sum"] is not None
finally:
spark.stop()
def test_udf_mixed_types(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"name": "Alice", "age": 25, "score": 95.5},
{"name": "Bob", "age": 30, "score": 87.0},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(
lambda n, a, s: f"{n} is {a} years old with score {s}", T.StringType()
)
result = df.withColumn("info", my_udf("name", "age", "score"))
rows = result.collect()
assert len(rows) == 2
assert "Alice" in rows[0]["info"]
assert "25" in rows[0]["info"]
assert "Bob" in rows[1]["info"]
finally:
spark.stop()
def test_udf_single_argument_still_works(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [{"value": 5}, {"value": 10}]
df = spark.createDataFrame(data=data)
square = F.udf(lambda x: x * x, T.IntegerType())
result = df.withColumn("squared", square("value"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["squared"] == 25
assert rows[1]["squared"] == 100
finally:
spark.stop()
def test_udf_four_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": 1, "b": 2, "c": 3, "d": 4},
{"a": 5, "b": 6, "c": 7, "d": 8},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda w, x, y, z: w + x + y + z, T.IntegerType())
result = df.withColumn("total", my_udf("a", "b", "c", "d"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["total"] == 10
assert rows[1]["total"] == 26
finally:
spark.stop()
def test_udf_with_computed_columns(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"x": 2, "y": 3},
{"x": 4, "y": 5},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda a, b: a * b, T.IntegerType())
result = df.withColumn("product", my_udf(F.col("x") * 2, F.col("y") + 1))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["product"] is not None
assert rows[1]["product"] is not None
finally:
spark.stop()
def test_udf_decorator_pattern(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": 1, "b": 2},
{"a": 3, "b": 4},
]
df = spark.createDataFrame(data=data)
@F.udf(returnType=T.IntegerType())
def add_udf(x, y):
return x + y
result = df.withColumn("sum", add_udf("a", "b"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["sum"] == 3
assert rows[1]["sum"] == 7
finally:
spark.stop()
def test_udf_empty_dataframe(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
schema = _imports.StructType(
[
_imports.StructField("a", _imports.IntegerType(), True),
_imports.StructField("b", _imports.IntegerType(), True),
]
)
df = spark.createDataFrame([], schema)
my_udf = F.udf(lambda x, y: x + y, T.IntegerType())
result = df.withColumn("sum", my_udf("a", "b"))
rows = result.collect()
assert len(rows) == 0
assert "sum" in result.columns
finally:
spark.stop()
def test_udf_five_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5},
{"a": 6, "b": 7, "c": 8, "d": 9, "e": 10},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda v, w, x, y, z: v + w + x + y + z, T.IntegerType())
result = df.withColumn("total", my_udf("a", "b", "c", "d", "e"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["total"] == 15
assert rows[1]["total"] == 40
finally:
spark.stop()
def test_udf_with_float_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"x": 1.5, "y": 2.5, "z": 3.5},
{"x": 4.25, "y": 5.75, "z": 6.25},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda a, b, c: a * b + c, T.DoubleType())
result = df.withColumn("result", my_udf("x", "y", "z"))
rows = result.collect()
assert len(rows) == 2
assert abs(rows[0]["result"] - 7.25) < 0.01 assert abs(rows[1]["result"] - 30.6875) < 0.01
finally:
spark.stop()
def test_udf_with_boolean_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": True, "b": False, "c": True},
{"a": False, "b": True, "c": False},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(
lambda x, y, z: (x and y) or z, T.BooleanType()
) result = df.withColumn("result", my_udf("a", "b", "c"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] is True assert rows[1]["result"] is False finally:
spark.stop()
def test_udf_in_filter(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"x": 10, "y": 20, "z": 30},
{"x": 5, "y": 15, "z": 25},
{"x": 20, "y": 30, "z": 50},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda a, b: a + b, T.IntegerType())
result = df.filter(my_udf("x", "y") > 20)
rows = result.collect()
assert len(rows) == 2
assert all(r["x"] + r["y"] > 20 for r in rows)
finally:
spark.stop()
def test_udf_in_groupby_agg(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"category": "A", "value1": 10, "value2": 5},
{"category": "A", "value1": 15, "value2": 10},
{"category": "B", "value1": 20, "value2": 15},
{"category": "B", "value1": 25, "value2": 20},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda x, y: x + y, T.IntegerType())
result = df.groupBy("category").agg(
my_udf(F.sum("value1"), F.sum("value2")).alias("total_sum")
)
rows = result.collect()
assert len(rows) >= 0 finally:
spark.stop()
def test_udf_in_orderby(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"name": "Alice", "score1": 80, "score2": 90},
{"name": "Bob", "score1": 90, "score2": 85},
{"name": "Charlie", "score1": 85, "score2": 95},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda x, y: x + y, T.IntegerType())
result = df.withColumn("total_score", my_udf("score1", "score2")).orderBy(
F.col("total_score").desc()
)
rows = result.collect()
assert len(rows) == 3
assert rows[0]["name"] == "Charlie"
assert rows[1]["name"] == "Bob"
assert rows[2]["name"] == "Alice"
finally:
spark.stop()
def test_udf_mixed_string_and_column_objects(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": 1, "b": 2, "c": 3},
{"a": 4, "b": 5, "c": 6},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda x, y, z: x + y + z, T.IntegerType())
result = df.withColumn("sum", my_udf("a", F.col("b"), "c"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["sum"] == 6
assert rows[1]["sum"] == 15
finally:
spark.stop()
def test_udf_nested_with_arithmetic(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"x": 10, "y": 5, "z": 2},
{"x": 20, "y": 10, "z": 3},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(lambda a, b: a * b, T.IntegerType())
result = df.withColumn(
"result", my_udf(F.col("x") + F.col("y"), F.col("z") * 2)
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["result"] is not None
assert rows[1]["result"] is not None
finally:
spark.stop()
def test_udf_with_date_operations(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
from datetime import datetime
data = [
{"date1": datetime(2023, 1, 1), "date2": datetime(2023, 1, 5)},
{"date1": datetime(2023, 2, 1), "date2": datetime(2023, 2, 10)},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(
lambda d1, d2: (d2 - d1).days if d1 and d2 else None, T.IntegerType()
)
result = df.withColumn("days_diff", my_udf("date1", "date2"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["days_diff"] == 4
assert rows[1]["days_diff"] == 9
finally:
spark.stop()
def test_udf_in_join_condition(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data1 = [
{"id": 1, "value1": 10, "value2": 20},
{"id": 2, "value1": 15, "value2": 25},
]
data2 = [
{"id": 1, "sum": 30},
{"id": 2, "sum": 40},
]
df1 = spark.createDataFrame(data1)
df2 = spark.createDataFrame(data2)
my_udf = F.udf(lambda x, y: x + y, T.IntegerType())
df1_with_sum = df1.withColumn("computed_sum", my_udf("value1", "value2"))
result = df1_with_sum.join(
df2, df1_with_sum["computed_sum"] == df2["sum"], "inner"
)
rows = result.collect()
assert len(rows) == 2
assert all(
r["value1"] + r["value2"] == r["sum"] for r in rows
) finally:
spark.stop()
def test_udf_with_conditional_logic(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"price": 100, "discount": 10, "tax": 5},
{"price": 200, "discount": 20, "tax": 10},
{"price": 50, "discount": 5, "tax": 2},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(
lambda p, d, t: p - d + t if p > 75 else p - d, T.IntegerType()
)
result = df.withColumn("final_price", my_udf("price", "discount", "tax"))
rows = result.collect()
assert len(rows) == 3
assert rows[0]["final_price"] == 95 assert rows[1]["final_price"] == 190 assert rows[2]["final_price"] == 45 finally:
spark.stop()
def test_udf_six_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6},
{"a": 10, "b": 20, "c": 30, "d": 40, "e": 50, "f": 60},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(
lambda v, w, x, y, z, u: v + w + x + y + z + u, T.IntegerType()
)
result = df.withColumn("total", my_udf("a", "b", "c", "d", "e", "f"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["total"] == 21 assert rows[1]["total"] == 210 finally:
spark.stop()
def test_udf_with_all_null_arguments(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"a": None, "b": None, "c": None},
{"a": 1, "b": 2, "c": 3},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(
lambda x, y, z: (x or 0) + (y or 0) + (z or 0), T.IntegerType()
)
result = df.withColumn("sum", my_udf("a", "b", "c"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["sum"] == 0 assert rows[1]["sum"] == 6 finally:
spark.stop()
def test_udf_with_string_functions(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"first": "hello", "second": "world", "third": "!"},
{"first": "foo", "second": "bar", "third": "baz"},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(
lambda a, b, c: f"{a.upper()}-{b.upper()}-{c.upper()}",
T.StringType(),
)
result = df.withColumn("combined", my_udf("first", "second", "third"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["combined"] == "HELLO-WORLD-!"
assert rows[1]["combined"] == "FOO-BAR-BAZ"
finally:
spark.stop()
def test_udf_chained_operations(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{"x": 2, "y": 3, "z": 4},
{"x": 5, "y": 6, "z": 7},
]
df = spark.createDataFrame(data=data)
add_udf = F.udf(lambda a, b: a + b, T.IntegerType())
multiply_udf = F.udf(lambda a, b: a * b, T.IntegerType())
result = df.withColumn("sum_xy", add_udf("x", "y")).withColumn(
"product", multiply_udf("sum_xy", "z")
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["product"] == 20 assert rows[1]["product"] == 77 finally:
spark.stop()
def test_udf_with_large_number_of_columns(self):
spark = SparkSession.builder.appName("issue-290").getOrCreate()
try:
data = [
{f"col{i}": i for i in range(1, 11)},
{f"col{i}": i * 2 for i in range(1, 11)},
]
df = spark.createDataFrame(data=data)
my_udf = F.udf(
lambda a, b, c, d, e, f, g, h, i, j: (
a + b + c + d + e + f + g + h + i + j
),
T.IntegerType(),
)
result = df.withColumn(
"total",
my_udf(
"col1",
"col2",
"col3",
"col4",
"col5",
"col6",
"col7",
"col8",
"col9",
"col10",
),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["total"] == 55 assert rows[1]["total"] == 110 finally:
spark.stop()