from __future__ import annotations
import tempfile
import os
import pytest
from tests.sql.conftest import jdbc_available
pytestmark = pytest.mark.skipif(
not jdbc_available(),
reason="JDBC/SQLite support not enabled in this build",
)
def test_read_jdbc_sqlite_basic(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)")
conn.execute("INSERT INTO test_table (id, name) VALUES (1, 'Alice')")
conn.execute("INSERT INTO test_table (id, name) VALUES (2, 'Bob')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="test_table", properties={})
rows = df.collect()
assert len(rows) == 2
names = {r["name"] for r in rows}
assert names == {"Alice", "Bob"}
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_read_jdbc_with_options_api(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT)")
conn.execute("INSERT INTO users (id, email) VALUES (1, 'a@test.com')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.option("fetchsize", "100").jdbc(
url=url, table="users", properties={}
)
rows = df.collect()
assert len(rows) == 1
assert rows[0]["email"] == "a@test.com"
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_write_jdbc_sqlite_append(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE append_test (id INTEGER, value TEXT)")
conn.execute("INSERT INTO append_test VALUES (1, 'initial')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.createDataFrame([(2, "appended")], schema="id int, value string")
df.write.jdbc(url=url, table="append_test", properties={}, mode="append")
read_df = spark.read.jdbc(url=url, table="append_test", properties={})
rows = read_df.collect()
assert len(rows) == 2
values = {r["value"] for r in rows}
assert values == {"initial", "appended"}
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_write_jdbc_sqlite_overwrite(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE overwrite_test (id INTEGER, value TEXT)")
conn.execute("INSERT INTO overwrite_test VALUES (1, 'old')")
conn.execute("INSERT INTO overwrite_test VALUES (2, 'data')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.createDataFrame([(100, "new")], schema="id int, value string")
df.write.jdbc(url=url, table="overwrite_test", properties={}, mode="overwrite")
read_df = spark.read.jdbc(url=url, table="overwrite_test", properties={})
rows = read_df.collect()
assert len(rows) == 1
assert rows[0]["value"] == "new"
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_format_jdbc_load(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE format_test (id INTEGER, name TEXT)")
conn.execute("INSERT INTO format_test VALUES (1, 'test')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = (
spark.read.format("jdbc")
.option("url", url)
.option("dbtable", "format_test")
.load(".")
)
rows = df.collect()
assert len(rows) == 1
assert rows[0]["name"] == "test"
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_with_query_option(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute(
"CREATE TABLE query_test (id INTEGER, category TEXT, amount INTEGER)"
)
conn.execute("INSERT INTO query_test VALUES (1, 'A', 100)")
conn.execute("INSERT INTO query_test VALUES (2, 'B', 200)")
conn.execute("INSERT INTO query_test VALUES (3, 'A', 150)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = (
spark.read.format("jdbc")
.option("url", url)
.option(
"query",
"SELECT category, SUM(amount) as total FROM query_test GROUP BY category",
)
.load(".")
)
rows = df.collect()
assert len(rows) == 2
totals = {r["category"]: r["total"] for r in rows}
assert totals["A"] == 250
assert totals["B"] == 200
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_session_init_statement(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE session_init_test (id INTEGER, value TEXT)")
conn.execute("INSERT INTO session_init_test VALUES (1, 'original')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = (
spark.read.format("jdbc")
.option("url", url)
.option("dbtable", "session_init_test")
.option("sessionInitStatement", "SELECT 1")
.load(".")
)
rows = df.collect()
assert len(rows) == 1
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_batchsize_write(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE batch_test (id INTEGER, value TEXT)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
data = [(i, f"row_{i}") for i in range(100)]
df = spark.createDataFrame(data, schema="id int, value string")
df.write.jdbc(
url=url, table="batch_test", properties={"batchsize": "10"}, mode="append"
)
read_df = spark.read.jdbc(url=url, table="batch_test", properties={})
assert read_df.count() == 100
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_truncate_option(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE truncate_test (id INTEGER, name TEXT)")
conn.execute("INSERT INTO truncate_test VALUES (1, 'old')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.createDataFrame([(2, "new")], schema="id int, name string")
df.write.jdbc(
url=url,
table="truncate_test",
properties={"truncate": "true"},
mode="overwrite",
)
read_df = spark.read.jdbc(url=url, table="truncate_test", properties={})
rows = read_df.collect()
assert len(rows) == 1
assert rows[0]["name"] == "new"
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_error_if_exists_mode(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE error_test (id INTEGER, name TEXT)")
conn.execute("INSERT INTO error_test VALUES (1, 'exists')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.createDataFrame([(2, "new")], schema="id int, name string")
with pytest.raises(Exception) as exc_info:
df.write.jdbc(url=url, table="error_test", properties={}, mode="error")
err_msg = str(exc_info.value).lower()
assert "exists" in err_msg or "not empty" in err_msg or "data" in err_msg
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_ignore_mode(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE ignore_test (id INTEGER, name TEXT)")
conn.execute("INSERT INTO ignore_test VALUES (1, 'original')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.createDataFrame([(2, "new")], schema="id int, name string")
df.write.jdbc(url=url, table="ignore_test", properties={}, mode="ignore")
read_df = spark.read.jdbc(url=url, table="ignore_test", properties={})
rows = read_df.collect()
assert len(rows) == 1
assert rows[0]["name"] == "original"
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_prepare_query(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE prepare_source (id INTEGER, value INTEGER)")
conn.execute("INSERT INTO prepare_source VALUES (1, 100)")
conn.execute("INSERT INTO prepare_source VALUES (2, 200)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = (
spark.read.format("jdbc")
.option("url", url)
.option(
"prepareQuery",
"CREATE TEMP VIEW prepare_view AS SELECT id, value * 2 as doubled FROM prepare_source",
)
.option("dbtable", "prepare_view")
.load(".")
)
rows = df.collect()
assert len(rows) == 2
doubled_values = {r["doubled"] for r in rows}
assert doubled_values == {200, 400}
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_memory_database(spark) -> None:
url = "jdbc:sqlite::memory:"
df = spark.createDataFrame([(1, "test")], schema="id int, name string")
try:
df.write.jdbc(url=url, table="mem_test", properties={}, mode="overwrite")
except Exception:
pass
def test_jdbc_integer_types_roundtrip(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE int_test (small_int INTEGER, big_int INTEGER)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
data = [
(0, 0),
(1, 100),
(-1, -100),
(32767, 2147483647),
(-32768, -2147483648),
]
df = spark.createDataFrame(data, schema="small_int int, big_int bigint")
df.write.jdbc(url=url, table="int_test", properties={}, mode="append")
read_df = spark.read.jdbc(url=url, table="int_test", properties={})
rows = read_df.collect()
assert len(rows) == 5
small_vals = {r["small_int"] for r in rows}
assert 0 in small_vals
assert 1 in small_vals
assert -1 in small_vals
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_float_types_roundtrip(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE float_test (value REAL)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
data = [(0.0,), (1.5,), (-1.5,), (3.14159,), (1e10,)]
df = spark.createDataFrame(data, schema="value double")
df.write.jdbc(url=url, table="float_test", properties={}, mode="append")
read_df = spark.read.jdbc(url=url, table="float_test", properties={})
rows = read_df.collect()
assert len(rows) == 5
values = [r["value"] for r in rows]
assert any(abs(v - 3.14159) < 0.0001 for v in values)
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_null_handling(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE null_test (id INTEGER, name TEXT, value REAL)")
conn.execute("INSERT INTO null_test VALUES (1, NULL, 1.0)")
conn.execute("INSERT INTO null_test VALUES (2, 'present', NULL)")
conn.execute("INSERT INTO null_test VALUES (NULL, 'no_id', 2.0)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="null_test", properties={})
rows = df.collect()
assert len(rows) == 3
names = [r["name"] for r in rows]
assert None in names or any(n is None for n in names)
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_special_characters_in_strings(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE special_chars (id INTEGER, text TEXT)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
special_strings = [
(1, "Hello, World!"),
(2, "Quote's test"),
(3, 'Double "quotes"'),
(4, "Line\nbreak"),
(5, "Tab\there"),
(6, "Unicode: 日本語 🎉"),
(7, "Backslash: \\path\\to\\file"),
]
df = spark.createDataFrame(special_strings, schema="id int, text string")
df.write.jdbc(url=url, table="special_chars", properties={}, mode="append")
read_df = spark.read.jdbc(url=url, table="special_chars", properties={})
rows = read_df.collect()
assert len(rows) == 7
texts = {r["text"] for r in rows}
assert "Hello, World!" in texts
assert "Quote's test" in texts
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_options_dict_api(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE options_test (id INTEGER, name TEXT)")
conn.execute("INSERT INTO options_test VALUES (1, 'test')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
opts = {
"url": url,
"dbtable": "options_test",
}
df = spark.read.format("jdbc").options(opts).load(".")
rows = df.collect()
assert len(rows) == 1
assert rows[0]["name"] == "test"
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_schema_fields(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute(
"CREATE TABLE schema_test (user_id INTEGER, user_name TEXT, balance REAL)"
)
conn.execute("INSERT INTO schema_test VALUES (1, 'Alice', 100.50)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="schema_test", properties={})
field_names = [f.name for f in df.schema.fields]
assert "user_id" in field_names
assert "user_name" in field_names
assert "balance" in field_names
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_count_operation(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE count_test (id INTEGER)")
for i in range(50):
conn.execute("INSERT INTO count_test VALUES (?)", (i,))
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="count_test", properties={})
assert df.count() == 50
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_filter_after_read(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE filter_test (id INTEGER, category TEXT)")
conn.execute("INSERT INTO filter_test VALUES (1, 'A')")
conn.execute("INSERT INTO filter_test VALUES (2, 'B')")
conn.execute("INSERT INTO filter_test VALUES (3, 'A')")
conn.execute("INSERT INTO filter_test VALUES (4, 'C')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="filter_test", properties={})
filtered = df.filter(df["category"] == "A")
rows = filtered.collect()
assert len(rows) == 2
assert all(r["category"] == "A" for r in rows)
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_select_columns(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE select_test (id INTEGER, name TEXT, email TEXT)")
conn.execute("INSERT INTO select_test VALUES (1, 'Alice', 'alice@test.com')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="select_test", properties={})
selected = df.select("id", "name")
rows = selected.collect()
assert len(rows) == 1
assert "id" in rows[0].asDict()
assert "name" in rows[0].asDict()
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_multiple_writes(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE multi_write (id INTEGER, batch INTEGER)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df1 = spark.createDataFrame([(1, 1), (2, 1)], schema="id int, batch int")
df1.write.jdbc(url=url, table="multi_write", properties={}, mode="append")
df2 = spark.createDataFrame([(3, 2), (4, 2)], schema="id int, batch int")
df2.write.jdbc(url=url, table="multi_write", properties={}, mode="append")
df3 = spark.createDataFrame([(5, 3)], schema="id int, batch int")
df3.write.jdbc(url=url, table="multi_write", properties={}, mode="append")
read_df = spark.read.jdbc(url=url, table="multi_write", properties={})
assert read_df.count() == 5
rows = read_df.collect()
batches = {r["batch"] for r in rows}
assert batches == {1, 2, 3}
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_empty_table_read(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE empty_table (id INTEGER, name TEXT)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="empty_table", properties={})
assert df.count() == 0
rows = df.collect()
assert len(rows) == 0
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_properties_passed_via_jdbc_method(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE props_test (id INTEGER)")
conn.execute("INSERT INTO props_test VALUES (1)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
props = {"fetchsize": "500"}
df = spark.read.jdbc(url=url, table="props_test", properties=props)
rows = df.collect()
assert len(rows) == 1
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_chained_options(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE chained_test (id INTEGER, val TEXT)")
conn.execute("INSERT INTO chained_test VALUES (1, 'a')")
conn.execute("INSERT INTO chained_test VALUES (2, 'b')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = (
spark.read.format("jdbc")
.option("url", url)
.option("dbtable", "chained_test")
.option("fetchsize", "100")
.option("sessionInitStatement", "SELECT 1")
.load(".")
)
rows = df.collect()
assert len(rows) == 2
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_write_mode_default_append(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE default_mode (id INTEGER)")
conn.execute("INSERT INTO default_mode VALUES (1)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.createDataFrame([(2,)], schema="id int")
df.write.jdbc(url=url, table="default_mode", properties={}, mode="append")
read_df = spark.read.jdbc(url=url, table="default_mode", properties={})
assert read_df.count() == 2
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_boolean_via_integer(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE bool_test (id INTEGER, active INTEGER)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
data = [(1, True), (2, False), (3, True)]
df = spark.createDataFrame(data, schema="id int, active boolean")
df.write.jdbc(url=url, table="bool_test", properties={}, mode="append")
read_df = spark.read.jdbc(url=url, table="bool_test", properties={})
rows = read_df.collect()
assert len(rows) == 3
active_vals = [r["active"] for r in rows]
assert len(active_vals) == 3
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_large_string_values(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE large_str (id INTEGER, content TEXT)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
large_string = "x" * 10000
medium_string = "y" * 1000
data = [(1, large_string), (2, medium_string), (3, "small")]
df = spark.createDataFrame(data, schema="id int, content string")
df.write.jdbc(url=url, table="large_str", properties={}, mode="append")
read_df = spark.read.jdbc(url=url, table="large_str", properties={})
rows = read_df.collect()
assert len(rows) == 3
contents = {r["id"]: r["content"] for r in rows}
assert len(contents[1]) == 10000
assert len(contents[2]) == 1000
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_groupby_after_read(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE groupby_test (category TEXT, amount INTEGER)")
conn.execute("INSERT INTO groupby_test VALUES ('A', 10)")
conn.execute("INSERT INTO groupby_test VALUES ('A', 20)")
conn.execute("INSERT INTO groupby_test VALUES ('B', 30)")
conn.execute("INSERT INTO groupby_test VALUES ('B', 40)")
conn.execute("INSERT INTO groupby_test VALUES ('B', 50)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="groupby_test", properties={})
from sparkless.sql import functions as F
grouped = df.groupBy("category").agg(F.sum("amount").alias("total"))
rows = grouped.collect()
assert len(rows) == 2
totals = {r["category"]: r["total"] for r in rows}
assert totals["A"] == 30
assert totals["B"] == 120
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_orderby_after_read(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE orderby_test (id INTEGER, name TEXT)")
conn.execute("INSERT INTO orderby_test VALUES (3, 'Charlie')")
conn.execute("INSERT INTO orderby_test VALUES (1, 'Alice')")
conn.execute("INSERT INTO orderby_test VALUES (2, 'Bob')")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
df = spark.read.jdbc(url=url, table="orderby_test", properties={})
ordered = df.orderBy("id")
rows = ordered.collect()
assert len(rows) == 3
ids = [r["id"] for r in rows]
assert ids == [1, 2, 3] or ids == sorted(ids)
finally:
if os.path.exists(db_path):
os.unlink(db_path)
def test_jdbc_write_options_via_option_method(spark) -> None:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
import sqlite3
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE write_opts (id INTEGER, name TEXT)")
conn.commit()
conn.close()
url = f"jdbc:sqlite:{db_path}"
data = [(i, f"name_{i}") for i in range(20)]
df = spark.createDataFrame(data, schema="id int, name string")
df.write.option("batchsize", "5").jdbc(
url=url, table="write_opts", properties={}, mode="append"
)
read_df = spark.read.jdbc(url=url, table="write_opts", properties={})
assert read_df.count() == 20
finally:
if os.path.exists(db_path):
os.unlink(db_path)