from __future__ import annotations
import pytest
def test_import_module() -> None:
import robin_sparkless
assert robin_sparkless is not None
def test_configure_for_multiprocessing() -> None:
import robin_sparkless as rs
rs._configure_for_multiprocessing()
def test_spark_session_builder() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
assert spark is not None
assert spark.is_case_sensitive() in (True, False)
def test_create_dataframe_and_collect() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [(1, 25, "Alice"), (2, 30, "Bob"), (3, 35, "Carol")]
df = spark.create_dataframe(data, ["id", "age", "name"])
assert df is not None
n = df.count()
assert n == 3
rows = df.collect()
assert isinstance(rows, list)
assert len(rows) == 3
assert rows[0] == {"id": 1, "age": 25, "name": "Alice"}
assert rows[1]["name"] == "Bob"
assert rows[2]["age"] == 35
def test_filter_and_select() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [(1, 25, "Alice"), (2, 30, "Bob"), (3, 35, "Carol")]
df = spark.create_dataframe(data, ["id", "age", "name"])
filtered = df.filter(rs.col("age").gt(rs.lit(28)))
assert filtered.count() == 2
rows = filtered.collect()
assert all(r["age"] > 28 for r in rows)
selected = df.select(["id", "name"])
assert selected.count() == 3
first = selected.collect()[0]
assert "id" in first and "name" in first
assert "age" not in first
def test_filter_with_and_or_operators() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [(1, 25, "Alice"), (2, 30, "Bob"), (3, 35, "Carol"), (4, 40, "Dave")]
df = spark.create_dataframe(data, ["id", "age", "name"])
filtered_and = df.filter(
(rs.col("age").gt(rs.lit(26))) & (rs.col("age").lt(rs.lit(36)))
)
rows_and = filtered_and.collect()
assert len(rows_and) == 2
assert all(26 < r["age"] < 36 for r in rows_and)
filtered_or = df.filter(
(rs.col("age").lt(rs.lit(26))) | (rs.col("age").gt(rs.lit(35)))
)
rows_or = filtered_or.collect()
assert len(rows_or) == 2
assert rows_or[0]["age"] == 25 and rows_or[1]["age"] == 40
def test_filter_column_vs_column() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[1, 5], [2, 4], [3, 1], [4, 2], [5, 1]]
schema = [("a", "bigint"), ("b", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
filtered_method = df.filter(rs.col("a").gt(rs.col("b")))
rows_method = filtered_method.collect()
assert len(rows_method) == 3
assert {(r["a"], r["b"]) for r in rows_method} == {(3, 1), (4, 2), (5, 1)}
filtered_op = df.filter(rs.col("a") > rs.col("b"))
rows_op = filtered_op.collect()
assert len(rows_op) == 3
assert {(r["a"], r["b"]) for r in rows_op} == {(3, 1), (4, 2), (5, 1)}
assert df.filter(rs.col("a") < rs.col("b")).count() == 2 assert df.filter(rs.col("a") >= rs.col("b")).count() == 3 assert df.filter(rs.col("a") <= rs.col("b")).count() == 2 assert df.filter(rs.col("a") == rs.col("b")).count() == 0
assert df.filter(rs.col("a") != rs.col("b")).count() == 5
def test_filter_column_vs_column_all_method_forms() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[3, 1], [1, 3], [2, 2], [0, 5]]
schema = [("x", "bigint"), ("y", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
assert df.filter(rs.col("x").gt(rs.col("y"))).count() == 1 assert df.filter(rs.col("x").ge(rs.col("y"))).count() == 2 assert df.filter(rs.col("x").lt(rs.col("y"))).count() == 2 assert df.filter(rs.col("x").le(rs.col("y"))).count() == 3 assert df.filter(rs.col("x").eq(rs.col("y"))).count() == 1 assert df.filter(rs.col("x") != rs.col("y")).count() == 3
def test_filter_column_vs_column_combined_with_literal() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[1, 5], [2, 4], [3, 1], [4, 2], [5, 1]]
schema = [("a", "bigint"), ("b", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
out = df.filter((rs.col("a") > rs.col("b")) & (rs.col("a") > 2))
rows = out.collect()
assert len(rows) == 3
assert {(r["a"], r["b"]) for r in rows} == {(3, 1), (4, 2), (5, 1)}
out2 = df.filter((rs.col("a") < rs.col("b")) | (rs.col("b") >= rs.lit(5)))
assert out2.count() == 2
def test_filter_column_vs_column_with_with_column() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[10, 5], [3, 7], [0, 0]]
schema = [("p", "bigint"), ("q", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
out = df.with_column("p_gt_q", rs.col("p") > rs.col("q"))
rows = out.collect()
assert len(rows) == 3
assert rows[0] == {"p": 10, "q": 5, "p_gt_q": True}
assert rows[1] == {"p": 3, "q": 7, "p_gt_q": False}
assert rows[2] == {"p": 0, "q": 0, "p_gt_q": False}
out2 = df.with_column("p_le_q", rs.col("p").le(rs.col("q")))
assert out2.collect()[0]["p_le_q"] is False
assert out2.collect()[1]["p_le_q"] is True
def test_filter_column_vs_column_strings() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [["apple", "banana"], ["banana", "apple"], ["x", "x"]]
schema = [("s1", "string"), ("s2", "string")]
df = spark._create_dataframe_from_rows(data, schema)
out = df.filter(rs.col("s1") > rs.col("s2"))
rows = out.collect()
assert len(rows) == 1 and rows[0]["s1"] == "banana" and rows[0]["s2"] == "apple"
assert df.filter(rs.col("s1") == rs.col("s2")).count() == 1
assert df.filter(rs.col("s1") != rs.col("s2")).count() == 2
def test_filter_column_vs_column_empty_and_all_match() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[1, 2], [3, 4]]
schema = [("a", "bigint"), ("b", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
empty = df.filter(rs.col("a") == rs.col("b"))
assert empty.count() == 0
assert empty.collect() == []
all_match = df.filter(rs.col("a") != rs.col("b"))
assert all_match.count() == 2
def test_column_operator_overloads_pyspark_style() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 25, "a"), (2, 35, "b"), (3, 30, "c")], ["id", "age", "name"]
)
expr_method = rs.col("age").gt(rs.lit(30))
out_method = df.filter(expr_method).collect()
assert len(out_method) == 1 and out_method[0]["age"] == 35
expr_op = rs.col("age") > rs.lit(30)
out_op = df.filter(expr_op).collect()
assert out_op == out_method
out_scalar = df.filter(rs.col("age") > 30).collect()
assert out_scalar == out_method
assert df.filter(rs.col("age") >= 30).count() == 2
assert df.filter(rs.col("age") < 30).count() == 1
assert df.filter(rs.col("age") <= 25).count() == 1
assert df.filter(rs.col("age") == 30).count() == 1
assert df.filter(rs.col("age") != 30).count() == 2
def test_column_operator_overloads_operator_vs_method_parity() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[1, 5], [2, 4], [3, 1], [4, 2], [5, 1]]
schema = [("a", "bigint"), ("b", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
for op_name, op_fn, method_fn in [
("gt", lambda c1, c2: c1 > c2, lambda c1, c2: c1.gt(c2)),
("lt", lambda c1, c2: c1 < c2, lambda c1, c2: c1.lt(c2)),
("ge", lambda c1, c2: c1 >= c2, lambda c1, c2: c1.ge(c2)),
("le", lambda c1, c2: c1 <= c2, lambda c1, c2: c1.le(c2)),
("eq", lambda c1, c2: c1 == c2, lambda c1, c2: c1.eq(c2)),
("ne", lambda c1, c2: c1 != c2, lambda c1, c2: c1.ne(c2)),
]:
out_op = df.filter(op_fn(rs.col("a"), rs.col("b"))).collect()
out_method = df.filter(method_fn(rs.col("a"), rs.col("b"))).collect()
assert out_op == out_method, f"{op_name}: operator vs method should match"
for op_name, op_fn, method_fn in [
("gt", lambda c, v: c > v, lambda c, v: c.gt(v)),
("lt", lambda c, v: c < v, lambda c, v: c.lt(v)),
("ge", lambda c, v: c >= v, lambda c, v: c.ge(v)),
("le", lambda c, v: c <= v, lambda c, v: c.le(v)),
("eq", lambda c, v: c == v, lambda c, v: c.eq(v)),
("ne", lambda c, v: c != v, lambda c, v: c.ne(v)),
]:
out_op = df.filter(op_fn(rs.col("a"), 3)).collect()
out_method = df.filter(method_fn(rs.col("a"), 3)).collect()
assert out_op == out_method, (
f"{op_name} scalar: operator vs method should match"
)
def test_column_operator_overloads_with_column_pyspark_semantics() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[10, 5], [3, 7], [0, 0], [2, 2]]
schema = [("p", "bigint"), ("q", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
out = df.with_column("p_gt_q", rs.col("p") > rs.col("q"))
rows = out.collect()
assert [r["p_gt_q"] for r in rows] == [True, False, False, False]
out2 = df.with_column("p_eq_q", rs.col("p") == rs.col("q"))
assert [r["p_eq_q"] for r in out2.collect()] == [False, False, True, True]
out3 = df.with_column("p_ge_5", rs.col("p") >= 5)
assert [r["p_ge_5"] for r in out3.collect()] == [True, False, False, False]
def test_column_operator_overloads_combined_and_or_pyspark_semantics() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[1, 5], [2, 4], [3, 1], [4, 2], [5, 1]]
schema = [("a", "bigint"), ("b", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
out = df.filter((rs.col("a") > rs.col("b")) & (rs.col("a") > 2)).collect()
assert len(out) == 3
assert {(r["a"], r["b"]) for r in out} == {(3, 1), (4, 2), (5, 1)}
out2 = df.filter((rs.col("a") < rs.col("b")) | (rs.col("b") == 1)).collect()
assert len(out2) == 4
assert {(r["a"], r["b"]) for r in out2} == {(1, 5), (2, 4), (3, 1), (5, 1)}
out3 = df.filter((rs.col("a") >= 4) | (rs.col("b") <= 2)).collect()
assert len(out3) == 3
assert {(r["a"], r["b"]) for r in out3} == {(3, 1), (4, 2), (5, 1)}
def test_column_operator_overloads_float_and_string_scalar() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[
{"id": 1, "score": 2.0, "name": "Alice"},
{"id": 2, "score": 3.5, "name": "Bob"},
{"id": 3, "score": 2.5, "name": "Charlie"},
],
[("id", "bigint"), ("score", "double"), ("name", "string")],
)
out = df.filter(rs.col("score") > 2.5).collect()
assert len(out) == 1 and out[0]["name"] == "Bob"
assert df.filter(rs.col("score") >= 3.5).count() == 1
assert df.filter(rs.col("score") == 2.0).count() == 1
assert df.filter(rs.col("name") == "Bob").count() == 1
assert df.filter(rs.col("name") != "Alice").count() == 2
assert df.filter(rs.col("name") > "B").count() == 2
def test_column_operator_overloads_reflected_scalar() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 25, "a"), (2, 35, "b"), (3, 30, "c")], ["id", "age", "name"]
)
out = df.filter(30 < rs.col("age")).collect()
assert len(out) == 1 and out[0]["age"] == 35
assert df.filter(25 <= rs.col("age")).count() == 3
out2 = df.filter(30 > rs.col("age")).collect()
assert len(out2) == 1 and out2[0]["age"] == 25
def test_filter_column_vs_column_scalar_still_works() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [[1, 5], [2, 4], [3, 1]]
schema = [("a", "bigint"), ("b", "bigint")]
df = spark._create_dataframe_from_rows(data, schema)
assert df.filter(rs.col("a") > 2).count() == 1
assert df.filter(rs.col("a").gt(2)).count() == 1
assert df.filter(rs.col("b") >= rs.lit(4)).count() == 2
assert df.filter(rs.col("b").ge(rs.lit(4))).count() == 2
def test_filter_accepts_literal_bool() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [(1, 25, "Alice"), (2, 30, "Bob")]
df = spark.create_dataframe(data, ["id", "age", "name"])
out_true = df.filter(True)
assert out_true.count() == 2
assert out_true.collect() == df.collect()
out_false = df.filter(False)
assert out_false.count() == 0
assert out_false.collect() == []
def test_filter_literal_bool_empty_dataframe() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([], ["id", "age", "name"])
assert df.count() == 0
assert df.filter(True).count() == 0
assert df.filter(True).collect() == []
assert df.filter(False).count() == 0
assert df.filter(False).collect() == []
def test_filter_literal_bool_preserves_schema() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [(1, 10, "a"), (2, 20, "b")]
df = spark.create_dataframe(data, ["id", "v", "label"])
out = df.filter(False)
rows = out.collect()
assert rows == []
out2 = out.with_column("extra", rs.lit(1))
assert out2.count() == 0
assert out2.collect() == []
def test_filter_literal_bool_chained_with_column_filter() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [(1, 25, "Alice"), (2, 30, "Bob"), (3, 35, "Carol")]
df = spark.create_dataframe(data, ["id", "age", "name"])
out = df.filter(True).filter(rs.col("age") > 28)
assert out.count() == 2
assert {r["name"] for r in out.collect()} == {"Bob", "Carol"}
out_empty = df.filter(False).limit(5)
assert out_empty.count() == 0
def test_filter_literal_bool_from_rows_schema() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [{"a": 1, "b": "x"}, {"a": 2, "b": "y"}]
schema = [("a", "bigint"), ("b", "string")]
df = spark._create_dataframe_from_rows(data, schema)
assert df.filter(True).count() == 2
assert df.filter(False).count() == 0
assert df.filter(False).collect() == []
def test_filter_condition_type_error() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 2, "a")], ["id", "v", "name"])
with pytest.raises(TypeError, match="condition must be a Column or literal bool"):
df.filter(1) with pytest.raises(TypeError, match="condition must be a Column or literal bool"):
df.filter("age > 10") with pytest.raises(TypeError, match="condition must be a Column or literal bool"):
df.filter(None)
def test_with_column_and_show() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [(1, 25, "Alice"), (2, 30, "Bob")]
df = spark.create_dataframe(data, ["id", "age", "name"])
doubled = df.with_column("age2", rs.col("age").gt(rs.lit(20)))
assert doubled.count() == 2
doubled.show(5)
def test_group_by_count() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 0, "a"), (2, 0, "a"), (3, 0, "b")], ["id", "age", "grp"]
)
grouped = df.group_by(["grp"])
result = grouped.count()
assert result is not None
rows = result.collect()
assert len(rows) == 2
counts = {r["grp"]: r["count"] for r in rows}
assert counts["a"] == 2
assert counts["b"] == 1
def test_col_lit_when() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, ""), (2, 20, ""), (3, 30, "")], ["id", "age", "name"]
)
expr = (
rs.when(rs.col("age").gt(rs.lit(15)))
.then(rs.lit("high"))
.otherwise(rs.lit("low"))
)
out = df.with_column("level", expr)
rows = out.collect()
assert len(rows) == 3
assert rows[0]["level"] == "low"
assert rows[1]["level"] == "high"
assert rows[2]["level"] == "high"
def test_lit_date_and_datetime() -> None:
import datetime
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "x", "name"]
)
out = df.with_column("const_date", rs.lit(datetime.date(2025, 6, 15)))
out = out.with_column(
"const_ts",
rs.lit(datetime.datetime(2025, 6, 15, 12, 30, 45, 123456)),
)
rows = out.collect()
assert len(rows) == 3
assert rows[0]["const_date"] in (datetime.date(2025, 6, 15), "2025-06-15")
ts_val = rows[0]["const_ts"]
assert ts_val == datetime.datetime(2025, 6, 15, 12, 30, 45, 123456) or (
isinstance(ts_val, str) and ts_val.startswith("2025-06-15")
)
out2 = df.with_column("d1", rs.lit(datetime.date(2025, 1, 1)))
out2 = out2.with_column("d2", rs.lit(datetime.date(2025, 6, 1)))
filtered = out2.filter(rs.col("d1").lt(rs.col("d2")))
assert filtered.count() == 3
def test_lit_date_edge_cases() -> None:
import datetime
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 10, "a"), (2, 20, "b")], ["id", "x", "name"])
out = df.with_column("epoch", rs.lit(datetime.date(1970, 1, 1)))
rows = out.collect()
assert len(rows) == 2
assert rows[0]["epoch"] in (datetime.date(1970, 1, 1), "1970-01-01")
out = df.with_column("leap", rs.lit(datetime.date(2024, 2, 29)))
rows = out.collect()
assert rows[0]["leap"] in (datetime.date(2024, 2, 29), "2024-02-29")
out = df.with_column("early", rs.lit(datetime.date(1, 1, 1)))
rows = out.collect()
assert rows[0]["early"] in (datetime.date(1, 1, 1), "0001-01-01")
def test_lit_datetime_edge_cases() -> None:
import datetime
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 10, "a")], ["id", "x", "name"])
out = df.with_column(
"midnight",
rs.lit(datetime.datetime(2025, 3, 15, 0, 0, 0, 0)),
)
rows = out.collect()
assert len(rows) == 1
ts = rows[0]["midnight"]
assert ts == datetime.datetime(2025, 3, 15, 0, 0, 0, 0) or (
isinstance(ts, str) and "2025-03-15" in ts and "00:00" in ts
)
out = df.with_column(
"with_micros",
rs.lit(datetime.datetime(2025, 1, 1, 23, 59, 59, 999999)),
)
rows = out.collect()
assert len(rows) == 1
ts = rows[0]["with_micros"]
assert ts == datetime.datetime(2025, 1, 1, 23, 59, 59, 999999) or (
isinstance(ts, str) and "2025-01-01" in ts
)
def test_lit_date_filter_comparisons() -> None:
import datetime
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "x", "name"]
)
df = df.with_column("d", rs.lit(datetime.date(2025, 5, 15)))
threshold = datetime.date(2025, 5, 15)
eq_rows = df.filter(rs.col("d").eq(rs.lit(threshold)))
assert eq_rows.count() == 3
future = datetime.date(2026, 1, 1)
gt_rows = df.filter(rs.col("d").gt(rs.lit(future)))
assert gt_rows.count() == 0
lt_rows = df.filter(rs.col("d").lt(rs.lit(future)))
assert lt_rows.count() == 3
ge_rows = df.filter(rs.col("d").ge(rs.lit(threshold)))
assert ge_rows.count() == 3
past = datetime.date(2020, 1, 1)
le_rows = df.filter(rs.col("d").le(rs.lit(past)))
assert le_rows.count() == 0
def test_lit_rejects_non_date_datetime() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 10, "a")], ["id", "x", "name"])
with pytest.raises(TypeError, match="lit\\(\\) supports only"):
df.with_column("bad", rs.lit([1, 2, 3]))
with pytest.raises(TypeError, match="lit\\(\\) supports only"):
df.with_column("bad", rs.lit({"a": 1}))
with pytest.raises(TypeError, match="lit\\(\\) supports only"):
df.with_column("bad", rs.lit(b"bytes"))
def test_lit_date_and_datetime_in_when() -> None:
import datetime
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 1, "a"), (2, 2, "b"), (3, 10, "c")], ["id", "x", "name"]
)
out = df.with_column(
"bucket",
rs.when(rs.col("x").lt(rs.lit(5)))
.then(rs.lit(datetime.date(2025, 1, 1)))
.otherwise(rs.lit(datetime.date(2025, 12, 31))),
)
rows = out.collect()
assert len(rows) == 3
assert rows[0]["bucket"] in (datetime.date(2025, 1, 1), "2025-01-01")
assert rows[2]["bucket"] in (datetime.date(2025, 12, 31), "2025-12-31")
out2 = df.with_column(
"ts",
rs.when(rs.col("id").eq(rs.lit(2)))
.then(rs.lit(datetime.datetime(2025, 6, 15, 12, 0, 0, 0)))
.otherwise(rs.lit(datetime.datetime(2025, 1, 1, 0, 0, 0, 0))),
)
rows2 = out2.collect()
assert len(rows2) == 3
mid_ts = rows2[1]["ts"]
assert mid_ts == datetime.datetime(2025, 6, 15, 12, 0, 0, 0) or (
isinstance(mid_ts, str) and "2025-06-15" in mid_ts and "12" in mid_ts
)
def test_limit_and_distinct() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [(1, 1, "a"), (2, 2, "b"), (3, 3, "c")]
df = spark.create_dataframe(data, ["id", "age", "name"])
limited = df.limit(2)
assert limited.count() == 2
distinct_df = spark.create_dataframe(
[(1, 1, "x"), (1, 1, "x"), (2, 2, "y")], ["id", "age", "name"]
)
uniq = distinct_df.distinct()
assert uniq.count() == 2
def test_window_row_number_rank_over() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 100, "a"), (2, 90, "a"), (3, 80, "b")], ["id", "salary", "dept"]
)
out = df.with_column(
"rn",
rs.col("salary").row_number(descending=True).over(["dept"]),
)
rows = out.collect()
assert len(rows) == 3
rn_by_id = {r["id"]: r["rn"] for r in rows}
assert rn_by_id[1] == 1 and rn_by_id[2] == 2 and rn_by_id[3] == 1
out2 = df.with_column(
"rk",
rs.col("salary").rank(descending=True).over(["dept"]),
)
rows2 = out2.collect()
assert [r["rk"] for r in rows2] == [1, 2, 1]
out3 = df.with_column(
"dr",
rs.col("salary").dense_rank(descending=True).over(["dept"]),
)
rows3 = out3.collect()
assert [r["dr"] for r in rows3] == [1, 2, 1]
out4 = df.with_column(
"prev",
rs.col("salary").lag(1).over(["dept"]),
)
out4 = out4.with_column(
"nxt",
rs.col("salary").lead(1).over(["dept"]),
)
rows4 = out4.collect()
by_id = {r["id"]: r for r in rows4}
assert by_id[3]["prev"] is None and by_id[3]["nxt"] is None
a_rows = [r for r in rows4 if r["dept"] == "a"]
assert len(a_rows) == 2
assert sum(1 for r in a_rows if r["prev"] is not None or r["nxt"] is not None) >= 1
out5 = df.with_column(
"total_by_dept",
rs.sum(rs.col("salary")).over(["dept"]),
)
rows5 = out5.collect()
assert rows5[0]["total_by_dept"] == 190 and rows5[2]["total_by_dept"] == 80
def test_aggregate_functions() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "a"), (3, 30, "b")], ["id", "val", "grp"]
)
grouped = df.group_by(["grp"])
agg_df = grouped.agg([rs.sum(rs.col("val")), rs.count(rs.col("id"))])
rows = agg_df.collect()
assert len(rows) == 2
def test_stat_cov_corr() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "x", "name"]
)
stat = df.stat()
c = stat.cov("id", "x")
assert isinstance(c, (int, float))
r = stat.corr("id", "x")
assert isinstance(r, (int, float))
assert -1.0 - 1e-9 <= r <= 1.0 + 1e-9 or (r != r)
def test_na_fill_drop() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "age", "name"]
)
filled = df.na().fill(rs.lit(0))
assert filled.count() == 3
dropped = df.na().drop()
assert dropped.count() == 3
def test_with_columns_and_renamed() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 25, "Alice")], ["id", "age", "name"])
out = df.with_columns({"extra": rs.lit(42)})
rows = out.collect()
assert rows[0]["extra"] == 42
renamed = df.with_columns_renamed({"name": "full_name"})
row = renamed.collect()[0]
assert "full_name" in row and "name" not in row
def test_to_pandas() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 25, "Alice")], ["id", "age", "name"])
result = df.to_pandas()
assert result is not None
assert isinstance(result, list)
assert len(result) == 1
assert (
result[0]["id"] == 1 and result[0]["age"] == 25 and result[0]["name"] == "Alice"
)
def test_ascii_base64() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 65, "A")], ["id", "code", "name"])
out = df.with_column("ascii_val", rs.ascii(rs.col("name")))
assert out.count() == 1
df2 = spark.create_dataframe([(1, 2, "hello")], ["id", "x", "msg"])
out2 = df2.with_column("enc", rs.base64(rs.col("msg")))
assert out2.count() == 1
def test_filter_nonexistent_column_raises() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 25, "Alice")], ["id", "age", "name"])
with pytest.raises(Exception):
df.filter(rs.col("nonexistent").gt(rs.lit(0)))
def test_select_nonexistent_column_raises() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 25, "Alice")], ["id", "age", "name"])
with pytest.raises(Exception):
df.select(["id", "nonexistent"])
def test_read_api_and_write_parquet_csv_json() -> None:
import tempfile
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 10, "a"), (2, 20, "b")], ["id", "x", "label"])
with tempfile.TemporaryDirectory() as tmpdir:
parquet_path = f"{tmpdir}/out.parquet"
df.write().mode("overwrite").parquet(parquet_path)
back = spark.read().parquet(parquet_path)
assert back.count() == 2
assert back.collect()[0]["id"] == 1
csv_path = f"{tmpdir}/out.csv"
df.write().mode("overwrite").option("header", "true").csv(csv_path)
back_csv = spark.read().option("header", "true").csv(csv_path)
assert back_csv.count() == 2
json_path = f"{tmpdir}/out.json"
df.write().mode("overwrite").json(json_path)
back_json = spark.read().json(json_path)
assert back_json.count() == 2
df.write().mode("overwrite").format("parquet").save(
f"{tmpdir}/via_save.parquet"
)
via_save = spark.read().format("parquet").load(f"{tmpdir}/via_save.parquet")
assert via_save.count() == 2
def test_delta_write_and_read(spark) -> None:
import tempfile
df = spark.create_dataframe([(1, 1, "a"), (2, 2, "b")], ["id", "num", "name"])
try:
with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/delta_table"
df.write_delta(path, overwrite=True)
back = spark.read_delta(path)
rows = back.collect()
assert len(rows) == 2
assert rows[0]["id"] == 1 and rows[0]["name"] == "a"
except (RuntimeError, AttributeError) as e:
if (
isinstance(e, AttributeError)
or "delta" in str(e).lower()
or "Delta Lake" in str(e)
):
pytest.skip(
"Delta Lake feature not built (build with --features pyo3,delta)"
)
raise
def test_sparkless_parity_join_inner_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
left = spark.create_dataframe([(1, 10, "a"), (2, 20, "b")], ["id", "v", "label"])
right = spark.create_dataframe([(1, 100, "x"), (3, 300, "z")], ["id", "w", "tag"])
joined = left.join(right, ["id"], "inner")
rows = joined.collect()
assert len(rows) == 1
assert rows[0]["id"] == 1 and rows[0]["v"] == 10 and rows[0]["w"] == 100
def test_join_on_string_single_column() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df1 = spark._create_dataframe_from_rows(
[{"id": 1, "x": 10}], [("id", "bigint"), ("x", "bigint")]
)
df2 = spark._create_dataframe_from_rows(
[{"id": 1, "y": 20}], [("id", "bigint"), ("y", "bigint")]
)
result = df1.join(df2, on="id", how="inner")
rows = result.collect()
assert len(rows) == 1
assert rows[0]["id"] == 1 and rows[0]["x"] == 10 and rows[0]["y"] == 20
result_list = df1.join(df2, on=["id"], how="inner")
assert result_list.collect() == rows
def test_join_on_string_all_join_types() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
left = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "v", "label"]
)
right = spark.create_dataframe([(1, 100, "x"), (3, 300, "z")], ["id", "w", "tag"])
for how in ("inner", "left", "right", "outer", "left_semi", "left_anti"):
with_str = left.join(right, on="id", how=how).collect()
with_list = left.join(right, on=["id"], how=how).collect()
assert len(with_str) == len(with_list), f"how={how} row count mismatch"
if how not in ("left_semi", "left_anti"):
def by_id(r):
return (r.get("id"), r.get("v"), r.get("w"), r.get("tag"))
assert sorted(with_str, key=by_id) == sorted(with_list, key=by_id), (
f"how={how} rows differ"
)
else:
assert sorted(r["id"] for r in with_str) == sorted(
r["id"] for r in with_list
), f"how={how} ids differ"
def test_join_on_tuple_single_and_multi_column() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df1 = spark._create_dataframe_from_rows(
[{"id": 1, "x": 10}], [("id", "bigint"), ("x", "bigint")]
)
df2 = spark._create_dataframe_from_rows(
[{"id": 1, "y": 20}], [("id", "bigint"), ("y", "bigint")]
)
result_tuple = df1.join(df2, on=("id",), how="inner").collect()
result_list = df1.join(df2, on=["id"], how="inner").collect()
assert result_tuple == result_list
left = spark._create_dataframe_from_rows(
[{"a": 1, "b": 2, "v": 10}],
[("a", "bigint"), ("b", "bigint"), ("v", "bigint")],
)
right = spark._create_dataframe_from_rows(
[{"a": 1, "b": 2, "w": 20}],
[("a", "bigint"), ("b", "bigint"), ("w", "bigint")],
)
on_tuple = left.join(right, on=("a", "b"), how="inner").collect()
on_list = left.join(right, on=["a", "b"], how="inner").collect()
assert on_tuple == on_list
assert len(on_tuple) == 1 and on_tuple[0]["v"] == 10 and on_tuple[0]["w"] == 20
def test_join_on_string_no_matches() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df1 = spark._create_dataframe_from_rows(
[{"id": 1, "x": 10}], [("id", "bigint"), ("x", "bigint")]
)
df2 = spark._create_dataframe_from_rows(
[{"id": 99, "y": 20}], [("id", "bigint"), ("y", "bigint")]
)
result = df1.join(df2, on="id", how="inner").collect()
assert result == []
assert df1.join(df2, on=["id"], how="inner").collect() == result
def test_join_on_string_multiple_matches() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
left = spark._create_dataframe_from_rows(
[{"id": 1, "x": 10}, {"id": 1, "x": 11}],
[("id", "bigint"), ("x", "bigint")],
)
right = spark._create_dataframe_from_rows(
[{"id": 1, "y": 20}], [("id", "bigint"), ("y", "bigint")]
)
result = left.join(right, on="id", how="inner").collect()
assert len(result) == 2
assert {r["x"] for r in result} == {10, 11}
assert all(r["id"] == 1 and r["y"] == 20 for r in result)
result_list = left.join(right, on=["id"], how="inner").collect()
assert sorted(result, key=lambda r: r["x"]) == sorted(
result_list, key=lambda r: r["x"]
)
def test_join_on_invalid_type_raises() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df1 = spark._create_dataframe_from_rows([{"id": 1}], [("id", "bigint")])
df2 = spark._create_dataframe_from_rows([{"id": 1}], [("id", "bigint")])
for invalid in (42, None, 3.14):
with pytest.raises(
TypeError, match="join 'on' must be str or list/tuple of str"
):
df1.join(df2, on=invalid, how="inner")
def test_join_on_string_pyspark_semantics_single_key_column() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[{"name": "Alice", "age": 2}, {"name": "Bob", "age": 5}],
[("name", "string"), ("age", "bigint")],
)
df2 = spark._create_dataframe_from_rows(
[{"name": "Tom", "height": 80}, {"name": "Bob", "height": 85}],
[("name", "string"), ("height", "bigint")],
)
result = df.join(df2, on="name", how="inner").collect()
assert len(result) == 1
assert (
result[0]["name"] == "Bob"
and result[0]["age"] == 5
and result[0]["height"] == 85
)
assert list(result[0].keys()) == ["name", "age", "height"], (
"join on single column must not duplicate key: output should have one 'name' column"
)
result_list = df.join(df2, on=["name"], how="inner").collect()
assert result_list == result
def test_join_on_string_pyspark_semantics_outer_and_semi_anti() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[{"name": "Alice", "age": 2}, {"name": "Bob", "age": 5}],
[("name", "string"), ("age", "bigint")],
)
df2 = spark._create_dataframe_from_rows(
[{"name": "Tom", "height": 80}, {"name": "Bob", "height": 85}],
[("name", "string"), ("height", "bigint")],
)
outer = df.join(df2, on="name", how="outer").collect()
assert len(outer) == 3
names = {r["name"] for r in outer}
assert names == {"Alice", "Bob", "Tom"}
semi = df.join(df2, on="name", how="left_semi").collect()
assert len(semi) == 1 and semi[0]["name"] == "Bob"
assert "height" not in semi[0]
anti = df.join(df2, on="name", how="left_anti").collect()
assert len(anti) == 1 and anti[0]["name"] == "Alice"
assert "height" not in anti[0]
def test_join_on_list_multiple_columns_pyspark_semantics() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[{"name": "Alice", "age": 2}, {"name": "Bob", "age": 5}],
[("name", "string"), ("age", "bigint")],
)
df3 = spark._create_dataframe_from_rows(
[
{"name": "Alice", "age": 10, "height": 80},
{"name": "Bob", "age": 5, "height": None},
{"name": "Tom", "age": None, "height": None},
],
[("name", "string"), ("age", "bigint"), ("height", "bigint")],
)
result = df.join(df3, on=["name", "age"], how="inner").collect()
assert len(result) == 1
assert result[0]["name"] == "Bob" and result[0]["age"] == 5
assert list(result[0].keys()) == ["name", "age", "height"]
def test_sparkless_parity_join_left_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
left = spark.create_dataframe(
[(1, 0, "a"), (2, 0, "b")], ["id", "_", "label"]
).drop(["_"])
right = spark.create_dataframe([(1, 0, "x")], ["id", "_", "tag"]).drop(["_"])
joined = left.join(right, ["id"], "left")
rows = joined.collect()
assert len(rows) == 2
assert rows[0]["id"] == 1 and rows[0]["tag"] == "x"
assert rows[1]["id"] == 2 and rows[1]["tag"] is None
def test_sparkless_parity_join_right_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
left = spark.create_dataframe([(1, 0, "x")], ["id", "_", "tag"]).drop(["_"])
right = spark.create_dataframe(
[(1, 0, "a"), (2, 0, "b")], ["id", "_", "label"]
).drop(["_"])
joined = left.join(right, ["id"], "right")
rows = joined.collect()
assert len(rows) == 2
def test_sparkless_parity_join_outer_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
left = spark.create_dataframe(
[(1, 0, "a"), (2, 0, "b")], ["id", "_", "label"]
).drop(["_"])
right = spark.create_dataframe([(1, 0, "x"), (3, 0, "z")], ["id", "_", "tag"]).drop(
["_"]
)
joined = left.join(right, ["id"], "outer")
rows = joined.collect()
assert len(rows) == 3
def test_sparkless_parity_join_left_semi_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
left = spark.create_dataframe(
[(1, 0, "a"), (2, 0, "b"), (3, 0, "c")], ["id", "_", "label"]
).drop(["_"])
right = spark.create_dataframe([(2, 0, "x")], ["id", "_", "tag"]).drop(["_"])
joined = left.join(right, ["id"], "left_semi")
rows = joined.collect()
assert len(rows) == 1
assert rows[0]["id"] == 2 and rows[0]["label"] == "b"
assert "tag" not in rows[0]
def test_sparkless_parity_join_left_anti_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
left = spark.create_dataframe(
[(1, 0, "a"), (2, 0, "b"), (3, 0, "c")], ["id", "_", "label"]
).drop(["_"])
right = spark.create_dataframe([(2, 0, "x")], ["id", "_", "tag"]).drop(["_"])
joined = left.join(right, ["id"], "left_anti")
rows = joined.collect()
assert len(rows) == 2
ids = {r["id"] for r in rows}
assert ids == {1, 3}
def test_sparkless_parity_filter_simple_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 25, "Alice"), (2, 30, "Bob"), (3, 35, "Carol")], ["id", "age", "name"]
)
out = df.filter(rs.col("age").gt(rs.lit(28)))
rows = out.collect()
assert len(rows) == 2
assert all(r["age"] > 28 for r in rows)
def test_sparkless_parity_filter_boolean_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 0, "a"), (2, 0, "b"), (3, 0, "c")], ["id", "_", "name"]
)
df = df.with_column(
"flag",
rs.when(rs.col("id").eq(rs.lit(2))).then(rs.lit(False)).otherwise(rs.lit(True)),
).drop(["_"])
out = df.filter(rs.col("flag").eq(rs.lit(True)))
rows = out.collect()
assert len(rows) == 2
assert [r["id"] for r in rows] == [1, 3]
def test_sparkless_parity_select_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 10, "a"), (2, 20, "b")], ["id", "v", "name"])
out = df.select(["id", "name"])
rows = out.collect()
assert len(rows) == 2
assert rows[0]["id"] == 1 and rows[0]["name"] == "a"
def test_sparkless_parity_select_with_alias_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 25, "Alice")], ["id", "age", "name"])
df = df.with_column("years", rs.col("age"))
out = df.select(["years", "name"])
rows = out.collect()
assert len(rows) == 1
assert rows[0]["years"] == 25 and rows[0]["name"] == "Alice"
def test_sparkless_parity_with_column_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 0, "a"), (2, 0, "b")], ["id", "_", "name"]).drop(
["_"]
)
out = df.with_column("double_id", rs.col("id").multiply(rs.lit(2)))
rows = out.collect()
assert len(rows) == 2
assert rows[0]["double_id"] == 2 and rows[1]["double_id"] == 4
def test_issue_179_with_column_expression_operators() -> None:
import robin_sparkless as rs
F = rs
spark = F.SparkSession.builder().app_name("test").get_or_create()
data = [{"a": 1}, {"a": 2}, {"a": 3}]
schema = [("a", "int")]
df = spark._create_dataframe_from_rows(data, schema)
expr = F.col("a") * F.lit(2)
result = df.with_column("doubled", expr).collect()
assert result == [
{"a": 1, "doubled": 2},
{"a": 2, "doubled": 4},
{"a": 3, "doubled": 6},
]
df2 = spark._create_dataframe_from_rows([{"x": 10}, {"x": 20}], [("x", "int")])
result2 = df2.with_column("plus_two", F.lit(2) + F.col("x")).collect()
assert result2 == [
{"x": 10, "plus_two": 12},
{"x": 20, "plus_two": 22},
]
result3 = df.with_column("times_two", F.col("a") * 2).collect()
assert result3 == [
{"a": 1, "times_two": 2},
{"a": 2, "times_two": 4},
{"a": 3, "times_two": 6},
]
result4 = df2.with_column("tripled", 3 * F.col("x")).collect()
assert result4 == [
{"x": 10, "tripled": 30},
{"x": 20, "tripled": 60},
]
def test_sparkless_parity_drop_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 10, "a"), (2, 20, "b")], ["id", "v", "name"])
out = df.drop(["v"])
rows = out.collect()
assert len(rows) == 2
assert "v" not in rows[0] and "id" in rows[0] and "name" in rows[0]
def test_sparkless_parity_distinct_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 0, "a"), (1, 0, "a"), (2, 0, "b")], ["id", "_", "name"]
).drop(["_"])
out = df.distinct()
rows = out.collect()
assert len(rows) == 2
def test_sparkless_parity_order_by_desc_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[[1, "a"], [2, "b"], [3, "c"]], [("id", "bigint"), ("name", "string")]
)
out = df.order_by(["id"], ascending=[False])
rows = out.collect()
assert len(rows) == 3
assert rows[0]["id"] == 3 and rows[2]["id"] == 1
def test_sparkless_parity_filter_comparison_not_column_existence() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "v", "x"]
).drop(["x"])
out = df.filter(rs.col("v").gt(rs.lit(15)))
rows = out.collect()
assert len(rows) == 2
assert rows[0]["v"] == 20 and rows[1]["v"] == 30
def test_sparkless_parity_table_read_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "v", "name"]
)
try:
spark.create_or_replace_temp_view("t", df)
except AttributeError:
pytest.skip("sql feature not built (create_or_replace_temp_view not available)")
read_back = spark.table("t")
rows = read_back.collect()
assert len(rows) == 3
assert rows[0]["id"] == 1 and rows[2]["name"] == "c"
def test_save_as_table_and_catalog() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 10, "a"), (2, 20, "b")], ["id", "v", "name"])
try:
df.write().saveAsTable("t1")
except AttributeError:
pytest.skip("sql feature not built (saveAsTable not available)")
read_back = spark.table("t1")
assert read_back.count() == 2
assert spark.catalog().tableExists("t1", None)
assert "t1" in spark.catalog().listTables(None)
df2 = spark.create_dataframe([(3, 30, "c")], ["id", "v", "name"])
df2.write().saveAsTable("t1", mode="overwrite")
assert spark.table("t1").count() == 1
df_temp = spark.create_dataframe([(99, 99, "temp")], ["id", "v", "name"])
spark.create_or_replace_temp_view("x", df_temp)
df_saved = spark.create_dataframe([(1, 1, "saved")], ["id", "v", "name"])
df_saved.write().saveAsTable("x", mode="overwrite")
rows_x = spark.table("x").collect()
assert len(rows_x) == 1 and rows_x[0]["name"] == "temp"
names = spark.catalog().listTables(None)
assert "t1" in names and "x" in names
spark.catalog().dropTable("t1")
assert not spark.catalog().tableExists("t1", None)
assert spark.catalog().tableExists("x", None)
df3 = spark.create_dataframe([(1, 2, "d")], ["id", "v", "name"])
df3.write().saveAsTable("delta_t")
rd = spark.read_delta("delta_t")
assert rd.count() == 1 and rd.collect()[0]["name"] == "d"
def test_global_temp_view_persists_across_sessions() -> None:
import robin_sparkless as rs
try:
spark1 = rs.SparkSession.builder().app_name("g1").get_or_create()
df = spark1.create_dataframe(
[(1, 25, "Alice"), (2, 30, "Bob")], ["id", "age", "name"]
)
df.createOrReplaceGlobalTempView("people")
assert spark1.table("global_temp.people").count() == 2
spark2 = rs.SparkSession.builder().app_name("g2").get_or_create()
assert spark2.table("global_temp.people").count() == 2
assert spark2.catalog().listTables("global_temp") == ["people"]
assert spark2.catalog().dropGlobalTempView("people") is True
with pytest.raises(Exception, match="not found"):
spark2.table("global_temp.people")
except AttributeError:
pytest.skip("sql feature not built")
def test_save_as_table_without_session_raises() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 2, "x")], ["id", "v", "name"])
try:
df.write().saveAsTable("_no_session_test", mode="overwrite")
except AttributeError:
pytest.skip("sql feature not built")
assert spark.table("_no_session_test").count() == 1
def test_save_as_table_mode_error_append_ignore() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df1 = spark.create_dataframe([(1, 10, "a")], ["id", "v", "name"])
try:
df1.write().saveAsTable("m1", mode="error")
except AttributeError:
pytest.skip("sql feature not built")
assert spark.table("m1").count() == 1
with pytest.raises(Exception, match="already exists"):
df1.write().saveAsTable("m1", mode="error")
df2 = spark.create_dataframe([(2, 20, "b")], ["id", "v", "name"])
df2.write().saveAsTable("m1", mode="append")
assert spark.table("m1").count() == 2
df3 = spark.create_dataframe([(3, 30, "c")], ["id", "v", "name"])
df3.write().saveAsTable("m1", mode="ignore") assert spark.table("m1").count() == 2
def test_phase_a_signature_alignment() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[{"id": 1, "name": "Alice", "start": "2024-01-15", "end": "2024-06-15"}],
[("id", "bigint"), ("name", "string"), ("start", "date"), ("end", "date")],
)
out = df.with_column("pos", rs.position("li", rs.col("name")))
rows = out.collect()
assert rows[0]["pos"] == 2 df2 = spark._create_dataframe_from_rows(
[{"id": 1, "s": "hello%world"}],
[("id", "bigint"), ("s", "string")],
)
matched = df2.filter(rs.like(rs.col("s"), "hello%world"))
assert matched.count() == 1
df3 = spark._create_dataframe_from_rows(
[["2024-06-15", "2024-01-15"]],
[("end", "date"), ("start", "date")],
)
out3 = df3.with_column("mo", rs.months_between(rs.col("end"), rs.col("start")))
assert out3.count() == 1
df4 = spark._create_dataframe_from_rows(
[{"id": 1, "val": 10}, {"id": 2, "val": 20}],
[("id", "bigint"), ("val", "bigint")],
)
out4 = df4.with_column(
"level",
rs.when(rs.col("val").gt(rs.lit(15)))
.then(rs.lit("high"))
.otherwise(rs.lit("low")),
)
rows4 = out4.collect()
assert rows4[0]["level"] == "low" and rows4[1]["level"] == "high"
def test_phase_b_functions() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[{"id": 1, "x": -10, "name": "Alice", "d": "2024-01-15"}],
[("id", "bigint"), ("x", "bigint"), ("name", "string"), ("d", "date")],
)
out = df.with_column("abs_x", rs.abs(rs.col("x")))
assert out.collect()[0]["abs_x"] == 10
out = df.with_column("len", rs.char_length(rs.col("name")))
assert out.collect()[0]["len"] == 5
out = df.with_column("fmt", rs.date_format(rs.col("d"), "yyyy-MM"))
assert out.collect()[0]["fmt"] == "2024-01"
out = df.with_column("plus7", rs.date_add(rs.col("d"), 7))
assert out.collect()[0]["plus7"] == "2024-01-22"
df2 = spark._create_dataframe_from_rows(
[{"a": 1, "b": 10, "c": 100}],
[("a", "bigint"), ("b", "bigint"), ("c", "bigint")],
)
out2 = df2.with_column("arr", rs.array(rs.col("a"), rs.col("b"), rs.col("c")))
rows2 = out2.collect()
assert rows2[0]["arr"] == [1, 10, 100]
out3 = df2.with_column(
"has", rs.array_contains(rs.array(rs.col("a"), rs.col("b")), rs.lit(10))
)
assert out3.collect()[0]["has"] is True
def test_phase_c_reader_writer() -> None:
import tempfile
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe([(1, 10, "a"), (2, 20, "b")], ["id", "x", "label"])
with tempfile.TemporaryDirectory() as tmpdir:
csv_path = f"{tmpdir}/data.csv"
with open(csv_path, "w") as f:
f.write("id,x,label\n1,10,a\n2,20,b\n")
read_df = spark.read().option("header", "true").csv(csv_path)
assert read_df.count() == 2
parquet_path = f"{tmpdir}/out.parquet"
df.write().mode("overwrite").parquet(parquet_path)
back = spark.read().parquet(parquet_path)
assert back.count() == 2
try:
df.createOrReplaceTempView("phase_c_view")
tbl = spark.read().table("phase_c_view")
assert tbl.count() == 2
except (AttributeError, RuntimeError) as e:
if "sql" in str(e).lower():
pytest.skip("sql feature not built")
def test_phase_f_behavioral() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[{"id": 1, "ok": True}],
[("id", "bigint"), ("ok", "boolean")],
)
out = df.with_column("_check", rs.assert_true(rs.col("ok")))
rows = out.collect()
assert len(rows) == 1
assert rows[0]["_check"] is None df2 = spark._create_dataframe_from_rows(
[{"id": 1, "ok": False}],
[("id", "bigint"), ("ok", "boolean")],
)
with pytest.raises(Exception):
df2.with_column("_check", rs.assert_true(rs.col("ok"))).collect()
df3 = spark._create_dataframe_from_rows(
[{"id": 1, "msg": "err"}],
[("id", "bigint"), ("msg", "string")],
)
with pytest.raises(Exception):
df3.with_column("_err", rs.raise_error(rs.col("msg"))).collect()
def test_phase_d_dataframe_methods() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "v", "name"]
)
try:
df.createOrReplaceTempView("phase_d_view")
read_back = spark.table("phase_d_view")
assert read_back.count() == 3
except (AttributeError, RuntimeError) as e:
if "sql" in str(e).lower() or "create" in str(e).lower():
pytest.skip("sql feature not built or no default session")
raise
r = df.corr("id", "v")
assert isinstance(r, float)
assert -1 <= r <= 1 or (r != r) c = df.cov("id", "v")
assert isinstance(c, float)
cols = df.columns()
assert cols == ["id", "v", "name"]
renamed = df.toDF(["a", "b", "c"])
assert renamed.columns() == ["a", "b", "c"]
js = df.toJSON()
assert isinstance(js, list) and len(js) == 3
def test_phase_e_spark_session_catalog() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("phase_e").get_or_create()
cat = spark.catalog()
assert cat is not None
assert cat.currentDatabase() == "default"
assert cat.currentCatalog() == "spark_catalog"
assert set(cat.listDatabases(None)) == {"default", "global_temp"}
assert cat.listCatalogs(None) == ["spark_catalog"]
tables = cat.listTables(None)
assert isinstance(tables, list)
assert cat.tableExists("nonexistent", None) is False
cat.dropTempView("x")
conf = spark.conf()
assert conf is not None
assert isinstance(conf.get("spark.app.name"), str)
assert isinstance(conf.getAll(), dict)
ver = spark.version()
assert isinstance(ver, str) and len(ver) > 0
r5 = spark.range(5)
assert r5.count() == 5
rows = r5.collect()
assert [r["id"] for r in rows] == [0, 1, 2, 3, 4]
r2_6 = spark.range(2, 6)
assert r2_6.count() == 4
assert [r["id"] for r in r2_6.collect()] == [2, 3, 4, 5]
r0_10_2 = spark.range(0, 10, 2)
assert [r["id"] for r in r0_10_2.collect()] == [0, 2, 4, 6, 8]
sess2 = spark.newSession()
assert sess2 is not None
assert sess2.version() == ver
active = rs.SparkSession.get_active_session()
assert active is not None
assert active.version() == ver
default = rs.SparkSession.get_default_session()
assert default is not None
spark.stop()
spark2 = rs.SparkSession.builder().app_name("e2").get_or_create()
udf_reg = spark2.udf()
assert udf_reg is not None
df = spark2.create_dataframe([(1, 2, "a")], ["a", "b", "c"])
try:
df.createOrReplaceTempView("phase_e_v")
assert spark2.catalog().tableExists("phase_e_v", None) is True
assert "phase_e_v" in spark2.catalog().listTables(None)
spark2.catalog().dropTempView("phase_e_v")
assert spark2.catalog().tableExists("phase_e_v", None) is False
except (AttributeError, RuntimeError) as e:
if "sql" in str(e).lower():
pytest.skip("sql feature not built")
def test_sparkless_parity_multiple_append_operations() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
a = spark.create_dataframe([(1, 0, "a")], ["id", "_", "name"]).drop(["_"])
b = spark.create_dataframe([(2, 0, "b")], ["id", "_", "name"]).drop(["_"])
c = spark.create_dataframe([(3, 0, "c")], ["id", "_", "name"]).drop(["_"])
combined = a.union(b).union(c)
rows = combined.collect()
assert len(rows) == 3
assert rows[0]["id"] == 1 and rows[1]["id"] == 2 and rows[2]["id"] == 3
def test_sql_select_where_returns_rows() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "v", "name"]
)
try:
spark.create_or_replace_temp_view("t", df)
result = spark.sql("SELECT * FROM t WHERE id > 1")
except AttributeError:
pytest.skip("sql feature not built")
rows = result.collect()
assert len(rows) == 2
assert rows[0]["id"] == 2 and rows[1]["id"] == 3
assert rows[0]["name"] == "b" and rows[1]["name"] == "c"
def test_python_udf_register_and_call() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
def double(x):
return x * 2 if x is not None else None
my_udf = spark.udf().register("double", double, return_type="int")
assert my_udf is not None
df = spark.create_dataframe([(1, 10, "a"), (2, 20, "b")], ["id", "v", "name"])
df2 = df.with_column("doubled", my_udf(rs.col("id")))
rows = df2.collect()
assert rows[0]["doubled"] == 2
assert rows[1]["doubled"] == 4
df3 = df.with_column("d2", rs.call_udf("double", rs.col("id")))
rows3 = df3.collect()
assert rows3[0]["d2"] == 2
def test_vectorized_udf_with_column() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("vec_udf").get_or_create()
def double_list(col):
return [x * 2 if x is not None else None for x in col]
my_udf = spark.udf().register(
"double_vec", double_list, return_type="int", vectorized=True
)
assert my_udf is not None
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "v", "name"]
)
df2 = df.with_column("d2", my_udf(rs.col("id")))
rows = df2.collect()
assert [r["d2"] for r in rows] == [2, 4, 6]
def test_vectorized_call_udf() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("vec_udf_call").get_or_create()
def triple_list(col):
return [x * 3 if x is not None else None for x in col]
spark.udf().register("triple_vec", triple_list, return_type="int", vectorized=True)
df = spark.create_dataframe(
[(1, 10, "a"), (2, 20, "b"), (3, 30, "c")], ["id", "v", "name"]
)
df2 = df.with_column("d3", rs.call_udf("triple_vec", rs.col("id")))
rows = df2.collect()
assert [r["d3"] for r in rows] == [3, 6, 9]
def test_grouped_vectorized_pandas_udf_grouped_agg() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("grouped_vec_udf").get_or_create()
def _mean(values):
import ast
xs = []
for v in values:
if v is None:
continue
if isinstance(v, str):
try:
inner = ast.literal_eval(v)
except Exception:
xs.append(float(v))
else:
if isinstance(inner, (list, tuple)):
xs.extend(float(iv) for iv in inner if iv is not None)
else:
xs.append(float(inner))
else:
xs.append(float(v))
if not xs:
return None
return sum(xs) / len(xs)
mean_udf = rs.pandas_udf(_mean, "double", function_type="grouped_agg")
schema = [("k", "bigint"), ("v", "double")]
rows = [
{"k": 1, "v": 10.0},
{"k": 1, "v": 20.0},
{"k": 2, "v": 5.0},
{"k": 2, "v": 15.0},
]
df = spark._create_dataframe_from_rows(rows, schema)
grouped = df.group_by(["k"])
result = grouped.agg([mean_udf(rs.col("v")).alias("mean_v")])
rows = sorted(result.collect(), key=lambda r: r["k"])
assert rows == [
{"k": 1, "mean_v": 15.0},
{"k": 2, "mean_v": 10.0},
]
def test_grouped_vectorized_pandas_udf_mixed_with_builtins_raises() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("grouped_vec_udf_mixed").get_or_create()
def _mean(values):
import ast
xs = []
for v in values:
if v is None:
continue
if isinstance(v, str):
try:
inner = ast.literal_eval(v)
except Exception:
xs.append(float(v))
else:
if isinstance(inner, (list, tuple)):
xs.extend(float(iv) for iv in inner if iv is not None)
else:
xs.append(float(inner))
else:
xs.append(float(v))
if not xs:
return None
return sum(xs) / len(xs)
mean_udf = rs.pandas_udf(_mean, "double", function_type="grouped_agg")
schema = [("k", "bigint"), ("v", "double")]
rows = [
{"k": 1, "v": 10.0},
{"k": 1, "v": 20.0},
{"k": 2, "v": 5.0},
{"k": 2, "v": 15.0},
]
df = spark._create_dataframe_from_rows(rows, schema)
grouped = df.group_by(["k"])
with pytest.raises(NotImplementedError, match="grouped Python UDF aggregations"):
grouped.agg(
[rs.sum(rs.col("v")).alias("sum_v"), mean_udf(rs.col("v")).alias("mean_v")]
)
def test_pandas_udf_wrong_function_type_raises() -> None:
import robin_sparkless as rs
def f(xs):
return xs
with pytest.raises(NotImplementedError, match="only function_type='grouped_agg'"):
rs.pandas_udf(f, "int", function_type="scalar")
def test_grouped_vectorized_pandas_udf_unsupported_with_column_context() -> None:
import robin_sparkless as rs
spark = (
rs.SparkSession.builder().app_name("grouped_vec_udf_bad_ctx").get_or_create()
)
def _mean(values):
import ast
xs = []
for v in values:
if v is None:
continue
if isinstance(v, str):
try:
inner = ast.literal_eval(v)
except Exception:
xs.append(float(v))
else:
if isinstance(inner, (list, tuple)):
xs.extend(float(iv) for iv in inner if iv is not None)
else:
xs.append(float(inner))
else:
xs.append(float(v))
if not xs:
return None
return sum(xs) / len(xs)
mean_udf = rs.pandas_udf(_mean, "double", function_type="grouped_agg")
schema = [("k", "bigint"), ("v", "double")]
rows = [
{"k": 1, "v": 10.0},
{"k": 2, "v": 20.0},
]
df = spark._create_dataframe_from_rows(rows, schema)
with pytest.raises(Exception, match="only supported in groupBy\\(\\)\\.agg"):
df.with_column("bad", mean_udf(rs.col("v"))).collect()
EXPECTED_CREATE_DATAFRAME_FROM_ROWS_PARITY = [
{"id": 1, "name": "Alice", "ok": True, "d": "2024-01-15"},
{"id": 2, "name": "Bob", "ok": False, "d": "2024-06-10"},
]
def test__create_dataframe_from_rows_schema_pyspark_parity() -> None:
import robin_sparkless as rs
schema = [("id", "int"), ("name", "string"), ("ok", "boolean"), ("d", "date")]
rows = [
{"id": 1, "name": "Alice", "ok": True, "d": "2024-01-15"},
{"id": 2, "name": "Bob", "ok": False, "d": "2024-06-10"},
]
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(rows, schema)
result = sorted(df.collect(), key=lambda r: r["id"])
assert result == EXPECTED_CREATE_DATAFRAME_FROM_ROWS_PARITY
def test_cast_string_to_boolean() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[
{"s": "true"},
{"s": "false"},
{"s": "1"},
{"s": "0"},
{"s": "TRUE"},
{"s": "invalid"},
],
[("s", "string")],
)
result = df.with_column("b", rs.col("s").try_cast("boolean"))
rows = result.collect()
assert rows[0]["b"] is True
assert rows[1]["b"] is False
assert rows[2]["b"] is True
assert rows[3]["b"] is False
assert rows[4]["b"] is True
assert rows[5]["b"] is None
def test_create_dataframe_from_rows_struct_and_array() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
schema = [
("id", "bigint"),
("person", "struct<name:string,age:bigint>"),
("scores", "array<bigint>"),
]
rows = [
{"id": 1, "person": {"name": "Alice", "age": 25}, "scores": [10, 20, 30]},
{"id": 2, "person": {"name": "Bob", "age": 30}, "scores": [5, 15]},
]
df = spark._create_dataframe_from_rows(rows, schema)
result = df.collect()
assert len(result) == 2
assert result[0]["id"] == 1
assert result[0]["person"] == {"name": "Alice", "age": 25}
assert result[0]["scores"] == [10, 20, 30]
assert result[1]["id"] == 2
assert result[1]["person"] == {"name": "Bob", "age": 30}
assert result[1]["scores"] == [5, 15]
def test_regexp_extract_all_and_select_with_expression() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
data = [
{"s": "a1 b22 c333"},
{"s": "no-digits"},
{"s": None},
]
schema = [("s", "string")]
df = spark._create_dataframe_from_rows(data, schema)
result = df.select([rs.regexp_extract_all(rs.col("s"), r"\d+", 0).alias("m")])
rows = result.collect()
assert len(rows) == 3
assert rows[0]["m"] == ["1", "22", "333"]
assert rows[1]["m"] == []
assert rows[2]["m"] is None
result2 = df.select(rs.regexp_extract_all(rs.col("s"), r"\d+", 0).alias("m"))
rows2 = result2.collect()
assert rows2 == rows
result3 = df.select(["s"])
assert result3.collect()[0]["s"] == "a1 b22 c333"
result4 = df.select("s")
assert result4.collect()[0]["s"] == "a1 b22 c333"
def test_pivot_raises_not_implemented() -> None:
import robin_sparkless as rs
spark = rs.SparkSession.builder().app_name("test").get_or_create()
df = spark._create_dataframe_from_rows(
[[1, "x", 10]], [("id", "bigint"), ("pcol", "string"), ("v", "bigint")]
)
with pytest.raises(NotImplementedError, match="pivot is not yet implemented"):
df.pivot("pcol")