class TestIssue360InputFileName:
def test_input_file_name_returns_string_column(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame(
[
{"dataset": "dataset_a", "table": "table_1"},
{"dataset": "dataset_b", "table": "table_2"},
]
)
result = df.withColumn("InputFileName", F.input_file_name())
rows = result.collect()
assert len(rows) == 2
assert "InputFileName" in rows[0]
assert isinstance(rows[0]["InputFileName"], str)
def test_input_file_name_exact_issue_scenario(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame(
[
("dataset_a", "table_1"),
("dataset_b", "table_2"),
],
["dataset", "table"],
)
df = df.withColumn("InputFileName", F.input_file_name())
rows = df.collect()
assert len(rows) == 2
assert rows[0]["dataset"] == "dataset_a"
assert rows[0]["table"] == "table_1"
assert "InputFileName" in rows[0]
def test_input_file_name_select_only(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"a": 1}, {"a": 2}])
result = df.select(F.input_file_name().alias("path"))
rows = result.collect()
assert len(rows) == 2
assert (
rows[0]["path"] == ""
or rows[0]["path"].startswith("/")
or "\\" in rows[0]["path"]
)
class TestIssue360InputFileNameRobust:
def test_input_file_name_empty_dataframe(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([], "a int")
result = df.withColumn("path", F.input_file_name())
rows = result.collect()
assert len(rows) == 0
def test_input_file_name_single_row(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"id": 1}])
result = df.withColumn("file", F.input_file_name())
rows = result.collect()
assert len(rows) == 1
assert "file" in rows[0]
assert isinstance(rows[0]["file"], str)
def test_input_file_name_after_filter(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"a": 1}, {"a": 2}, {"a": 3}])
result = df.filter(F.col("a") > 1).withColumn("path", F.input_file_name())
rows = result.collect()
assert len(rows) == 2
assert all("path" in r and isinstance(r["path"], str) for r in rows)
def test_input_file_name_after_select(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"x": 1, "y": 2}, {"x": 3, "y": 4}])
result = df.select("x").withColumn("path", F.input_file_name())
rows = result.collect()
assert len(rows) == 2
assert rows[0]["path"] == "" or isinstance(rows[0]["path"], str)
def test_input_file_name_preserves_schema(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"a": 1, "b": "x"}])
result = df.withColumn("path", F.input_file_name())
assert "path" in result.schema.fieldNames()
assert len(result.schema.fields) == 3
rows = result.collect()
assert rows[0]["a"] == 1 and rows[0]["b"] == "x" and "path" in rows[0]
def test_input_file_name_with_show(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"a": 1}])
result = df.withColumn("path", F.input_file_name())
result.show()
def test_input_file_name_all_rows_same_type(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"i": i} for i in range(5)])
result = df.withColumn("path", F.input_file_name())
rows = result.collect()
for r in rows:
assert "path" in r
assert isinstance(r["path"], str)
def test_input_file_name_multiple_columns(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
result = df.select(
F.col("a"),
F.input_file_name().alias("path"),
F.col("b"),
)
rows = result.collect()
assert len(rows) == 2
assert rows[0]["a"] == 1 and rows[0]["b"] == 2 and "path" in rows[0]
def test_input_file_name_alias(self, spark):
from sparkless.testing import get_imports
F = get_imports().F
df = spark.createDataFrame([{"x": 1}])
result = df.select(F.input_file_name().alias("source_file"))
rows = result.collect()
assert len(rows) == 1
assert "source_file" in rows[0]
assert isinstance(rows[0]["source_file"], str)