from sparkless.testing import get_imports
imports = get_imports()
SparkSession = imports.SparkSession
StringType = imports.StringType
StructType = imports.StructType
StructField = imports.StructField
F = imports.F
class TestColumnSubstr:
def test_basic_substr(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
{"name": "Charlie"},
],
schema=schema,
)
result = df.select(F.col("name").substr(1, 2).alias("partial_name"))
rows = result.collect()
assert len(rows) == 3
assert rows[0]["partial_name"] == "Al" assert rows[1]["partial_name"] == "Bo" assert rows[2]["partial_name"] == "Ch"
def test_substr_from_second_position(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
],
schema=schema,
)
result = df.select(F.col("name").substr(2, 100).alias("from_second"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["from_second"] == "lice" assert rows[1]["from_second"] == "ob"
def test_substr_issue_238_example(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
],
schema=schema,
)
result = df.withColumn("partial_name", F.col("name").substr(1, 2))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["partial_name"] == "Al"
assert rows[1]["partial_name"] == "Bo"
def test_substr_start_at_one(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hello"},
{"text": "World"},
],
schema=schema,
)
result = df.select(F.col("text").substr(1, 3).alias("first_three"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["first_three"] == "Hel" assert rows[1]["first_three"] == "Wor"
def test_substr_start_beyond_length(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hi"},
{"text": "No"},
],
schema=schema,
)
result = df.select(F.col("text").substr(10, 5).alias("beyond"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["beyond"] == ""
assert rows[1]["beyond"] == ""
def test_substr_with_null(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": None},
{"name": "Bob"},
],
schema=schema,
)
result = df.select(F.col("name").substr(1, 2).alias("partial"))
rows = result.collect()
assert len(rows) == 3
assert rows[0]["partial"] == "Al"
assert rows[1]["partial"] is None assert rows[2]["partial"] == "Bo"
def test_substr_length_zero(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hello"},
],
schema=schema,
)
result = df.select(F.col("text").substr(1, 0).alias("empty"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["empty"] == ""
def test_substr_length_exceeds_remaining(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hi"},
],
schema=schema,
)
result = df.select(F.col("text").substr(1, 100).alias("long"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["long"] == "Hi"
def test_substr_in_select(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
],
schema=schema,
)
result = df.select(
F.col("name"),
F.col("name").substr(1, 2).alias("first_two"),
F.col("name")
.substr(3, 100)
.alias("from_third"), )
rows = result.collect()
assert len(rows) == 2
assert rows[0]["name"] == "Alice"
assert rows[0]["first_two"] == "Al"
assert rows[0]["from_third"] == "ice"
assert rows[1]["name"] == "Bob"
assert rows[1]["first_two"] == "Bo"
assert rows[1]["from_third"] == "b"
def test_substr_in_withColumn(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
],
schema=schema,
)
result = df.withColumn("first_char", F.col("name").substr(1, 1))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["first_char"] == "A"
assert rows[1]["first_char"] == "B"
def test_substr_in_filter(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
{"name": "Charlie"},
],
schema=schema,
)
result = df.filter(F.col("name").substr(1, 1) == "A")
rows = result.collect()
assert len(rows) == 1
assert rows[0]["name"] == "Alice"
def test_substr_in_orderBy(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Charlie"},
{"name": "Alice"},
{"name": "Bob"},
],
schema=schema,
)
result = df.orderBy(F.col("name").substr(1, 1))
rows = result.collect()
assert len(rows) == 3
assert rows[0]["name"] == "Alice" assert rows[1]["name"] == "Bob" assert rows[2]["name"] == "Charlie"
def test_substr_equals_substring_function(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
],
schema=schema,
)
result_substr = df.select(F.col("name").substr(1, 2).alias("partial"))
result_substring = df.select(F.substring(F.col("name"), 1, 2).alias("partial"))
rows_substr = result_substr.collect()
rows_substring = result_substring.collect()
assert len(rows_substr) == len(rows_substring)
assert rows_substr[0]["partial"] == rows_substring[0]["partial"]
assert rows_substr[1]["partial"] == rows_substring[1]["partial"]
def test_substr_chained_operations(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
],
schema=schema,
)
result = df.select(F.col("name").substr(1, 2).alias("partial"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["partial"] == "Al"
assert rows[1]["partial"] == "Bo"
def test_substr_empty_string(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": ""},
{"text": "Hello"},
],
schema=schema,
)
result = df.select(F.col("text").substr(1, 2).alias("partial"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["partial"] == "" assert rows[1]["partial"] == "He"
def test_substr_unicode(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hello世界"},
{"text": "测试"},
],
schema=schema,
)
result = df.select(F.col("text").substr(1, 5).alias("partial"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["partial"] == "Hello" assert rows[1]["partial"] == "测试"
def test_substr_negative_start(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hello"},
],
schema=schema,
)
test_cases = [
(-5, 3, "Hel"), (-4, 3, "ell"), (-3, 3, "llo"), (-2, 3, "lo"), (-1, 3, "o"), ]
for start, length, expected in test_cases:
result = df.select(F.col("text").substr(start, length).alias("partial"))
rows = result.collect()
assert rows[0]["partial"] == expected, f"substr({start}, {length}) failed"
def test_substr_zero_start(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hello"},
],
schema=schema,
)
result = df.select(F.col("text").substr(0, 3).alias("partial"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["partial"] == "Hel"
def test_substr_with_alias(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
],
schema=schema,
)
result = df.select(F.col("name").substr(1, 2).alias("first_two_chars"))
rows = result.collect()
assert len(rows) == 1
assert "first_two_chars" in rows[0]
assert rows[0]["first_two_chars"] == "Al"
def test_substr_pyspark_parity_comprehensive(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hello"},
{"text": "World"},
{"text": "Test"},
{"text": ""},
{"text": None},
],
schema=schema,
)
test_cases = [
(1, 3, ["Hel", "Wor", "Tes", "", None]),
(1, 0, ["", "", "", "", None]),
(1, 100, ["Hello", "World", "Test", "", None]),
(2, 3, ["ell", "orl", "est", "", None]),
(0, 3, ["Hel", "Wor", "Tes", "", None]), (
-1,
3,
["o", "d", "t", "", None],
), (
-2,
3,
["lo", "ld", "st", "", None],
), ]
for start, length, expected in test_cases:
result = df.select(F.col("text").substr(start, length).alias("result"))
rows = result.collect()
for i, row in enumerate(rows):
assert row["result"] == expected[i], (
f"substr({start}, {length}) failed for row {i} (text={df.collect()[i]['text']!r}): "
f"expected {expected[i]!r}, got {row['result']!r}"
)
def test_substr_in_groupBy(self, spark):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("value", StringType(), True),
]
)
df = spark.createDataFrame(
[
{"name": "Alice", "value": "A1"},
{"name": "Alice", "value": "A2"},
{"name": "Bob", "value": "B1"},
],
schema=schema,
)
df_with_first_char = df.withColumn("first_char", F.col("name").substr(1, 1))
result = df_with_first_char.groupBy("first_char").agg(
F.count("*").alias("count")
)
rows = result.collect()
assert len(rows) == 2
first_chars = {row["first_char"]: row["count"] for row in rows}
assert first_chars.get("A") == 2
assert first_chars.get("B") == 1
def test_substr_chained_with_other_operations(self, spark):
schema = StructType([StructField("name", StringType(), True)])
df = spark.createDataFrame(
[
{"name": "Alice"},
{"name": "Bob"},
],
schema=schema,
)
result = df.select(F.upper(F.col("name").substr(1, 2)).alias("upper_partial"))
rows = result.collect()
assert len(rows) == 2
assert rows[0]["upper_partial"] == "AL"
assert rows[1]["upper_partial"] == "BO"
def test_substr_very_long_string(self, spark):
long_string = "A" * 1000 + "B" * 1000
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": long_string},
],
schema=schema,
)
result = df.select(F.col("text").substr(1, 100).alias("first_100"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["first_100"] == "A" * 100
def test_substr_start_exceeds_length(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hi"},
],
schema=schema,
)
result = df.select(F.col("text").substr(10, 5).alias("beyond"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["beyond"] == ""
def test_substr_negative_start_exceeds_length(self, spark):
schema = StructType([StructField("text", StringType(), True)])
df = spark.createDataFrame(
[
{"text": "Hi"},
],
schema=schema,
)
result = df.select(F.col("text").substr(-10, 3).alias("result"))
rows = result.collect()
assert len(rows) == 1
assert rows[0]["result"] == "Hi"[:3] if len("Hi") >= 3 else "Hi"