from __future__ import annotations
def test_rdd_flatmap_words(spark) -> None:
df = spark.createDataFrame(
[("hello world",), ("hello pyspark",), ("flatMap is useful",)],
["line"],
)
rdd = df.rdd.flatMap(lambda row: row["line"].split())
assert rdd.collect() == [
"hello",
"world",
"hello",
"pyspark",
"flatMap",
"is",
"useful",
]
def test_rdd_flatmap_empty_iterable(spark) -> None:
df = spark.createDataFrame([(1,), (0,), (2,)], ["x"])
rdd = df.rdd.flatMap(lambda row: range(row["x"]) if row["x"] > 0 else [])
assert rdd.collect() == [0, 0, 1]
def test_rdd_flatmap_then_map(spark) -> None:
df = spark.createDataFrame([("a b",), ("c",)], ["line"])
rdd = df.rdd.flatMap(lambda row: row["line"].split()).map(lambda s: s.upper())
assert rdd.collect() == ["A", "B", "C"]
def test_rdd_flatmap_empty_rdd(spark) -> None:
from tests.fixtures.spark_imports import get_spark_imports
imports = get_spark_imports()
schema = imports.StructType([imports.StructField("line", imports.StringType())])
df = spark.createDataFrame([], schema)
rdd = df.rdd.flatMap(lambda row: row["line"].split())
assert rdd.collect() == []
assert rdd.count() == 0
def test_rdd_flatmap_one_element_per_row(spark) -> None:
df = spark.createDataFrame([(1,), (2,), (3,)], ["x"])
rdd = df.rdd.flatMap(lambda row: (row["x"] * 10,))
assert rdd.collect() == [10, 20, 30]
def test_rdd_flatmap_tuples_for_pair_ops(spark) -> None:
df = spark.createDataFrame([("a b",), ("a c",)], ["line"])
rdd = df.rdd.flatMap(lambda row: ((word, 1) for word in row["line"].split()))
collected = sorted(rdd.collect())
assert collected == [("a", 1), ("a", 1), ("b", 1), ("c", 1)]
def test_rdd_flatmap_then_filter(spark) -> None:
df = spark.createDataFrame([("one",), ("two",), ("three",)], ["word"])
rdd = df.rdd.flatMap(lambda row: [row["word"], row["word"].upper()]).filter(
lambda x: x.isupper() or len(x) > 3
)
assert sorted(rdd.collect()) == ["ONE", "THREE", "TWO", "three"]
def test_rdd_flatmap_then_count_take_first(spark) -> None:
df = spark.createDataFrame([("x y",), ("z",)], ["line"])
rdd = df.rdd.flatMap(lambda row: row["line"].split())
assert rdd.count() == 3
assert len(rdd.take(2)) == 2
assert rdd.take(2) == ["x", "y"]
assert rdd.first() == "x"
def test_rdd_flatmap_empty_string_split(spark) -> None:
df = spark.createDataFrame([("a b",), ("",), ("c",)], ["line"])
rdd = df.rdd.flatMap(lambda row: row["line"].split())
assert rdd.collect() == ["a", "b", "c"]
def test_rdd_flatmap_then_reduce(spark) -> None:
df = spark.createDataFrame([(1,), (2,), (3,)], ["x"])
rdd = df.rdd.flatMap(lambda row: (row["x"], row["x"] + 1))
total = rdd.reduce(lambda a, b: a + b)
assert total == (1 + 2 + 2 + 3 + 3 + 4)
def test_rdd_flatmap_chain_double_flatmap(spark) -> None:
df = spark.createDataFrame([("ab cd",)], ["line"])
rdd = df.rdd.flatMap(lambda row: row["line"].split()).flatMap(
lambda w: [w, w.upper()]
)
assert sorted(rdd.collect()) == sorted(["ab", "AB", "cd", "CD"])
def test_rdd_flatmap_preserves_order(spark) -> None:
df = spark.createDataFrame([("1",), ("2",), ("3",)], ["x"])
rdd = df.rdd.flatMap(lambda row: [row["x"], row["x"] + "x"])
assert rdd.collect() == ["1", "1x", "2", "2x", "3", "3x"]