import pytest
from sparkless.testing import get_imports
_imports = get_imports()
SparkSession = _imports.SparkSession
F = _imports.F
StructType = _imports.StructType
StructField = _imports.StructField
StringType = _imports.StringType
BooleanType = _imports.BooleanType
class TestDoubleJoinEmptyAggregated:
@pytest.fixture(autouse=True)
def setup_method(self):
self.spark = SparkSession.builder.appName("double_join_bug").getOrCreate()
yield
self.spark.stop()
def test_columns_preserved_in_double_join_with_empty_aggregated(self):
patients_data = [
("PAT-001", "John", "Doe", 25, "adult", "M", "ProviderA"),
("PAT-002", "Jane", "Smith", 30, "adult", "F", "ProviderB"),
]
patients_df = self.spark.createDataFrame(
patients_data,
[
"patient_id",
"first_name",
"last_name",
"age",
"age_group",
"gender",
"insurance_provider",
],
)
patients_df = patients_df.withColumn(
"full_name", F.concat(F.col("first_name"), F.lit(" "), F.col("last_name"))
)
empty_labs_schema = StructType(
[
StructField("patient_id", StringType(), False),
StructField("is_abnormal", BooleanType(), True),
StructField("result_category", StringType(), True),
]
)
empty_labs_df = self.spark.createDataFrame([], empty_labs_schema)
empty_diagnoses_schema = StructType(
[
StructField("patient_id", StringType(), False),
StructField("is_chronic", BooleanType(), True),
StructField("risk_level", StringType(), True),
]
)
empty_diagnoses_df = self.spark.createDataFrame([], empty_diagnoses_schema)
lab_metrics = empty_labs_df.groupBy("patient_id").agg(
F.count("*").alias("total_labs"),
F.sum(F.when(F.col("is_abnormal"), 1).otherwise(0)).alias("abnormal_labs"),
F.sum(
F.when(
F.col("result_category").isin(["critical_high", "critical_low"]),
1,
).otherwise(0)
).alias("critical_labs"),
)
diagnosis_metrics = empty_diagnoses_df.groupBy("patient_id").agg(
F.count("*").alias("total_diagnoses"),
F.sum(F.when(F.col("is_chronic"), 1).otherwise(0)).alias(
"chronic_conditions"
),
F.sum(
F.when(F.col("risk_level") == "high", 3)
.when(F.col("risk_level") == "medium", 2)
.otherwise(1)
).alias("risk_score_sum"),
)
assert "total_labs" in lab_metrics.columns
assert "abnormal_labs" in lab_metrics.columns
assert "critical_labs" in lab_metrics.columns
assert "total_diagnoses" in diagnosis_metrics.columns
assert "chronic_conditions" in diagnosis_metrics.columns
assert "risk_score_sum" in diagnosis_metrics.columns
result1 = patients_df.join(lab_metrics, "patient_id", "left")
assert "total_labs" in result1.columns
assert "abnormal_labs" in result1.columns
assert "critical_labs" in result1.columns
result2 = result1.join(diagnosis_metrics, "patient_id", "left")
expected_columns = [
"patient_id",
"first_name",
"last_name",
"age",
"age_group",
"gender",
"insurance_provider",
"full_name",
"total_labs",
"abnormal_labs",
"critical_labs",
"total_diagnoses",
"chronic_conditions",
"risk_score_sum",
]
for col_name in expected_columns:
assert col_name in result2.columns, (
f"Column '{col_name}' is missing from result. "
f"Available columns: {result2.columns}"
)
assert len(result2.columns) == len(set(result2.columns)), (
f"Duplicate columns found: {result2.columns}"
)
result3 = result2.withColumn(
"overall_risk_score", F.coalesce(F.col("risk_score_sum"), F.lit(0))
)
assert "overall_risk_score" in result3.columns
rows = result3.collect()
assert len(rows) == 2
for row in rows:
assert "patient_id" in row
assert "total_labs" in row
assert "total_diagnoses" in row
assert "risk_score_sum" in row
assert "overall_risk_score" in row
assert row["total_labs"] in [None, 0]
assert row["total_diagnoses"] in [None, 0]
assert row["risk_score_sum"] in [None, 0]
assert row["overall_risk_score"] in [None, 0]