import os
import pytest
from tests.fixtures.comparison import assert_dataframes_equal
_skip_if_pyspark_only = pytest.mark.skipif(
(os.environ.get("SPARKLESS_TEST_BACKEND") or "").strip().lower() == "pyspark",
reason="Requires mock/both backends; skipped in PySpark-only run",
)
class TestUnifiedInfrastructure:
def test_basic_operation(self, spark):
df = spark.createDataFrame([{"id": 1, "name": "Alice"}])
assert df.count() == 1
assert df.columns == ["id", "name"]
@pytest.mark.backend("mock")
@_skip_if_pyspark_only
def test_mock_only(self, spark):
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
@pytest.mark.backend("pyspark")
def test_pyspark_only(self, spark):
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
@pytest.mark.backend("both")
@_skip_if_pyspark_only
def test_comparison(self, mock_spark_session, pyspark_session):
data = [{"id": 1, "value": 10}, {"id": 2, "value": 20}]
mock_df = mock_spark_session.createDataFrame(data)
pyspark_df = pyspark_session.createDataFrame(data)
assert_dataframes_equal(mock_df, pyspark_df)
mock_filtered = mock_df.filter(mock_df.id > 1)
pyspark_filtered = pyspark_df.filter(pyspark_df.id > 1)
assert_dataframes_equal(mock_filtered, pyspark_filtered)
@pytest.mark.backend("both")
@_skip_if_pyspark_only
def test_aggregation_comparison(self, mock_spark_session, pyspark_session):
from tests.fixtures.spark_backend import BackendType
from tests.fixtures.spark_imports import get_spark_imports
mock_F = get_spark_imports(BackendType.MOCK).F
pyspark_F = get_spark_imports(BackendType.PYSPARK).F
data = [
{"category": "A", "value": 10},
{"category": "A", "value": 20},
{"category": "B", "value": 30},
]
mock_df = mock_spark_session.createDataFrame(data)
pyspark_df = pyspark_session.createDataFrame(data)
mock_result = mock_df.groupBy("category").agg(
mock_F.sum("value").alias("total")
)
pyspark_result = pyspark_df.groupBy("category").agg(
pyspark_F.sum("value").alias("total")
)
assert_dataframes_equal(
mock_result,
pyspark_result,
tolerance=1e-6,
check_schema=False,
check_order=False,
)
def test_with_backend_info(self, spark, spark_backend):
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
backend_name = spark_backend.value
assert backend_name in ["mock", "pyspark", "robin"]
class TestUnifiedImports:
def test_with_unified_imports(self, spark):
from tests.fixtures.spark_imports import get_imports
SparkSession, F, StructType = get_imports()
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1
def test_with_full_imports_object(self, spark):
from tests.fixtures.spark_imports import get_spark_imports
_ = get_spark_imports()
df = spark.createDataFrame([{"id": 1}])
assert df.count() == 1