from __future__ import annotations
def test_select_expression_literal_plus_column() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame([{"x": 10}, {"x": 20}], [("x", "int")])
result = df.select(rs.lit(2) + rs.col("x")).collect()
assert len(result) == 2
vals = [list(r.values())[0] for r in result]
assert vals == [12, 22]
result_with_alias = df.select((rs.lit(2) + rs.col("x")).alias("plus_two")).collect()
assert result_with_alias == [{"plus_two": 12}, {"plus_two": 22}]
def test_select_expression_literal_times_column() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame([{"x": 10}, {"x": 20}], [("x", "int")])
result = df.select((rs.lit(3) * rs.col("x")).alias("tripled")).collect()
assert result == [{"tripled": 30}, {"tripled": 60}]
def test_select_expression_aliased() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame([{"a": 1}, {"a": 2}, {"a": 3}], [("a", "int")])
result = df.select((rs.col("a") * 2).alias("doubled")).collect()
assert result == [{"doubled": 2}, {"doubled": 4}, {"doubled": 6}]
def test_select_mixed_column_names_and_expressions() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame([{"x": 10}, {"x": 20}], [("x", "int")])
result = df.select("x", (rs.lit(2) + rs.col("x")).alias("plus_two")).collect()
assert result == [{"x": 10, "plus_two": 12}, {"x": 20, "plus_two": 22}]
def test_select_list_of_expressions() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame(
[{"a": 1, "b": 10}, {"a": 2, "b": 20}],
[("a", "int"), ("b", "int")],
)
result = df.select(
[
rs.col("a"),
(rs.col("a") + rs.col("b")).alias("sum_ab"),
]
).collect()
assert result == [
{"a": 1, "sum_ab": 11},
{"a": 2, "sum_ab": 22},
]
def test_select_varargs_expressions() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame([{"x": 5}, {"x": 7}], [("x", "int")])
result = df.select(
rs.col("x"),
(rs.col("x") * 2).alias("doubled"),
(rs.lit(1) + rs.col("x")).alias("plus_one"),
).collect()
assert result == [
{"x": 5, "doubled": 10, "plus_one": 6},
{"x": 7, "doubled": 14, "plus_one": 8},
]
def test_select_string_column_names_still_work() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame(
[{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}],
[("a", "int"), ("b", "int"), ("c", "int")],
)
result_list = df.select(["a", "b"]).collect()
assert result_list == [{"a": 1, "b": 2}, {"a": 4, "b": 5}]
result_varargs = df.select("a", "b").collect()
assert result_varargs == [{"a": 1, "b": 2}, {"a": 4, "b": 5}]
def test_select_mixed_strings_and_expressions() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame(
[{"a": 1, "b": 10, "c": 100}, {"a": 2, "b": 20, "c": 200}],
[("a", "int"), ("b", "int"), ("c", "int")],
)
result = df.select(
"a",
(rs.col("b") * 2).alias("b_twice"),
"c",
).collect()
assert result == [
{"a": 1, "b_twice": 20, "c": 100},
{"a": 2, "b_twice": 40, "c": 200},
]
def test_select_expression_with_multiply_literal_left() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame([{"x": 10}, {"x": 20}], [("x", "int")])
result = df.select((3 * rs.col("x")).alias("tripled")).collect()
assert result == [{"tripled": 30}, {"tripled": 60}]
def test_select_complex_expression() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame(
[{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}],
[("a", "int"), ("b", "int"), ("c", "int")],
)
result = df.select(
((rs.col("a") + rs.col("b")) * rs.col("c")).alias("product")
).collect()
assert result == [{"product": 9}, {"product": 54}]
def test_select_single_expression_no_alias() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.createDataFrame([{"x": 1}], [("x", "int")])
result = df.select(rs.col("x") + rs.lit(10)).collect()
assert len(result) == 1
row = result[0]
assert len(row) == 1
val = list(row.values())[0]
assert val == 11