from __future__ import annotations
import os
import pytest
from tests.sql.conftest import jdbc_available
pytestmark = pytest.mark.skipif(
not jdbc_available(),
reason="JDBC support not enabled in this build",
)
_HAS_ENV_URL = "SPARKLESS_TEST_JDBC_URL" in os.environ
def _use_env_url() -> bool:
return _HAS_ENV_URL
@pytest.fixture
def jdbc_url(request, spark):
if _use_env_url():
return os.environ["SPARKLESS_TEST_JDBC_URL"]
try:
conn = request.getfixturevalue("postgres_jdbc")
return conn.url
except pytest.FixtureLookupError:
pytest.skip(
"PostgreSQL container not available and SPARKLESS_TEST_JDBC_URL not set"
)
@pytest.fixture
def jdbc_props(request):
if _use_env_url():
return {
"user": os.getenv("SPARKLESS_TEST_JDBC_USER", ""),
"password": os.getenv("SPARKLESS_TEST_JDBC_PASSWORD", ""),
}
try:
conn = request.getfixturevalue("postgres_jdbc")
return conn.properties
except pytest.FixtureLookupError:
return {}
def test_read_jdbc_table_round_trip(spark, jdbc_url, jdbc_props) -> None:
df = spark.read.jdbc(
url=jdbc_url, table="sparkless_jdbc_test", properties=jdbc_props
)
rows = df.collect()
assert isinstance(rows, list)
assert len(rows) >= 2
schema_names = [f.name for f in df.schema.fields]
assert "id" in schema_names and "name" in schema_names
def test_read_jdbc_via_format_options_load(spark, jdbc_url, jdbc_props) -> None:
df = (
spark.read.format("jdbc")
.option("url", jdbc_url)
.option("dbtable", "sparkless_jdbc_test")
.options(jdbc_props)
.load(".")
)
rows = df.collect()
assert isinstance(rows, list)
assert len(rows) >= 2
def test_read_jdbc_with_query_option(spark, jdbc_url, jdbc_props) -> None:
df = (
spark.read.format("jdbc")
.option("url", jdbc_url)
.option("query", "SELECT id, name FROM sparkless_jdbc_test WHERE id = 1")
.options(jdbc_props)
.load(".")
)
rows = df.collect()
assert isinstance(rows, list)
assert len(rows) <= 1
if len(rows) == 1:
row = rows[0]
assert row["id"] == 1 or row["id"] == "1"
def test_write_jdbc_append(spark, jdbc_url, jdbc_props) -> None:
df = spark.createDataFrame(
[(100, "append_a"), (101, "append_b")], schema="id bigint, name string"
)
df.write.jdbc(
url=jdbc_url, table="sparkless_jdbc_test", properties=jdbc_props, mode="append"
)
read_df = spark.read.jdbc(
url=jdbc_url, table="sparkless_jdbc_test", properties=jdbc_props
)
rows = read_df.collect()
assert len(rows) >= 2
def _norm_id(val) -> int | None:
if val is None:
return None
if isinstance(val, str) and val.isdigit():
return int(val)
return int(val)
def test_write_then_read_back(spark, jdbc_url, jdbc_props) -> None:
table = "sparkless_jdbc_writeread_test"
data = [(10, "ten"), (20, "twenty")]
df = spark.createDataFrame(data, schema="id bigint, name string")
df.write.jdbc(url=jdbc_url, table=table, properties=jdbc_props, mode="overwrite")
read_df = spark.read.jdbc(url=jdbc_url, table=table, properties=jdbc_props)
rows = read_df.collect()
assert len(rows) == 2
ids = {_norm_id(r["id"]) for r in rows}
names = {r["name"] for r in rows}
assert ids == {10, 20}
assert names == {"ten", "twenty"}
def test_write_jdbc_overwrite_replaces_content(spark, jdbc_url, jdbc_props) -> None:
table = "sparkless_jdbc_writeread_test"
df1 = spark.createDataFrame([(1, "first")], schema="id bigint, name string")
df1.write.jdbc(url=jdbc_url, table=table, properties=jdbc_props, mode="overwrite")
_ = spark.read.jdbc(url=jdbc_url, table=table, properties=jdbc_props).count()
df2 = spark.createDataFrame(
[(2, "second"), (3, "third")], schema="id bigint, name string"
)
df2.write.jdbc(url=jdbc_url, table=table, properties=jdbc_props, mode="overwrite")
rows = spark.read.jdbc(url=jdbc_url, table=table, properties=jdbc_props).collect()
assert len(rows) == 2
names = {r["name"] for r in rows}
assert "second" in names and "third" in names
def test_jdbc_missing_url_raises(spark) -> None:
with pytest.raises(Exception) as exc_info:
spark.read.format("jdbc").option("dbtable", "t").load(".")
assert "url" in str(exc_info.value).lower() or "url" in str(exc_info.value)
def test_jdbc_empty_properties_allowed(spark, jdbc_url) -> None:
df = spark.read.jdbc(url=jdbc_url, table="sparkless_jdbc_test", properties={})
rows = df.collect()
assert isinstance(rows, list)