from sparkless.testing import get_imports
_imports = get_imports()
SparkSession = _imports.SparkSession
F = _imports.F
class TestIssue328SplitLimit:
def _get_unique_app_name(self, test_name: str) -> str:
import os
import threading
thread_id = threading.current_thread().ident
process_id = os.getpid()
return f"{test_name}_{process_id}_{thread_id}"
def test_split_with_limit(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "StringValue": "A,B,C,D,E,F"},
]
)
df = df.withColumn("StringArray", F.split(F.col("StringValue"), ",", 3))
df = df.withColumn("StringArray", F.explode(F.col("StringArray")))
rows = df.collect()
assert len(rows) == 3
values = [r["StringArray"] for r in rows]
assert "A" in values
assert "B" in values
assert "C,D,E,F" in values
assert "C" not in values finally:
spark.stop()
def test_split_with_limit_1(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": "A,B,C,D"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 1))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 1
values = [r["Array"] for r in rows]
assert "A,B,C,D" in values finally:
spark.stop()
def test_split_with_limit_2(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": "A,B,C,D"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 2))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 2
values = [r["Array"] for r in rows]
assert "A" in values
assert "B,C,D" in values
finally:
spark.stop()
def test_split_without_limit(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": "A,B,C,D"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ","))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 4
values = [r["Array"] for r in rows]
assert "A" in values
assert "B" in values
assert "C" in values
assert "D" in values
finally:
spark.stop()
def test_split_with_limit_larger_than_splits(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": "A,B,C"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 10))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 3
values = [r["Array"] for r in rows]
assert "A" in values
assert "B" in values
assert "C" in values
finally:
spark.stop()
def test_split_with_limit_minus_one(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": "A,B,C,D"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", -1))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 4
values = [r["Array"] for r in rows]
assert "A" in values
assert "B" in values
assert "C" in values
assert "D" in values
finally:
spark.stop()
def test_split_with_null_values(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": "A,B,C"},
{"Name": "Bob", "Value": None},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 2))
rows = df.collect()
assert len(rows) == 2
row_alice = [r for r in rows if r["Name"] == "Alice"][0]
row_bob = [r for r in rows if r["Name"] == "Bob"][0]
assert row_alice["Array"] is not None
assert row_bob["Array"] is None
finally:
spark.stop()
def test_split_with_empty_string(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": ""},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 2))
rows = df.collect()
assert len(rows) == 1
row_alice = [r for r in rows if r["Name"] == "Alice"][0]
assert row_alice["Array"] == [""]
finally:
spark.stop()
def test_split_in_select(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Name": "Alice", "Value": "A,B,C,D"},
]
)
df = df.select(
"Name",
F.split(F.col("Value"), ",", 2).alias("Array"),
)
rows = df.collect()
assert len(rows) == 1
row_alice = [r for r in rows if r["Name"] == "Alice"][0]
assert len(row_alice["Array"]) == 2 assert row_alice["Array"][0] == "A"
assert row_alice["Array"][1] == "B,C,D"
finally:
spark.stop()
def test_split_multi_char_delimiter(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "A::B::C::D"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), "::", 2))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 2
values = [r["Array"] for r in rows]
assert "A" in values
assert "B::C::D" in values
finally:
spark.stop()
def test_split_special_regex_characters(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "192.168.1.1"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), "\\.", 3))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 3
values = [r["Array"] for r in rows]
assert "192" in values
assert "168" in values
assert "1.1" in values
finally:
spark.stop()
def test_split_whitespace_delimiter(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "one two three four"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), " ", 2))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 2
values = [r["Array"] for r in rows]
assert "one" in values
assert "two three four" in values
finally:
spark.stop()
def test_split_consecutive_delimiters(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "A,,B,,C"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 3))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 3
values = [r["Array"] for r in rows]
assert "A" in values
assert "" in values assert "B,,C" in values
finally:
spark.stop()
def test_split_delimiter_not_found(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "NoDelimiterHere"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 2))
rows = df.collect()
assert len(rows) == 1
assert rows[0]["Array"] == ["NoDelimiterHere"]
finally:
spark.stop()
def test_split_limit_zero(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "A,B,C,D"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 0))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 4
values = [r["Array"] for r in rows]
assert "A" in values
assert "B" in values
assert "C" in values
assert "D" in values
finally:
spark.stop()
def test_split_unicode_characters(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "José|MarÃa|José"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), "\\|", 2))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 2
values = [r["Array"] for r in rows]
assert "José" in values
assert "MarÃa|José" in values
finally:
spark.stop()
def test_split_very_long_string(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
long_str = ",".join([f"item{i}" for i in range(100)])
df = spark.createDataFrame(
[
{"Value": long_str},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 10))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 10 values = [r["Array"] for r in rows]
assert "item0" in values
last_item = [r["Array"] for r in rows if "item99" in r["Array"]]
assert len(last_item) > 0
finally:
spark.stop()
def test_split_empty_delimiter(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "ABC"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ""))
rows = df.collect()
assert len(rows) == 1
assert rows[0]["Array"] == ["A", "B", "C"]
finally:
spark.stop()
def test_split_leading_trailing_delimiters(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": ",A,B,C,"},
]
)
df = df.withColumn("Array", F.split(F.col("Value"), ",", 4))
df = df.withColumn("Array", F.explode(F.col("Array")))
rows = df.collect()
assert len(rows) == 4
values = [r["Array"] for r in rows]
assert "" in values
assert "A" in values
assert "B" in values
assert "C," in values finally:
spark.stop()
def test_split_different_limit_values(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
base_value = "A,B,C,D,E"
df = spark.createDataFrame(
[
{"Value": base_value},
]
)
for limit in [1, 2, 3, 4, 5, 6, 10]:
result_df = df.withColumn("Array", F.split(F.col("Value"), ",", limit))
rows = result_df.collect()
arr = rows[0]["Array"]
expected_parts = min(limit, 5) if limit > 0 else 1
assert len(arr) == expected_parts, (
f"limit={limit}: expected {expected_parts} parts, got {len(arr)}"
)
finally:
spark.stop()
def test_split_in_filter_context(self):
import inspect
test_name = inspect.stack()[1].function
spark = SparkSession.builder.appName(
self._get_unique_app_name(test_name)
).getOrCreate()
try:
df = spark.createDataFrame(
[
{"Value": "A,B,C", "Category": "test1"},
{"Value": "X,Y", "Category": "test2"},
{"Value": "P,Q,R,S", "Category": "test3"},
]
)
df = df.withColumn(
"First", F.element_at(F.split(F.col("Value"), ",", 2), 1)
)
df = df.filter(F.col("First") == "A")
rows = df.collect()
assert len(rows) == 1
assert rows[0]["Category"] == "test1"
finally:
spark.stop()