from tests.fixtures.parity_base import ParityTestBase
class TestSQLQueriesParity(ParityTestBase):
def test_basic_select(self, spark, table_prefix):
expected = self.load_expected("sql_operations", "basic_select")
tbl = f"{table_prefix}_test_table"
df = spark.createDataFrame(expected["input_data"])
df.write.mode("overwrite").saveAsTable(tbl)
result = spark.sql(f"SELECT id, name, age FROM {tbl}")
self.assert_parity(result, expected)
def test_filtered_select(self, spark, table_prefix):
expected = self.load_expected("sql_operations", "filtered_select")
tbl = f"{table_prefix}_test_table"
df = spark.createDataFrame(expected["input_data"])
df.write.mode("overwrite").saveAsTable(tbl)
result = spark.sql(f"SELECT * FROM {tbl} WHERE age > 30")
self.assert_parity(result, expected)
def test_group_by(self, spark, table_prefix):
expected = self.load_expected("sql_operations", "group_by")
tbl = f"{table_prefix}_test_table"
df = spark.createDataFrame(expected["input_data"])
df.write.mode("overwrite").saveAsTable(tbl)
result = spark.sql(f"SELECT COUNT(*) as count FROM {tbl} GROUP BY (age > 30)")
self.assert_parity(result, expected)
def test_aggregation(self, spark, table_prefix):
expected = self.load_expected("sql_operations", "aggregation")
tbl = f"{table_prefix}_test_table"
df = spark.createDataFrame(expected["input_data"])
df.write.mode("overwrite").saveAsTable(tbl)
result = spark.sql(f"SELECT AVG(salary) as avg_salary FROM {tbl}")
self.assert_parity(result, expected)