from tests.tools.parity_base import ParityTestBase
class TestSetOperationsParity(ParityTestBase):
def test_union(self, spark):
expected = self.load_expected("set_operations", "union")
all_data = expected["input_data"]
df1_data = (
all_data[:3] if len(all_data) >= 6 else all_data[: len(all_data) // 2]
)
df2_data = (
all_data[3:6] if len(all_data) >= 6 else all_data[len(all_data) // 2 :]
)
df1 = spark.createDataFrame(df1_data)
df2 = spark.createDataFrame(df2_data)
result = df1.union(df2)
self.assert_parity(result, expected)
def test_union_all(self, spark):
expected = self.load_expected("set_operations", "union_all")
all_data = expected["input_data"]
df1_data = (
all_data[:3] if len(all_data) >= 6 else all_data[: len(all_data) // 2]
)
df2_data = (
all_data[3:6] if len(all_data) >= 6 else all_data[len(all_data) // 2 :]
)
df1 = spark.createDataFrame(df1_data)
df2 = spark.createDataFrame(df2_data)
result = df1.union(df2)
self.assert_parity(result, expected)
def test_intersect(self, spark):
expected = self.load_expected("set_operations", "intersect")
all_data = expected["input_data"]
df1_data = (
all_data[:3] if len(all_data) >= 6 else all_data[: len(all_data) // 2]
)
df2_data = (
all_data[3:6] if len(all_data) >= 6 else all_data[len(all_data) // 2 :]
)
df1 = spark.createDataFrame(df1_data)
df2 = spark.createDataFrame(df2_data)
result = df1.intersect(df2)
self.assert_parity(result, expected)
def test_except(self, spark):
expected = self.load_expected("set_operations", "except")
all_data = expected["input_data"]
df1_data = (
all_data[:3] if len(all_data) >= 6 else all_data[: len(all_data) // 2]
)
df2_data = (
all_data[3:6] if len(all_data) >= 6 else all_data[len(all_data) // 2 :]
)
df1 = spark.createDataFrame(df1_data)
df2 = spark.createDataFrame(df2_data)
result = df1.exceptAll(df2)
self.assert_parity(result, expected)