use std::{convert, ffi::c_void, fmt, mem, os::raw::c_char, ptr, str};
use arrow::{array::StructArray, datatypes::SchemaRef};
use super::{AndThenRows, Connection, Error, MappedRows, Params, RawStatement, Result, Row, Rows, ValueRef, ffi};
#[cfg(feature = "polars")]
use crate::polars_dataframe::Polars;
use crate::{
arrow_batch::{Arrow, ArrowStream},
error::result_from_duckdb_prepare,
types::{ToSql, ToSqlOutput, binding_unsupported_value, value_ref_from_value},
};
#[cfg(feature = "polars")]
use polars_core::utils::arrow as polars_arrow;
pub struct Statement<'conn> {
conn: &'conn Connection,
pub(crate) stmt: RawStatement,
}
impl Statement<'_> {
#[inline]
pub fn execute<P: Params>(&mut self, params: P) -> Result<usize> {
params.__bind_in(self)?;
self.execute_with_bound_parameters()
}
#[inline]
pub fn insert<P: Params>(&mut self, params: P) -> Result<()> {
let changes = self.execute(params)?;
match changes {
1 => Ok(()),
_ => Err(Error::StatementChangedRows(changes)),
}
}
#[inline]
pub fn query_arrow<P: Params>(&mut self, params: P) -> Result<Arrow<'_>> {
self.execute(params)?;
Ok(Arrow::new(self))
}
#[inline]
pub fn stream_arrow<P: Params>(&mut self, params: P, schema: SchemaRef) -> Result<ArrowStream<'_>> {
params.__bind_in(self)?;
self.stmt.execute_streaming()?;
Ok(ArrowStream::new(self, schema))
}
#[cfg(feature = "polars")]
#[inline]
pub fn query_polars<P: Params>(&mut self, params: P) -> Result<Polars<'_>> {
self.execute(params)?;
Ok(Polars::new(self))
}
#[inline]
pub fn query<P: Params>(&mut self, params: P) -> Result<Rows<'_>> {
self.execute(params)?;
Ok(Rows::new(self))
}
pub fn query_map<T, P, F>(&mut self, params: P, f: F) -> Result<MappedRows<'_, F>>
where
P: Params,
F: FnMut(&Row<'_>) -> Result<T>,
{
self.query(params).map(|rows| rows.mapped(f))
}
#[inline]
pub fn query_and_then<T, E, P, F>(&mut self, params: P, f: F) -> Result<AndThenRows<'_, F>>
where
P: Params,
E: convert::From<Error>,
F: FnMut(&Row<'_>) -> Result<T, E>,
{
self.query(params).map(|rows| rows.and_then(f))
}
#[inline]
pub fn exists<P: Params>(&mut self, params: P) -> Result<bool> {
let mut rows = self.query(params)?;
let exists = rows.next()?.is_some();
Ok(exists)
}
pub fn query_row<T, P, F>(&mut self, params: P, f: F) -> Result<T>
where
P: Params,
F: FnOnce(&Row<'_>) -> Result<T>,
{
self.query(params)?.get_expected_row().and_then(f)
}
pub fn query_one<T, P, F>(&mut self, params: P, f: F) -> Result<T>
where
P: Params,
F: FnOnce(&Row<'_>) -> Result<T>,
{
let mut rows = self.query(params)?;
let row = rows.get_expected_row().and_then(f)?;
if rows.next()?.is_some() {
return Err(Error::QueryReturnedMoreThanOneRow);
}
Ok(row)
}
#[inline]
pub fn row_count(&self) -> usize {
self.stmt.row_count()
}
#[inline]
pub fn step(&self) -> Option<StructArray> {
self.stmt.step()
}
#[inline]
pub fn stream_step(&self, schema: SchemaRef) -> Option<StructArray> {
self.stmt.streaming_step(schema)
}
#[cfg(feature = "polars")]
#[inline]
pub(crate) fn step_polars(&self) -> Option<polars_arrow::array::StructArray> {
self.stmt.step_polars()
}
#[inline]
pub(crate) fn bind_parameters<P>(&mut self, params: P) -> Result<()>
where
P: IntoIterator,
P::Item: ToSql,
{
let result = self.try_bind_parameters(params);
if result.is_err() {
let _ = self.stmt.clear_bindings();
}
result
}
fn try_bind_parameters<P>(&mut self, params: P) -> Result<()>
where
P: IntoIterator,
P::Item: ToSql,
{
let expected = self.stmt.bind_parameter_count();
let mut index = 0;
for p in params.into_iter() {
index += 1; if index > expected {
break;
}
self.bind_parameter(&p, index)?;
}
if index != expected {
Err(Error::InvalidParameterCount(index, expected))
} else {
Ok(())
}
}
#[inline]
pub fn parameter_count(&self) -> usize {
self.stmt.bind_parameter_count()
}
#[inline]
pub fn parameter_name(&self, idx: usize) -> Result<String> {
self.stmt.parameter_name(idx)
}
#[inline]
pub fn raw_bind_parameter<T: ToSql>(&mut self, one_based_col_index: usize, param: T) -> Result<()> {
self.bind_parameter(¶m, one_based_col_index)
}
#[inline]
pub fn raw_execute(&mut self) -> Result<usize> {
self.execute_with_bound_parameters()
}
#[inline]
pub fn raw_query(&self) -> Rows<'_> {
Rows::new(self)
}
#[inline]
pub fn schema(&self) -> SchemaRef {
self.stmt.schema()
}
fn bind_parameter<P: ?Sized + ToSql>(&self, param: &P, col: usize) -> Result<()> {
let value = param.to_sql()?;
let ptr = unsafe { self.stmt.ptr() };
let value = match value {
ToSqlOutput::Borrowed(v) => v,
ToSqlOutput::Owned(ref v) => value_ref_from_value(v, binding_unsupported_value)?,
};
let rc = match value {
ValueRef::Null => unsafe { ffi::duckdb_bind_null(ptr, col as u64) },
ValueRef::Boolean(i) => unsafe { ffi::duckdb_bind_boolean(ptr, col as u64, i) },
ValueRef::TinyInt(i) => unsafe { ffi::duckdb_bind_int8(ptr, col as u64, i) },
ValueRef::SmallInt(i) => unsafe { ffi::duckdb_bind_int16(ptr, col as u64, i) },
ValueRef::Int(i) => unsafe { ffi::duckdb_bind_int32(ptr, col as u64, i) },
ValueRef::BigInt(i) => unsafe { ffi::duckdb_bind_int64(ptr, col as u64, i) },
ValueRef::HugeInt(i) => unsafe {
let hi = ffi::duckdb_hugeint {
lower: i as u64,
upper: (i >> 64) as i64,
};
ffi::duckdb_bind_hugeint(ptr, col as u64, hi)
},
ValueRef::UTinyInt(i) => unsafe { ffi::duckdb_bind_uint8(ptr, col as u64, i) },
ValueRef::USmallInt(i) => unsafe { ffi::duckdb_bind_uint16(ptr, col as u64, i) },
ValueRef::UInt(i) => unsafe { ffi::duckdb_bind_uint32(ptr, col as u64, i) },
ValueRef::UBigInt(i) => unsafe { ffi::duckdb_bind_uint64(ptr, col as u64, i) },
ValueRef::Float(r) => unsafe { ffi::duckdb_bind_float(ptr, col as u64, r) },
ValueRef::Double(r) => unsafe { ffi::duckdb_bind_double(ptr, col as u64, r) },
ValueRef::Text(s) => unsafe {
ffi::duckdb_bind_varchar_length(ptr, col as u64, s.as_ptr() as *const c_char, s.len() as u64)
},
ValueRef::Blob(b) => unsafe {
ffi::duckdb_bind_blob(ptr, col as u64, b.as_ptr() as *const c_void, b.len() as u64)
},
ValueRef::Timestamp(u, i) => unsafe {
ffi::duckdb_bind_timestamp(ptr, col as u64, ffi::duckdb_timestamp { micros: u.to_micros(i) })
},
ValueRef::Interval { months, days, nanos } => unsafe {
let micros = nanos / 1_000;
ffi::duckdb_bind_interval(ptr, col as u64, ffi::duckdb_interval { months, days, micros })
},
ValueRef::Date32(days) => unsafe { ffi::duckdb_bind_date(ptr, col as u64, ffi::duckdb_date { days }) },
ValueRef::Time64(u, i) => unsafe {
ffi::duckdb_bind_time(ptr, col as u64, ffi::duckdb_time { micros: u.to_micros(i) })
},
ValueRef::Decimal(d) => unsafe {
let decimal = crate::types::to_duckdb_decimal(d);
ffi::duckdb_bind_decimal(ptr, col as u64, decimal)
},
_ => {
return Err(Error::ToSqlConversionFailure(
binding_unsupported_value(value.data_type()).into(),
));
}
};
result_from_duckdb_prepare(rc, ptr)
}
#[inline]
fn execute_with_bound_parameters(&mut self) -> Result<usize> {
self.stmt.execute()
}
#[inline]
pub(crate) unsafe fn into_raw(mut self) -> RawStatement {
let mut stmt = unsafe { RawStatement::new(ptr::null_mut()) };
mem::swap(&mut stmt, &mut self.stmt);
stmt
}
}
impl fmt::Debug for Statement<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sql = if self.stmt.is_null() {
Ok("")
} else {
str::from_utf8(self.stmt.sql().unwrap().to_bytes())
};
f.debug_struct("Statement")
.field("conn", self.conn)
.field("stmt", &self.stmt)
.field("sql", &sql)
.finish()
}
}
impl Statement<'_> {
#[inline]
pub(super) fn new(conn: &Connection, stmt: RawStatement) -> Statement<'_> {
Statement { conn, stmt }
}
}
#[cfg(test)]
mod test {
use arrow::{array::ListArray, datatypes::Int32Type};
use crate::{
Connection, Error, Result,
core::LogicalTypeId,
params_from_iter,
types::{ListType, ToSql, ToSqlOutput, Type, ValueRef},
};
use rust_decimal::Decimal;
struct BorrowedList(ListArray);
impl BorrowedList {
fn new() -> Self {
Self(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![Some(vec![
Some(1),
Some(2),
])]))
}
}
impl ToSql for BorrowedList {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
Ok(ToSqlOutput::Borrowed(ValueRef::List(ListType::Regular(&self.0), 0)))
}
}
fn assert_binding_list_error(err: Error) {
match err {
Error::ToSqlConversionFailure(e) => {
assert!(
e.to_string().contains("binding List parameters is not yet supported"),
"unexpected message: {e}"
);
}
other => panic!("expected ToSqlConversionFailure, got {other:?}"),
}
}
fn assert_variant_decode_error(err: Error, expected_idx: usize) {
match err {
Error::FromSqlConversionFailure(idx, Type::Variant, e) => {
assert_eq!(idx, expected_idx);
assert!(
e.to_string().contains("decoding Variant columns is not supported"),
"unexpected message: {e}"
);
}
other => panic!("expected FromSqlConversionFailure for Variant, got {other:?}"),
}
}
#[test]
fn test_execute() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE foo(x INTEGER)")?;
assert_eq!(db.execute("INSERT INTO foo(x) VALUES (?)", [&2i32])?, 1);
assert_eq!(db.execute("INSERT INTO foo(x) VALUES (?)", [&3i32])?, 1);
assert_eq!(
5i32,
db.query_row::<i32, _, _>("SELECT SUM(x) FROM foo WHERE x > ?", [&0i32], |r| r.get(0))?
);
assert_eq!(
3i32,
db.query_row::<i32, _, _>("SELECT SUM(x) FROM foo WHERE x > ?", [&2i32], |r| r.get(0))?
);
Ok(())
}
#[test]
fn test_stmt_execute() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = r#"
CREATE SEQUENCE seq;
CREATE TABLE test (id INTEGER DEFAULT NEXTVAL('seq'), name TEXT NOT NULL, flag INTEGER);
"#;
db.execute_batch(sql)?;
let mut stmt = db.prepare("INSERT INTO test (name) VALUES (?)")?;
stmt.execute([&"one"])?;
let mut stmt = db.prepare("SELECT COUNT(*) FROM test WHERE name = ?")?;
assert_eq!(1i32, stmt.query_row::<i32, _, _>([&"one"], |r| r.get(0))?);
Ok(())
}
#[test]
fn test_query() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = r#"
CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER);
INSERT INTO test(id, name) VALUES (1, 'one');
"#;
db.execute_batch(sql)?;
let mut stmt = db.prepare("SELECT id FROM test where name = ?")?;
{
let id: i32 = stmt.query_one([&"one"], |r| r.get(0))?;
assert_eq!(id, 1);
}
Ok(())
}
#[test]
fn test_query_and_then() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = r#"
CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER);
INSERT INTO test(id, name) VALUES (1, 'one');
INSERT INTO test(id, name) VALUES (2, 'one');
"#;
db.execute_batch(sql)?;
let mut stmt = db.prepare("SELECT id FROM test where name = ? ORDER BY id ASC")?;
let mut rows = stmt.query_and_then([&"one"], |row| {
let id: i32 = row.get(0)?;
if id == 1 {
Ok(id)
} else {
Err(Error::ExecuteReturnedResults)
}
})?;
let doubled_id: i32 = rows.next().unwrap()?;
assert_eq!(1, doubled_id);
#[allow(clippy::match_wild_err_arm)]
match rows.next().unwrap() {
Ok(_) => panic!("invalid Ok"),
Err(Error::ExecuteReturnedResults) => (),
Err(_) => panic!("invalid Err"),
}
Ok(())
}
#[test]
fn test_unbound_parameters_are_error() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = "CREATE TABLE test (x TEXT, y TEXT)";
db.execute_batch(sql)?;
let mut stmt = db.prepare("INSERT INTO test (x, y) VALUES (?, ?)")?;
assert!(stmt.execute([&"one"]).is_err());
Ok(())
}
#[test]
fn test_insert_empty_text_is_none() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = "CREATE TABLE test (x TEXT, y TEXT)";
db.execute_batch(sql)?;
let mut stmt = db.prepare("INSERT INTO test (x) VALUES (?)")?;
stmt.execute([&"one"])?;
let result: Option<String> = db.query_row("SELECT y FROM test WHERE x = 'one'", [], |row| row.get(0))?;
assert!(result.is_none());
Ok(())
}
#[test]
fn test_raw_binding() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE test (name TEXT, value INTEGER)")?;
{
let mut stmt = db.prepare("INSERT INTO test (name, value) VALUES (?, ?)")?;
stmt.raw_bind_parameter(2, 50i32)?;
stmt.raw_bind_parameter(1, "example")?;
let n = stmt.raw_execute()?;
assert_eq!(n, 1);
}
{
let mut stmt = db.prepare("SELECT name, value FROM test WHERE value = ?")?;
stmt.raw_bind_parameter(1, 50)?;
stmt.raw_execute()?;
let mut rows = stmt.raw_query();
{
let row = rows.next()?.unwrap();
let name: String = row.get(0)?;
assert_eq!(name, "example");
let value: i32 = row.get(1)?;
assert_eq!(value, 50);
}
assert!(rows.next()?.is_none());
}
{
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE test (name TEXT, value UINTEGER)")?;
let mut stmt = db.prepare("INSERT INTO test(name, value) VALUES (?, ?)")?;
stmt.raw_bind_parameter(1, "negative")?;
stmt.raw_bind_parameter(2, u32::MAX)?;
let n = stmt.raw_execute()?;
assert_eq!(n, 1);
assert_eq!(
u32::MAX,
db.query_row::<u32, _, _>("SELECT value FROM test", [], |r| r.get(0))?
);
}
{
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE test (name TEXT, value UBIGINT)")?;
let mut stmt = db.prepare("INSERT INTO test(name, value) VALUES (?, ?)")?;
stmt.raw_bind_parameter(1, "negative")?;
stmt.raw_bind_parameter(2, u64::MAX)?;
let n = stmt.raw_execute()?;
assert_eq!(n, 1);
assert_eq!(
u64::MAX,
db.query_row::<u64, _, _>("SELECT value FROM test", [], |r| r.get(0))?
);
}
Ok(())
}
#[test]
#[cfg_attr(windows, ignore = "Windows doesn't allow concurrent writes to a file")]
fn test_insert_duplicate() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE foo(x INTEGER UNIQUE)")?;
let mut stmt = db.prepare("INSERT INTO foo (x) VALUES (?)")?;
stmt.insert([1i32])?;
stmt.insert([2i32])?;
assert!(stmt.insert([1i32]).is_err());
let mut multi = db.prepare("INSERT INTO foo (x) SELECT 3 UNION ALL SELECT 4")?;
match multi.insert([]).unwrap_err() {
Error::StatementChangedRows(2) => (),
err => panic!("Unexpected error {err}"),
}
Ok(())
}
#[test]
fn test_insert_different_tables() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch(
r"
CREATE TABLE foo(x INTEGER);
CREATE TABLE bar(x INTEGER);
",
)?;
db.prepare("INSERT INTO foo VALUES (10)")?.insert([])?;
db.prepare("INSERT INTO bar VALUES (10)")?.insert([])?;
Ok(())
}
#[test]
fn test_insert_with_returning_clause() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch(
"CREATE SEQUENCE location_id_seq START WITH 1 INCREMENT BY 1;
CREATE TABLE location (
id INTEGER PRIMARY KEY DEFAULT nextval('location_id_seq'),
name TEXT NOT NULL
)",
)?;
let changes = db.execute("INSERT INTO location (name) VALUES (?)", ["test1"])?;
assert_eq!(changes, 1);
let changes = db.execute("INSERT INTO location (name) VALUES (?) RETURNING id", ["test2"])?;
assert_eq!(changes, 0);
let count: i64 = db.query_row("SELECT COUNT(*) FROM location", [], |r| r.get(0))?;
assert_eq!(count, 2);
let mut stmt = db.prepare("INSERT INTO location (name) VALUES (?)")?;
stmt.insert(["test3"])?;
let mut stmt = db.prepare("INSERT INTO location (name) VALUES (?) RETURNING id")?;
let result = stmt.insert(["test4"]);
assert!(matches!(result, Err(Error::StatementChangedRows(0))));
let count: i64 = db.query_row("SELECT COUNT(*) FROM location", [], |r| r.get(0))?;
assert_eq!(count, 4);
let id: i64 = db.query_row("INSERT INTO location (name) VALUES (?) RETURNING id", ["test5"], |r| {
r.get(0)
})?;
assert_eq!(id, 5);
let mut stmt = db.prepare("INSERT INTO location (name) VALUES (?) RETURNING id")?;
let ids: Vec<i64> = stmt
.query_map(["test6"], |row| row.get(0))?
.collect::<Result<Vec<_>>>()?;
assert_eq!(ids.len(), 1);
assert_eq!(ids[0], 6);
let id: i64 = db
.prepare("INSERT INTO location (name) VALUES (?) RETURNING id")?
.query_one(["test7"], |r| r.get(0))?;
assert_eq!(id, 7);
let (id, name): (i64, String) = db.query_row(
"INSERT INTO location (name) VALUES (?) RETURNING id, name",
["test8"],
|r| Ok((r.get(0)?, r.get(1)?)),
)?;
assert_eq!(id, 8);
assert_eq!(name, "test8");
Ok(())
}
#[test]
fn test_exists() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = "BEGIN;
CREATE TABLE foo(x INTEGER);
INSERT INTO foo VALUES(1);
INSERT INTO foo VALUES(2);
END;";
db.execute_batch(sql)?;
let mut stmt = db.prepare("SELECT 1 FROM foo WHERE x = ?")?;
assert!(stmt.exists([1i32])?);
assert!(stmt.exists([2i32])?);
assert!(!stmt.exists([0i32])?);
Ok(())
}
#[test]
fn test_query_row() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = "BEGIN;
CREATE TABLE foo(x INTEGER, y INTEGER);
INSERT INTO foo VALUES(1, 3);
INSERT INTO foo VALUES(2, 4);
END;";
db.execute_batch(sql)?;
let mut stmt = db.prepare("SELECT y FROM foo WHERE x = ?")?;
let y: Result<i32> = stmt.query_row([1i32], |r| r.get(0));
assert_eq!(3i32, y?);
Ok(())
}
#[test]
fn test_query_one() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = "BEGIN;
CREATE TABLE foo(x INTEGER, y INTEGER);
INSERT INTO foo VALUES(1, 3);
INSERT INTO foo VALUES(2, 4);
END;";
db.execute_batch(sql)?;
let y: i32 = db
.prepare("SELECT y FROM foo WHERE x = ?")?
.query_one([1], |r| r.get(0))?;
assert_eq!(y, 3);
let res: Result<i32> = db
.prepare("SELECT y FROM foo WHERE x = ?")?
.query_one([99], |r| r.get(0));
assert_eq!(res.unwrap_err(), Error::QueryReturnedNoRows);
let res: Result<i32> = db.prepare("SELECT y FROM foo")?.query_one([], |r| r.get(0));
assert_eq!(res.unwrap_err(), Error::QueryReturnedMoreThanOneRow);
Ok(())
}
#[test]
fn test_query_one_optional() -> Result<()> {
use crate::OptionalExt;
let db = Connection::open_in_memory()?;
let sql = "BEGIN;
CREATE TABLE foo(x INTEGER, y INTEGER);
INSERT INTO foo VALUES(1, 3);
INSERT INTO foo VALUES(2, 4);
END;";
db.execute_batch(sql)?;
let y: Option<i32> = db
.prepare("SELECT y FROM foo WHERE x = ?")?
.query_one([1], |r| r.get(0))
.optional()?;
assert_eq!(y, Some(3));
let y: Option<i32> = db
.prepare("SELECT y FROM foo WHERE x = ?")?
.query_one([99], |r| r.get(0))
.optional()?;
assert_eq!(y, None);
let res = db
.prepare("SELECT y FROM foo")?
.query_one([], |r| r.get::<_, i32>(0))
.optional();
assert_eq!(res.unwrap_err(), Error::QueryReturnedMoreThanOneRow);
Ok(())
}
#[test]
fn test_query_by_column_name() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = "BEGIN;
CREATE TABLE foo(x INTEGER, y INTEGER);
INSERT INTO foo VALUES(1, 3);
END;";
db.execute_batch(sql)?;
let mut stmt = db.prepare("SELECT y FROM foo")?;
let y: Result<i64> = stmt.query_row([], |r| r.get("y"));
assert_eq!(3i64, y?);
Ok(())
}
#[test]
fn test_get_schema_of_executed_result() -> Result<()> {
use arrow::datatypes::{DataType, Field, Schema};
let db = Connection::open_in_memory()?;
let sql = "BEGIN;
CREATE TABLE foo(x STRING, y INTEGER);
INSERT INTO foo VALUES('hello', 3);
END;";
db.execute_batch(sql)?;
let mut stmt = db.prepare("SELECT x, y FROM foo")?;
let _ = stmt.execute([]);
let schema = stmt.schema();
assert_eq!(
*schema,
Schema::new(vec![
Field::new("x", DataType::Utf8, true),
Field::new("y", DataType::Int32, true)
])
);
Ok(())
}
#[test]
#[should_panic(expected = "called `Option::unwrap()` on a `None` value")]
fn test_unexecuted_schema_panics() {
let db = Connection::open_in_memory().unwrap();
let sql = "BEGIN;
CREATE TABLE foo(x STRING, y INTEGER);
INSERT INTO foo VALUES('hello', 3);
END;";
db.execute_batch(sql).unwrap();
let stmt = db.prepare("SELECT x, y FROM foo").unwrap();
let _ = stmt.schema();
}
#[test]
fn test_query_by_column_name_ignore_case() -> Result<()> {
let db = Connection::open_in_memory()?;
let sql = "BEGIN;
CREATE TABLE foo(x INTEGER, y INTEGER);
INSERT INTO foo VALUES(1, 3);
END;";
db.execute_batch(sql)?;
let mut stmt = db.prepare("SELECT y as Y FROM foo")?;
let y: Result<i64> = stmt.query_row([], |r| r.get("y"));
assert_eq!(3i64, y?);
Ok(())
}
#[test]
fn test_bind_parameters() -> Result<()> {
let db = Connection::open_in_memory()?;
db.query_row("SELECT ?1, ?2, ?3", [&1u8 as &dyn ToSql, &"one", &Some("one")], |row| {
row.get::<_, u8>(0)
})?;
let data = vec![1, 2, 3];
db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| row.get::<_, u8>(0))?;
db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data.as_slice()), |row| {
row.get::<_, u8>(0)
})?;
db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data), |row| row.get::<_, u8>(0))?;
let data: std::collections::BTreeSet<String> =
["one", "two", "three"].iter().map(|s| (*s).to_string()).collect();
db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| {
row.get::<_, String>(0)
})?;
let data = [0; 3];
db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| row.get::<_, u8>(0))?;
db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data.iter()), |row| {
row.get::<_, u8>(0)
})?;
Ok(())
}
#[test]
fn test_named_parameters() -> Result<()> {
use std::collections::HashMap;
let named_params = HashMap::from([("foo", 42), ("bar", 23)]);
let db = Connection::open_in_memory()?;
let sql = r#"SELECT $foo > $bar"#;
let result: bool = db.query_row(sql, &named_params, |row| row.get(0))?;
assert!(result);
Ok(())
}
#[test]
fn test_named_parameters_macro() -> Result<()> {
let db = Connection::open_in_memory()?;
let name = "alice";
let params = crate::named_params! {
"foo": 42,
"name": name,
};
let result: bool = db.query_row("SELECT $foo > 40 AND $name = 'alice'", params, |row| row.get(0))?;
assert!(result);
Ok(())
}
#[test]
fn test_empty_named_parameters_macro() -> Result<()> {
let db = Connection::open_in_memory()?;
let result: i32 = db.query_row("SELECT 1", crate::named_params! {}, |row| row.get(0))?;
assert_eq!(result, 1);
Ok(())
}
#[test]
fn test_named_parameters_repeated_placeholder() -> Result<()> {
use std::collections::HashMap;
let db = Connection::open_in_memory()?;
let stmt = db.prepare("SELECT $foo + $foo")?;
assert_eq!(stmt.parameter_count(), 1);
assert_eq!(stmt.parameter_name(1)?, "foo");
let slice_result: i32 = db.query_row(
"SELECT $foo + $foo",
crate::named_params! {
"foo": 21,
},
|row| row.get(0),
)?;
assert_eq!(slice_result, 42);
let named_params = HashMap::from([("foo", 21)]);
let hashmap_result: i32 = db.query_row("SELECT $foo + $foo", &named_params, |row| row.get(0))?;
assert_eq!(hashmap_result, 42);
Ok(())
}
#[test]
fn test_named_parameters_reject_extra_keys() -> Result<()> {
use std::collections::HashMap;
let named_params = HashMap::from([("foo", 42), ("bar", 23), ("extra", 1)]);
let db = Connection::open_in_memory()?;
let err = db
.query_row("SELECT $foo > $bar", &named_params, |row| row.get::<_, bool>(0))
.unwrap_err();
assert_eq!(err, Error::InvalidParameterName("extra".to_string()));
Ok(())
}
#[test]
fn test_named_parameters_reject_extra_hashmap_key_deterministically() -> Result<()> {
use std::collections::HashMap;
let named_params = HashMap::from([("foo", 42), ("z_extra", 1), ("a_extra", 2)]);
let db = Connection::open_in_memory()?;
let err = db
.query_row("SELECT $foo", &named_params, |row| row.get::<_, i32>(0))
.unwrap_err();
assert_eq!(err, Error::InvalidParameterName("a_extra".to_string()));
Ok(())
}
#[test]
fn test_named_parameters_reject_extra_slice_key() -> Result<()> {
let db = Connection::open_in_memory()?;
let bar = 23;
let params = crate::named_params! {
"foo": 42,
"middle_extra": 0,
"bar": bar,
"first_extra": 1,
"second_extra": 2,
};
let err = db
.query_row("SELECT $foo + $bar", params, |row| row.get::<_, i32>(0))
.unwrap_err();
assert_eq!(err, Error::InvalidParameterName("middle_extra".to_string()));
Ok(())
}
#[test]
fn test_named_parameters_reject_missing_hashmap_key() -> Result<()> {
use std::collections::HashMap;
let named_params = HashMap::from([("foo", 42)]);
let db = Connection::open_in_memory()?;
let err = db
.query_row("SELECT $foo > $bar", &named_params, |row| row.get::<_, bool>(0))
.unwrap_err();
assert_eq!(err, Error::InvalidParameterName("bar".to_string()));
Ok(())
}
#[test]
fn test_named_parameters_reject_missing_slice_key() -> Result<()> {
let db = Connection::open_in_memory()?;
let params = crate::named_params! {
"foo": 42,
};
let err = db
.query_row("SELECT $foo > $bar", params, |row| row.get::<_, bool>(0))
.unwrap_err();
assert_eq!(err, Error::InvalidParameterName("bar".to_string()));
Ok(())
}
#[test]
fn test_named_parameters_reject_duplicate_slice_keys() -> Result<()> {
let db = Connection::open_in_memory()?;
let first = 42;
let second = 23;
let params = &[("foo", &first as &dyn ToSql), ("foo", &second as &dyn ToSql)] as &[(&str, &dyn ToSql)];
let err = db
.query_row("SELECT $foo", params, |row| row.get::<_, i32>(0))
.unwrap_err();
assert_eq!(
err,
Error::InvalidParameterName("duplicate parameter name: foo".to_string())
);
Ok(())
}
#[test]
fn test_named_parameters_reject_positional_placeholders() -> Result<()> {
use std::collections::HashMap;
let named_params = HashMap::from([("1", 42), ("2", 23)]);
let db = Connection::open_in_memory()?;
let err = db
.query_row("SELECT ? > ?", &named_params, |row| row.get::<_, bool>(0))
.unwrap_err();
assert_eq!(
err,
Error::InvalidParameterName("positional parameter 1 cannot be used with named parameters".to_string())
);
Ok(())
}
#[test]
fn test_named_parameters_reject_dollar_number_placeholders() -> Result<()> {
use std::collections::HashMap;
let named_params = HashMap::from([("1", 42), ("2", 23)]);
let db = Connection::open_in_memory()?;
let hashmap_err = db
.query_row("SELECT $1 + $2", &named_params, |row| row.get::<_, i32>(0))
.unwrap_err();
assert_eq!(
hashmap_err,
Error::InvalidParameterName("positional parameter 1 cannot be used with named parameters".to_string())
);
let slice_err = db
.query_row(
"SELECT $1 + $2",
crate::named_params! {
"1": 42,
"2": 23,
},
|row| row.get::<_, i32>(0),
)
.unwrap_err();
assert_eq!(
slice_err,
Error::InvalidParameterName("positional parameter 1 cannot be used with named parameters".to_string())
);
Ok(())
}
#[test]
fn test_named_parameters_reject_mixed_placeholders() -> Result<()> {
let db = Connection::open_in_memory()?;
let err = db.prepare("SELECT $foo + ?").unwrap_err();
assert!(err.to_string().contains("Mixing named and positional parameters"));
Ok(())
}
#[test]
fn test_named_parameters_bind_null_values() -> Result<()> {
use std::collections::HashMap;
let db = Connection::open_in_memory()?;
let named_params = HashMap::from([("x", None::<i32>)]);
let hashmap_result: bool = db.query_row("SELECT $x IS NULL", &named_params, |row| row.get(0))?;
assert!(hashmap_result);
let slice_result: bool = db.query_row(
"SELECT $x IS NULL",
crate::named_params! {
"x": Option::<i32>::None,
},
|row| row.get(0),
)?;
assert!(slice_result);
Ok(())
}
#[test]
fn test_named_parameters_string_keys_query_map() -> Result<()> {
use std::collections::HashMap;
let db = Connection::open_in_memory()?;
let params = HashMap::from([("min".to_string(), 2i64), ("max".to_string(), 3i64)]);
let mut stmt = db.prepare("SELECT i FROM range(5) tbl(i) WHERE i BETWEEN $min AND $max ORDER BY i")?;
let rows = stmt.query_map(¶ms, |row| row.get::<_, i64>(0))?;
let values = rows.collect::<Result<Vec<_>>>()?;
assert_eq!(values, [2, 3]);
Ok(())
}
#[test]
fn test_named_parameters_cow_keys_custom_hasher() -> Result<()> {
use std::{
borrow::Cow,
collections::{HashMap, hash_map::DefaultHasher},
hash::BuildHasherDefault,
};
let db = Connection::open_in_memory()?;
let mut params: HashMap<Cow<'static, str>, i64, BuildHasherDefault<DefaultHasher>> = HashMap::default();
params.insert(Cow::Borrowed("min"), 2);
params.insert(Cow::Owned("max".to_string()), 3);
let mut stmt = db.prepare("SELECT i FROM range(5) tbl(i) WHERE i BETWEEN $min AND $max ORDER BY i")?;
let rows = stmt.query_map(¶ms, |row| row.get::<_, i64>(0))?;
let values = rows.collect::<Result<Vec<_>>>()?;
assert_eq!(values, [2, 3]);
Ok(())
}
#[test]
fn test_empty_stmt() -> Result<()> {
let conn = Connection::open_in_memory()?;
let stmt = conn.prepare("");
assert!(stmt.is_err());
Ok(())
}
#[test]
fn test_comment_empty_stmt() -> Result<()> {
let conn = Connection::open_in_memory()?;
assert!(conn.prepare("/*SELECT 1;*/").is_err());
Ok(())
}
#[test]
fn test_comment_and_sql_stmt() -> Result<()> {
let conn = Connection::open_in_memory()?;
let mut stmt = conn.prepare("/*...*/ SELECT 1;")?;
stmt.execute([])?;
assert_eq!(1, stmt.column_count());
Ok(())
}
#[test]
fn test_nul_byte() -> Result<()> {
let db = Connection::open_in_memory()?;
let expected = "a\x00b";
let actual: String = db.query_row("SELECT CAST(? AS VARCHAR)", [expected], |row| row.get(0))?;
assert_eq!(expected, actual);
Ok(())
}
#[test]
fn test_parameter_name() -> Result<()> {
let db = Connection::open_in_memory()?;
{
let stmt = db.prepare("SELECT $foo, $bar")?;
assert_eq!(stmt.parameter_count(), 2);
assert_eq!(stmt.parameter_name(1)?, "foo");
assert_eq!(stmt.parameter_name(2)?, "bar");
assert!(matches!(stmt.parameter_name(0), Err(Error::InvalidParameterIndex(0))));
assert!(matches!(
stmt.parameter_name(100),
Err(Error::InvalidParameterIndex(100))
));
}
{
let stmt = db.prepare("SELECT ?, ?")?;
assert_eq!(stmt.parameter_count(), 2);
assert_eq!(stmt.parameter_name(1)?, "1");
assert_eq!(stmt.parameter_name(2)?, "2");
}
{
let stmt = db.prepare("SELECT ?1, ?2")?;
assert_eq!(stmt.parameter_count(), 2);
assert_eq!(stmt.parameter_name(1)?, "1");
assert_eq!(stmt.parameter_name(2)?, "2");
}
Ok(())
}
#[test]
fn test_bind_named_parameters_manually() -> Result<()> {
use std::collections::HashMap;
let db = Connection::open_in_memory()?;
let mut stmt = db.prepare("SELECT $foo > $bar")?;
let mut params: HashMap<String, i32> = HashMap::new();
params.insert("foo".to_string(), 42);
params.insert("bar".to_string(), 23);
for idx in 1..=stmt.parameter_count() {
let name = stmt.parameter_name(idx)?;
if let Some(value) = params.get(&name) {
stmt.raw_bind_parameter(idx, value)?;
}
}
stmt.raw_execute()?;
let mut rows = stmt.raw_query();
let row = rows.next()?.unwrap();
let result: bool = row.get(0)?;
assert!(result);
Ok(())
}
#[test]
fn test_execute_streaming_error_message() -> Result<()> {
let db = Connection::open_in_memory()?;
let mut stmt = db.prepare("SELECT CAST('not-a-number' AS INTEGER)")?;
let result = stmt.stmt.execute_streaming();
assert!(result.is_err());
let err = result.unwrap_err();
let error_string = format!("{}", err);
assert!(
error_string.contains("Conversion Error"),
"Expected descriptive error, got: {}",
error_string
);
Ok(())
}
#[test]
fn test_bind_date32() -> Result<()> {
use crate::types::Value;
let db = Connection::open_in_memory()?;
let result: bool = db.query_row("SELECT ? = DATE '2022-05-18'", [Value::Date32(19130)], |row| row.get(0))?;
assert!(result);
Ok(())
}
#[test]
fn test_bind_time64() -> Result<()> {
use crate::types::{TimeUnit, Value};
let db = Connection::open_in_memory()?;
let micros = 45_045_123_456i64;
let result: bool = db.query_row(
"SELECT ? = TIME '12:30:45.123456'",
[Value::Time64(TimeUnit::Microsecond, micros)],
|row| row.get(0),
)?;
assert!(result);
Ok(())
}
#[test]
fn test_execute_tuple() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE test (id INTEGER, name TEXT, score DOUBLE)")?;
let mut stmt = db.prepare("INSERT INTO test VALUES (?, ?, ?)")?;
stmt.execute((1i32, "alice", 95.5f64))?;
stmt.execute((2i32, "bob", 87.0f64))?;
let mut stmt = db.prepare("SELECT id, name, score FROM test ORDER BY id")?;
let mut rows = stmt.query([])?;
let row = rows.next()?.unwrap();
assert_eq!(row.get::<_, i32>(0)?, 1);
assert_eq!(row.get::<_, String>(1)?, "alice");
assert_eq!(row.get::<_, f64>(2)?, 95.5);
let row = rows.next()?.unwrap();
assert_eq!(row.get::<_, i32>(0)?, 2);
assert_eq!(row.get::<_, String>(1)?, "bob");
assert_eq!(row.get::<_, f64>(2)?, 87.0);
Ok(())
}
#[test]
fn test_query_row_tuple() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE test (id INTEGER, name TEXT)")?;
db.execute("INSERT INTO test VALUES (1, 'alice')", [])?;
let name: String = db.query_row("SELECT name FROM test WHERE id = ?", (1i32,), |r| r.get(0))?;
assert_eq!(name, "alice");
Ok(())
}
#[test]
fn test_execute_tuple_single_element() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE test (id INTEGER)")?;
db.execute("INSERT INTO test VALUES (?)", (42i32,))?;
let val: i32 = db.query_row("SELECT id FROM test", [], |r| r.get(0))?;
assert_eq!(val, 42);
Ok(())
}
#[test]
fn test_execute_tuple_many_columns() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch(
"CREATE TABLE test (a INT, b TEXT, c DOUBLE, d INT, e TEXT, f DOUBLE, g INT, h TEXT, i DOUBLE, j INT, k TEXT, l DOUBLE, m INT, n TEXT, o DOUBLE, p INT)",
)?;
db.execute(
"INSERT INTO test VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
1i32, "a", 1.0f64, 2i32, "b", 2.0f64, 3i32, "c", 3.0f64, 4i32, "d", 4.0f64, 5i32, "e", 5.0f64, 6i32,
),
)?;
let (a, p): (i32, i32) = db.query_row("SELECT a, p FROM test", [], |r| Ok((r.get(0)?, r.get(1)?)))?;
assert_eq!(a, 1);
assert_eq!(p, 6);
Ok(())
}
#[test]
fn test_execute_tuple_with_option() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE test (id INTEGER, name TEXT)")?;
db.execute("INSERT INTO test VALUES (?, ?)", (1i32, None::<String>))?;
let name: Option<String> = db.query_row("SELECT name FROM test WHERE id = ?", (1i32,), |r| r.get(0))?;
assert_eq!(name, None);
Ok(())
}
#[test]
fn test_execute_empty_tuple() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE test (id INTEGER DEFAULT 1)")?;
db.execute("INSERT INTO test DEFAULT VALUES", ())?;
let val: i32 = db.query_row("SELECT id FROM test", (), |r| r.get(0))?;
assert_eq!(val, 1);
Ok(())
}
#[test]
fn test_with_decimal() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch(
"BEGIN; \
CREATE TABLE foo(x DECIMAL(18, 4)); \
CREATE TABLE bar(y DECIMAL(18, 2)); \
COMMIT;",
)?;
let value = Decimal::from_i128_with_scale(12345, 4);
db.execute("INSERT INTO foo(x) VALUES (?)", [&value])?;
let row: Decimal = db.query_row("SELECT x FROM foo", [], |r| r.get::<_, Decimal>(0))?;
assert_eq!(row, value);
let value = Decimal::from_i128_with_scale(12345, 4);
db.execute("INSERT INTO bar(y) VALUES (?)", [&value])?;
let row: Decimal = db.query_row("SELECT y FROM bar", [], |r| r.get::<_, Decimal>(0))?;
assert_eq!(row, Decimal::from_i128_with_scale(123, 2));
Ok(())
}
#[test]
fn test_decimal_from_sql_integer_and_float() -> Result<()> {
let db = Connection::open_in_memory()?;
let row: Decimal = db.query_row("SELECT 42", [], |r| r.get(0))?;
assert_eq!(row, Decimal::from(42));
let row: Decimal = db.query_row("SELECT 9999999999::BIGINT", [], |r| r.get(0))?;
assert_eq!(row, Decimal::from(9999999999_i64));
let row: Decimal = db.query_row("SELECT 3.14::DOUBLE", [], |r| r.get(0))?;
let diff = (row - Decimal::from_str_exact("3.14").unwrap()).abs();
assert!(diff < Decimal::from_str_exact("0.0001").unwrap());
let row: Decimal = db.query_row("SELECT '123.456'::VARCHAR", [], |r| r.get(0))?;
assert_eq!(row, Decimal::from_str_exact("123.456").unwrap());
let row: Decimal = db.query_row("SELECT 12345678901234567890::HUGEINT", [], |r| r.get(0))?;
assert_eq!(row, Decimal::from_i128_with_scale(12345678901234567890, 0));
Ok(())
}
#[test]
fn test_variant_column_logical_type_metadata() -> Result<()> {
let db = Connection::open_in_memory()?;
let stmt = db.prepare("SELECT {'a': 42}::VARIANT AS variant_col")?;
let logical_type = stmt.column_logical_type(0);
assert_eq!(logical_type.id(), LogicalTypeId::Variant);
assert_eq!(logical_type.raw_id(), crate::ffi::DUCKDB_TYPE_DUCKDB_TYPE_VARIANT);
assert_eq!(logical_type.num_children(), 0);
assert_eq!(format!("{logical_type:?}"), "Variant");
Ok(())
}
#[test]
#[should_panic(expected = "could not retrieve logical type for result column at index 1")]
fn test_column_logical_type_bad_index_panics() {
let db = Connection::open_in_memory().unwrap();
let stmt = db.prepare("SELECT 1").unwrap();
let _ = stmt.column_logical_type(1);
}
#[test]
fn test_variant_result_decode_unsupported() -> Result<()> {
let db = Connection::open_in_memory()?;
let err = match db
.prepare("SELECT 1 AS id, {'a': 42}::VARIANT AS variant_col")?
.query([])
{
Ok(_) => panic!("expected Variant query to fail"),
Err(err) => err,
};
assert_variant_decode_error(err, 1);
Ok(())
}
#[test]
fn test_variant_streaming_result_decode_unsupported() -> Result<()> {
use std::sync::Arc;
use arrow::datatypes::Schema;
let db = Connection::open_in_memory()?;
let err = match db
.prepare("SELECT 1 AS id, {'a': 42}::VARIANT AS variant_col")?
.stream_arrow([], Arc::new(Schema::empty()))
{
Ok(_) => panic!("expected Variant streaming query to fail"),
Err(err) => err,
};
assert_variant_decode_error(err, 1);
Ok(())
}
#[test]
fn test_nested_variant_result_decode_unsupported() -> Result<()> {
let db = Connection::open_in_memory()?;
let cases = [
"SELECT [123::VARIANT] AS variant_list",
"SELECT {'v': 123::VARIANT} AS variant_struct",
"SELECT map(['v'], [123::VARIANT]) AS variant_map",
];
for sql in cases {
let err = match db.prepare(sql)?.query([]) {
Ok(_) => panic!("expected nested Variant query to fail for {sql}"),
Err(err) => err,
};
assert_variant_decode_error(err, 0);
}
Ok(())
}
#[test]
fn test_bind_variant_parameter_delegates_to_duckdb() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute("CREATE TABLE t (id INTEGER, variant_col VARIANT)", [])?;
db.execute("INSERT INTO t VALUES (?, ?)", crate::params![1, "hello"])?;
let count: i64 = db.query_row("SELECT COUNT(*) FROM t WHERE variant_col IS NOT NULL", [], |row| {
row.get(0)
})?;
assert_eq!(count, 1);
Ok(())
}
#[test]
fn test_bind_unsupported_container_type_returns_error() -> Result<()> {
use crate::types::Value;
struct OwnedList;
impl ToSql for OwnedList {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
Ok(ToSqlOutput::Owned(Value::List(vec![Value::Int(1), Value::Int(2)])))
}
}
let db = Connection::open_in_memory()?;
db.execute("CREATE TABLE t (id INTEGER, numbers INTEGER[])", [])?;
let list = OwnedList;
let err = db
.execute("INSERT INTO t VALUES (?, ?)", crate::params![1, list])
.unwrap_err();
assert_binding_list_error(err);
let borrowed_list = BorrowedList::new();
let err = db
.execute("INSERT INTO t VALUES (?, ?)", crate::params![2, borrowed_list])
.unwrap_err();
assert_binding_list_error(err);
Ok(())
}
#[test]
fn test_bind_error_clears_partial_parameter_state() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute("CREATE TABLE t (id INTEGER NOT NULL, name TEXT)", [])?;
let mut stmt = db.prepare("INSERT INTO t VALUES (?, ?)")?;
let list = BorrowedList::new();
let err = stmt.execute(crate::params![1, list]).unwrap_err();
assert_binding_list_error(err);
stmt.raw_bind_parameter(2, "ok")?;
assert!(stmt.raw_execute().is_err());
let count: i32 = db.query_row("SELECT COUNT(*) FROM t", [], |row| row.get(0))?;
assert_eq!(count, 0);
stmt.raw_bind_parameter(1, 7)?;
stmt.raw_bind_parameter(2, "ok")?;
assert_eq!(stmt.raw_execute()?, 1);
let row = db.query_row("SELECT id, name FROM t", [], |row| {
Ok((row.get::<_, i32>(0)?, row.get::<_, String>(1)?))
})?;
assert_eq!(row, (7, "ok".to_string()));
Ok(())
}
}