import pytest
from sparkless.testing import (
Mode,
get_imports,
assert_dataframes_equal,
)
class TestUnifiedInfrastructure:
def test_basic_operation(self, spark):
df = spark.createDataFrame([{"id": 1, "name": "Alice"}])
assert df.count() == 1
assert df.columns == ["id", "name"]
@pytest.mark.sparkless_only
def test_sparkless_only(self, spark):
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
@pytest.mark.pyspark_only
def test_pyspark_only(self, spark):
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
@pytest.mark.backend("sparkless")
def test_with_backend_marker_sparkless(self, spark):
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
@pytest.mark.backend("pyspark")
def test_with_backend_marker_pyspark(self, spark):
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
def test_with_mode_info(self, spark, spark_mode):
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
assert spark_mode in [Mode.SPARKLESS, Mode.PYSPARK]
assert spark_mode.value in ["sparkless", "pyspark"]
class TestUnifiedImports:
def test_with_unified_imports(self, spark):
imports = get_imports()
F = imports.F
df = spark.createDataFrame([{"id": 1, "name": "test"}])
result = df.select(F.upper("name").alias("upper_name"))
assert result.collect()[0]["upper_name"] == "TEST"
def test_with_data_types(self, spark):
imports = get_imports()
schema = imports.StructType(
[
imports.StructField("id", imports.IntegerType(), True),
imports.StructField("name", imports.StringType(), True),
]
)
df = spark.createDataFrame([{"id": 1, "name": "Alice"}], schema=schema)
assert df.count() == 1
assert len(df.schema.fields) == 2
def test_with_window_functions(self, spark):
imports = get_imports()
F = imports.F
Window = imports.Window
data = [
{"category": "A", "value": 10},
{"category": "A", "value": 20},
{"category": "B", "value": 30},
]
df = spark.createDataFrame(data)
window = Window.partitionBy("category").orderBy("value")
result = df.withColumn("row_num", F.row_number().over(window))
rows = result.collect()
assert len(rows) == 3
class TestDataFrameComparison:
def test_dataframe_comparison(self, spark):
df1 = spark.createDataFrame(
[
{"id": 1, "value": 10.0},
{"id": 2, "value": 20.0},
]
)
df2 = spark.createDataFrame(
[
{"id": 1, "value": 10.0},
{"id": 2, "value": 20.0},
]
)
assert_dataframes_equal(df1, df2)
def test_comparison_with_tolerance(self, spark):
df1 = spark.createDataFrame(
[
{"id": 1, "value": 10.0000001},
]
)
df2 = spark.createDataFrame(
[
{"id": 1, "value": 10.0},
]
)
assert_dataframes_equal(df1, df2, tolerance=1e-6)
def test_comparison_ignore_order(self, spark):
df1 = spark.createDataFrame(
[
{"id": 1, "value": 10},
{"id": 2, "value": 20},
]
)
df2 = spark.createDataFrame(
[
{"id": 2, "value": 20},
{"id": 1, "value": 10},
]
)
assert_dataframes_equal(df1, df2, check_order=False)