use crate::db::connection::sqlite_error;
use crate::db::row::{FromColumn, Row};
use crate::db::value::{bind_all, bind_named_all, ToSql};
use crate::db::DbError;
use crate::sqlite_vfs::ffi;
use std::ptr::NonNull;
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
use std::cell::RefCell;
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
use std::collections::BTreeMap;
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
thread_local! {
static STEP_FAILPOINTS: RefCell<BTreeMap<crate::stable::memory::ContextId, StepFailpointState>> = const { RefCell::new(BTreeMap::new()) };
}
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct StepFailpoint {
pub ordinal: u64,
pub code: std::ffi::c_int,
}
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
#[derive(Clone, Copy, Debug)]
struct StepFailpointState {
failpoint: StepFailpoint,
count: u64,
}
pub struct Statement<'connection> {
db: *mut ffi::sqlite3,
raw: NonNull<ffi::sqlite3_stmt>,
parameter_count: usize,
_connection: std::marker::PhantomData<&'connection ()>,
}
pub struct Rows<'statement, 'connection> {
statement: &'statement mut Statement<'connection>,
done: bool,
}
#[cfg(feature = "bench-profile")]
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct QueryOptionalStringTextProfile {
pub reset_bind: u64,
pub step: u64,
pub column_read: u64,
}
impl<'connection> Statement<'connection> {
pub(crate) fn new(db: *mut ffi::sqlite3, raw: NonNull<ffi::sqlite3_stmt>) -> Self {
let parameter_count =
usize::try_from(unsafe { ffi::sqlite3_bind_parameter_count(raw.as_ptr()) })
.unwrap_or(0);
Self {
db,
raw,
parameter_count,
_connection: std::marker::PhantomData,
}
}
pub(crate) fn into_raw(self) -> NonNull<ffi::sqlite3_stmt> {
let raw = self.raw;
std::mem::forget(self);
raw
}
pub fn execute(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
self.reset_and_bind(values)?;
let rc = step(self.raw.as_ptr())?;
if rc == ffi::SQLITE_DONE {
Ok(())
} else {
Err(sqlite_error(self.db, rc))
}
}
pub fn execute_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
self.reset_and_bind_named(values)?;
let rc = step(self.raw.as_ptr())?;
if rc == ffi::SQLITE_DONE {
Ok(())
} else {
Err(sqlite_error(self.db, rc))
}
}
pub fn query<'statement>(
&'statement mut self,
values: &[&dyn ToSql],
) -> Result<Rows<'statement, 'connection>, DbError> {
self.reset_and_bind(values)?;
Ok(Rows {
statement: self,
done: false,
})
}
pub fn query_named<'statement>(
&'statement mut self,
values: &[(&str, &dyn ToSql)],
) -> Result<Rows<'statement, 'connection>, DbError> {
self.reset_and_bind_named(values)?;
Ok(Rows {
statement: self,
done: false,
})
}
pub fn query_one<T, F>(&mut self, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
where
F: FnOnce(&Row<'_>) -> Result<T, DbError>,
{
let mut rows = self.query(values)?;
match rows.next_row()? {
Some(row) => f(&row),
None => Err(DbError::NotFound),
}
}
pub fn query_one_named<T, F>(
&mut self,
values: &[(&str, &dyn ToSql)],
f: F,
) -> Result<T, DbError>
where
F: FnOnce(&Row<'_>) -> Result<T, DbError>,
{
let mut rows = self.query_named(values)?;
match rows.next_row()? {
Some(row) => f(&row),
None => Err(DbError::NotFound),
}
}
pub fn query_optional<T, F>(
&mut self,
values: &[&dyn ToSql],
f: F,
) -> Result<Option<T>, DbError>
where
F: FnOnce(&Row<'_>) -> Result<T, DbError>,
{
let mut rows = self.query(values)?;
match rows.next_row()? {
Some(row) => f(&row).map(Some),
None => Ok(None),
}
}
pub fn query_optional_named<T, F>(
&mut self,
values: &[(&str, &dyn ToSql)],
f: F,
) -> Result<Option<T>, DbError>
where
F: FnOnce(&Row<'_>) -> Result<T, DbError>,
{
let mut rows = self.query_named(values)?;
match rows.next_row()? {
Some(row) => f(&row).map(Some),
None => Ok(None),
}
}
pub fn query_all<T, F>(&mut self, values: &[&dyn ToSql], mut f: F) -> Result<Vec<T>, DbError>
where
F: FnMut(&Row<'_>) -> Result<T, DbError>,
{
let mut rows = self.query(values)?;
let mut output = Vec::new();
while let Some(row) = rows.next_row()? {
output.push(f(&row)?);
}
Ok(output)
}
pub fn query_all_named<T, F>(
&mut self,
values: &[(&str, &dyn ToSql)],
mut f: F,
) -> Result<Vec<T>, DbError>
where
F: FnMut(&Row<'_>) -> Result<T, DbError>,
{
let mut rows = self.query_named(values)?;
let mut output = Vec::new();
while let Some(row) = rows.next_row()? {
output.push(f(&row)?);
}
Ok(output)
}
pub fn query_scalar<T: FromColumn>(&mut self, values: &[&dyn ToSql]) -> Result<T, DbError> {
self.query_one(values, |row| row.get(0))
}
pub fn query_scalar_named<T: FromColumn>(
&mut self,
values: &[(&str, &dyn ToSql)],
) -> Result<T, DbError> {
self.query_one_named(values, |row| row.get(0))
}
pub fn query_optional_scalar<T: FromColumn>(
&mut self,
values: &[&dyn ToSql],
) -> Result<Option<T>, DbError> {
self.query_optional(values, |row| row.get(0))
}
pub fn query_optional_string_text(&mut self, value: &str) -> Result<Option<String>, DbError> {
self.reset_and_bind_single_text(value)?;
match step(self.raw.as_ptr())? {
ffi::SQLITE_ROW => read_string_column_zero(self.raw.as_ptr()).map(Some),
ffi::SQLITE_DONE => Ok(None),
rc => Err(sqlite_error(self.db, rc)),
}
}
pub fn query_optional_string_text_len(
&mut self,
value: &str,
) -> Result<Option<usize>, DbError> {
self.reset_and_bind_single_text(value)?;
match step(self.raw.as_ptr())? {
ffi::SQLITE_ROW => read_string_column_zero_len(self.raw.as_ptr()).map(Some),
ffi::SQLITE_DONE => Ok(None),
rc => Err(sqlite_error(self.db, rc)),
}
}
#[cfg(feature = "bench-profile")]
#[doc(hidden)]
pub fn query_optional_string_text_profiled(
&mut self,
value: &str,
) -> Result<(Option<String>, QueryOptionalStringTextProfile), DbError> {
let mut profile = QueryOptionalStringTextProfile::default();
let start = instruction_counter();
self.reset_and_bind_single_text(value)?;
profile.reset_bind = instruction_counter().saturating_sub(start);
let start = instruction_counter();
let rc = step(self.raw.as_ptr())?;
profile.step = instruction_counter().saturating_sub(start);
match rc {
ffi::SQLITE_ROW => {
let start = instruction_counter();
let value = read_string_column_zero(self.raw.as_ptr()).map(Some);
profile.column_read = instruction_counter().saturating_sub(start);
value.map(|value| (value, profile))
}
ffi::SQLITE_DONE => Ok((None, profile)),
rc => Err(sqlite_error(self.db, rc)),
}
}
#[cfg(feature = "bench-profile")]
#[doc(hidden)]
pub fn query_optional_string_text_len_profiled(
&mut self,
value: &str,
) -> Result<(Option<usize>, QueryOptionalStringTextProfile), DbError> {
let mut profile = QueryOptionalStringTextProfile::default();
let start = instruction_counter();
self.reset_and_bind_single_text(value)?;
profile.reset_bind = instruction_counter().saturating_sub(start);
let start = instruction_counter();
let rc = step(self.raw.as_ptr())?;
profile.step = instruction_counter().saturating_sub(start);
match rc {
ffi::SQLITE_ROW => {
let start = instruction_counter();
let value = read_string_column_zero_len(self.raw.as_ptr()).map(Some);
profile.column_read = instruction_counter().saturating_sub(start);
value.map(|value| (value, profile))
}
ffi::SQLITE_DONE => Ok((None, profile)),
rc => Err(sqlite_error(self.db, rc)),
}
}
pub fn query_optional_scalar_named<T: FromColumn>(
&mut self,
values: &[(&str, &dyn ToSql)],
) -> Result<Option<T>, DbError> {
self.query_optional_named(values, |row| row.get(0))
}
pub fn query_column<T: FromColumn>(
&mut self,
values: &[&dyn ToSql],
) -> Result<Vec<T>, DbError> {
self.query_all(values, |row| row.get(0))
}
pub fn query_column_named<T: FromColumn>(
&mut self,
values: &[(&str, &dyn ToSql)],
) -> Result<Vec<T>, DbError> {
self.query_all_named(values, |row| row.get(0))
}
fn reset_and_bind(&mut self, values: &[&dyn ToSql]) -> Result<(), DbError> {
let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
if reset_rc != ffi::SQLITE_OK {
return Err(sqlite_error(self.db, reset_rc));
}
let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
if clear_rc != ffi::SQLITE_OK {
return Err(sqlite_error(self.db, clear_rc));
}
bind_all(self.raw.as_ptr(), values)
}
fn reset_and_bind_single_text(&mut self, value: &str) -> Result<(), DbError> {
if self.parameter_count != 1 {
return Err(DbError::ParameterCountMismatch {
expected: self.parameter_count,
actual: 1,
});
}
let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
if reset_rc != ffi::SQLITE_OK {
return Err(sqlite_error(self.db, reset_rc));
}
let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
if clear_rc != ffi::SQLITE_OK {
return Err(sqlite_error(self.db, clear_rc));
}
let len = std::ffi::c_int::try_from(value.len()).map_err(|_| DbError::TextTooLarge)?;
let bind_rc = unsafe {
ffi::sqlite3_bind_text(
self.raw.as_ptr(),
1,
value.as_ptr().cast(),
len,
ffi::SQLITE_TRANSIENT(),
)
};
if bind_rc == ffi::SQLITE_OK {
Ok(())
} else {
Err(DbError::Sqlite(bind_rc, "sqlite bind failed".to_string()))
}
}
fn reset_and_bind_named(&mut self, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
let reset_rc = unsafe { ffi::sqlite3_reset(self.raw.as_ptr()) };
if reset_rc != ffi::SQLITE_OK {
return Err(sqlite_error(self.db, reset_rc));
}
let clear_rc = unsafe { ffi::sqlite3_clear_bindings(self.raw.as_ptr()) };
if clear_rc != ffi::SQLITE_OK {
return Err(sqlite_error(self.db, clear_rc));
}
bind_named_all(self.raw.as_ptr(), values)
}
}
fn read_string_column_zero(statement: *mut ffi::sqlite3_stmt) -> Result<String, DbError> {
let actual = unsafe { ffi::sqlite3_column_type(statement, 0) };
if actual != ffi::SQLITE_TEXT {
return Err(DbError::TypeMismatch {
index: 0,
expected: "TEXT",
actual: sqlite_type_name(actual),
});
}
let text = unsafe { ffi::sqlite3_column_text(statement, 0) };
let len = unsafe { ffi::sqlite3_column_bytes(statement, 0) };
let len = usize::try_from(len).map_err(|_| DbError::TextTooLarge)?;
if len == 0 || text.is_null() {
return Ok(String::new());
}
let bytes = unsafe { std::slice::from_raw_parts(text.cast::<u8>(), len) };
Ok(String::from_utf8_lossy(bytes).into_owned())
}
fn read_string_column_zero_len(statement: *mut ffi::sqlite3_stmt) -> Result<usize, DbError> {
let actual = unsafe { ffi::sqlite3_column_type(statement, 0) };
if actual != ffi::SQLITE_TEXT {
return Err(DbError::TypeMismatch {
index: 0,
expected: "TEXT",
actual: sqlite_type_name(actual),
});
}
let len = unsafe { ffi::sqlite3_column_bytes(statement, 0) };
usize::try_from(len).map_err(|_| DbError::TextTooLarge)
}
fn sqlite_type_name(code: std::ffi::c_int) -> &'static str {
match code {
ffi::SQLITE_INTEGER => "INTEGER",
ffi::SQLITE_FLOAT => "REAL",
ffi::SQLITE_TEXT => "TEXT",
ffi::SQLITE_BLOB => "BLOB",
ffi::SQLITE_NULL => "NULL",
_ => "UNKNOWN",
}
}
impl Rows<'_, '_> {
pub fn next_row(&mut self) -> Result<Option<Row<'_>>, DbError> {
if self.done {
return Ok(None);
}
let rc = step(self.statement.raw.as_ptr())?;
match rc {
ffi::SQLITE_ROW => Ok(Some(Row::new(self.statement.raw.as_ptr()))),
ffi::SQLITE_DONE => {
self.done = true;
Ok(None)
}
_ => Err(sqlite_error(self.statement.db, rc)),
}
}
}
fn step(statement: *mut ffi::sqlite3_stmt) -> Result<std::ffi::c_int, DbError> {
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
if let Some(code) = hit_step_failpoint() {
return Err(DbError::Sqlite(code, "sqlite step failpoint".to_string()));
}
Ok(unsafe { ffi::sqlite3_step(statement) })
}
#[cfg(feature = "bench-profile")]
fn instruction_counter() -> u64 {
#[cfg(target_arch = "wasm32")]
{
ic_cdk::api::performance_counter(0)
}
#[cfg(not(target_arch = "wasm32"))]
{
0
}
}
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
pub fn set_step_failpoint(failpoint: StepFailpoint) {
if let Ok(context) = crate::stable::memory::active_context_id() {
STEP_FAILPOINTS.with(|slot| {
slot.borrow_mut().insert(
context,
StepFailpointState {
failpoint,
count: 0,
},
);
});
}
}
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
pub fn clear_step_failpoint() {
STEP_FAILPOINTS.with(|slot| slot.borrow_mut().clear());
}
#[cfg(any(test, feature = "canister-api-test-failpoints"))]
fn hit_step_failpoint() -> Option<std::ffi::c_int> {
let Ok(context) = crate::stable::memory::active_context_id() else {
return None;
};
STEP_FAILPOINTS.with(|slot| {
let mut slot = slot.borrow_mut();
let state = slot.get_mut(&context)?;
state.count += 1;
if state.failpoint.ordinal == state.count {
let code = state.failpoint.code;
slot.remove(&context);
Some(code)
} else {
None
}
})
}
impl Drop for Statement<'_> {
fn drop(&mut self) {
unsafe {
ffi::sqlite3_finalize(self.raw.as_ptr());
}
}
}