from tests.tools.parity_base import ParityTestBase
from sparkless.testing import get_imports
class TestArrayFunctionsParity(ParityTestBase):
def test_array_contains(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "array_contains")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.array_contains(df.scores, 90))
self.assert_parity(result, expected)
def test_array_position(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "array_position")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.array_position(df.scores, 90))
self.assert_parity(result, expected)
def test_size(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "size")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.size(df.scores))
self.assert_parity(result, expected)
def test_element_at(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "element_at")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.element_at(df.scores, 2))
self.assert_parity(result, expected)
def test_explode(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "explode")
df = spark.createDataFrame(expected["input_data"])
result = df.select(df.name, F.explode(df.scores).alias("score"))
self.assert_parity(result, expected)
def test_array_distinct(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "array_distinct")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.array_distinct(df.tags))
result = result.select(
F.array_sort(F.col("array_distinct(tags)")).alias("array_distinct(tags)")
)
self.assert_parity(result, expected)
def test_array_join(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "array_join")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.array_join(df.arr1, "-"))
self.assert_parity(result, expected)
def test_array_union(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "array_union")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.array_union(df.arr1, df.arr2))
self.assert_parity(result, expected)
def test_array_sort(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "array_sort")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.array_sort(df.arr3))
self.assert_parity(result, expected)
def test_array_remove(self, spark):
imports = get_imports()
F = imports.F
expected = self.load_expected("arrays", "array_remove")
df = spark.createDataFrame(expected["input_data"])
result = df.select(F.array_remove(df.scores, 90))
self.assert_parity(result, expected)