squire 0.0.1-alpha.5

Safe and idiomatic SQLite bindings
Documentation
use core::{
    ffi::{CStr, c_char, c_int},
    fmt,
    marker::PhantomData,
    ptr,
};

#[cfg(target_pointer_width = "32")]
use sqlite::sqlite3_changes;
#[cfg(target_pointer_width = "64")]
use sqlite::sqlite3_changes64;
use sqlite::{
    SQLITE_DONE, SQLITE_OK, SQLITE_ROW, sqlite3, sqlite3_bind_parameter_count,
    sqlite3_bind_parameter_name, sqlite3_clear_bindings, sqlite3_column_count, sqlite3_column_name,
    sqlite3_data_count, sqlite3_db_handle, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset,
    sqlite3_step, sqlite3_stmt,
};

use super::{
    bind::Bind,
    call::call,
    connection::{Connected, Connection},
    value::Fetch,
};
use crate::{
    error::{Error, ErrorCategory, Result},
    types::{BindIndex, ColumnIndex},
};

/// A thin wrapper around a [`sqlite3_stmt`] prepared statement pointer.
#[repr(transparent)]
pub struct Statement<'c> {
    handle: ptr::NonNull<sqlite3_stmt>,
    _connection: PhantomData<fn() -> &'c Connection>,
}

#[cfg(any(feature = "multi-thread", feature = "serialized"))]
#[cfg_attr(
    docsrs,
    doc(cfg(any(feature = "multi-thread", feature = "serialized")))
)]
unsafe impl<'c> Send for Statement<'c> {}

#[cfg(feature = "serialized")]
#[cfg_attr(docsrs, doc(cfg(feature = "serialized")))]
unsafe impl<'c> Sync for Statement<'c> {}

impl<'c> Statement<'c> {
    /// Wrap a [`sqlite3_stmt`] prepared statement pointer.
    #[inline]
    #[must_use]
    pub const fn new(handle: *mut sqlite3_stmt) -> Option<Self> {
        match ptr::NonNull::new(handle) {
            Some(handle) => Some(Self {
                handle,
                _connection: PhantomData,
            }),
            None => None,
        }
    }

    /// Prepare a [`Statement`] on a [`Connection`] from SQL `query` text.
    #[doc(alias = "sqlite3_prepare_v3")]
    #[must_use = "a Statement will leak if prepared and discarded"]
    pub fn prepare(connection: &'c Connection, query: &str, flags: u32) -> Result<(Self, usize)> {
        let length = i32::try_from(query.len()).map_err(|_| ErrorCategory::TooBig)?;
        let query_p = query.as_bytes().as_ptr().cast::<c_char>();
        let mut handle: *mut sqlite3_stmt = ptr::null_mut();
        let mut tail: *const c_char = ptr::null();

        let result = unsafe {
            sqlite3_prepare_v3(
                connection.as_ptr(),
                query_p,
                length,
                flags,
                &mut handle,
                &mut tail,
            )
        };

        let sql_length = if tail.is_null() {
            0
        } else {
            unsafe { tail.byte_offset_from_unsigned(query_p) }
        };

        match Self::new(handle) {
            Some(statement) if result == SQLITE_OK => Ok((statement, sql_length)),
            _ => Err(Error::from_prepare(connection, result).unwrap_or_default()),
        }
    }

    #[inline]
    pub(crate) unsafe fn finalize(&mut self) -> Result<()> {
        call! { sqlite3_finalize(self.as_ptr()) }
    }

    /// [Finalize][] (i.e., destroy) the prepared statement.
    ///
    /// [Finalize]: https://sqlite.org/c3ref/finalize.html
    #[doc(alias = "sqlite3_finalize")]
    pub fn close(mut self) -> Result<()> {
        unsafe { self.finalize() }
    }

    #[doc(alias = "sqlite3_column_name")]
    pub fn column_name(&self, index: ColumnIndex) -> Option<&CStr> {
        let ptr = unsafe { sqlite3_column_name(self.as_ptr(), index.value()) };

        if ptr.is_null() {
            None
        } else {
            Some(unsafe { CStr::from_ptr(ptr) })
        }
    }

    #[doc(alias = "sqlite3_column_count")]
    pub fn column_count(&self) -> c_int {
        unsafe { sqlite3_column_count(self.as_ptr()) }
    }

    /// Return the highest (1-based) parameter index used by this [`Statement`].
    #[doc(alias = "sqlite3_bind_parameter_count")]
    pub fn parameter_count(&self) -> c_int {
        unsafe { sqlite3_bind_parameter_count(self.as_ptr()) }
    }

    #[doc(alias = "sqlite3_bind_parameter_name")]
    pub fn parameter_name(&self, index: BindIndex) -> Option<&CStr> {
        let ptr = unsafe { sqlite3_bind_parameter_name(self.as_ptr(), index.value()) };

        if ptr.is_null() {
            None
        } else {
            Some(unsafe { CStr::from_ptr(ptr) })
        }
    }

    /// Bind the parameter specified by `index` to the given `value`.
    ///
    /// # Safety
    ///
    /// Implementations access the `sqlite3_bind_*` API’s directly. If these
    /// API’s are used to bind a pointer non-`SQLITE_TRANSIENT`ly, the caller is
    /// responsible for ensuring the pointer remains valid for the duration of
    /// the binding; and if a [destructor](sqlite::sqlite3_destructor_type) is
    /// used, for SQLite to call it at the end of the binding lifecycle.
    pub unsafe fn bind<'b, B>(&self, index: BindIndex, value: B) -> Result<()>
    where
        B: Bind<'b>,
        'c: 'b,
    {
        unsafe { value.bind(self, index) }
    }

    #[doc(alias = "sqlite3_clear_bindings")]
    pub fn clear(&mut self) -> Result<()> {
        call! { sqlite3_clear_bindings(self.as_ptr()) }
    }

    /// [Step][step] the [statement](Statement) and read the next row.
    ///
    /// Returns:
    /// - `Ok(true)` if [`sqlite3_step`][step] returns `SQLITE_ROW`
    /// - `Ok(false)` if [`sqlite3_step`][step] returns `SQLITE_DONE`
    /// - an [`Error`] if [`sqlite3_step`][step] returns an error result code
    ///
    /// [step]: https://sqlite.org/c3ref/step.html
    ///
    /// # Safety
    ///
    /// Callers are responsible for managing the `ffi::Statement` lifecycle.
    #[doc(alias = "sqlite3_step")]
    pub unsafe fn row(&self) -> Result<bool> {
        let result = unsafe { sqlite3_step(self.as_ptr()) };

        if result == SQLITE_ROW {
            Ok(true)
        } else if result == SQLITE_DONE {
            Ok(false)
        } else {
            match Error::from_connection(self, result) {
                Some(err) => Err(err),
                None => Ok(false),
            }
        }
    }

    /// [Execute][step] the [statement](Statement), returning `()`, the
    /// last-inserted [`RowId`](crate::RowId), or the
    /// [number of changes](primitive@isize).
    ///
    /// Returns:
    /// - the [`Conclusion`] if [`sqlite3_step`][step] returns `SQLITE_DONE`
    /// - a [misuse error](crate::ErrorCategory::Misuse) if [`sqlite3_step`][step] returns `SQLITE_ROW`
    /// - an [`Error`] if [`sqlite3_step`][step] returns an error result code
    ///
    /// [step]: https://sqlite.org/c3ref/step.html
    ///
    /// # Safety
    ///
    /// Callers are responsible for managing the `ffi::Statement` lifecycle.
    pub unsafe fn execute<C: Conclusion>(&self) -> Result<C> {
        let result = unsafe { sqlite3_step(self.as_ptr()) };

        if result == SQLITE_DONE {
            let connection_ptr = unsafe { self.connection_ptr() };
            Ok(unsafe { C::from_connection_ptr(connection_ptr) })
        } else if result == SQLITE_ROW {
            Err(ErrorCategory::Misuse.into())
        } else {
            Err(Error::from_connection(self, result).unwrap_or_default())
        }
    }

    /// [Reset][reset] the [statement](Statement).
    ///
    /// [reset]: https://sqlite.org/c3ref/reset.html
    ///
    /// # Safety
    ///
    /// Callers are responsible for managing the `ffi::Statement` lifecycle.
    #[doc(alias = "sqlite3_reset")]
    pub unsafe fn reset(&mut self) -> Result<()> {
        call! { sqlite3_reset(self.as_ptr()) }
    }

    /// [Fetch][fetch] a column value from the [statement](Statement).
    ///
    /// [fetch]: https://sqlite.org/c3ref/column_blob.html
    ///
    /// # Safety
    ///
    /// Callers are responsible for managing the `ffi::Statement` lifecycle, and
    /// ensuring the [`ColumnIndex`] is in bounds. See [`fetch`](Fetch::fetch)
    /// for details.
    pub unsafe fn fetch<'r, T: Fetch<'r>>(&'r self, column: ColumnIndex) -> T {
        unsafe { T::fetch(self, column) }
    }

    #[doc(alias = "sqlite3_data_count")]
    pub fn data_count(&mut self) -> c_int {
        unsafe { sqlite3_data_count(self.as_ptr()) }
    }

    /// Access the raw [`sqlite3_stmt`] pointer.
    #[inline]
    pub const fn as_ptr(&self) -> *mut sqlite3_stmt {
        self.handle.as_ptr()
    }

    #[inline]
    pub(crate) unsafe fn connection_ptr(&self) -> *mut sqlite3 {
        unsafe { sqlite3_db_handle(self.as_ptr()) }
    }
}

impl<'c> Connected for Statement<'c> {
    fn as_connection_ptr(&self) -> *mut sqlite3 {
        unsafe { self.connection_ptr() }
    }
}

impl<'c> Connected for &Statement<'c> {
    fn as_connection_ptr(&self) -> *mut sqlite3 {
        unsafe { self.connection_ptr() }
    }
}

impl<'c> Connected for &mut Statement<'c> {
    fn as_connection_ptr(&self) -> *mut sqlite3 {
        unsafe { self.connection_ptr() }
    }
}

pub trait Execute<'c>: Connected {
    /// Access the raw [`sqlite3_stmt`] pointer.
    ///
    /// # Safety
    ///
    /// Callers are responsible for managing the `ffi::Statement` lifecycle.
    unsafe fn as_statement_ptr(&self) -> *mut sqlite3_stmt;

    /// Access the [`Statement`] being executed.
    ///
    /// # Safety
    ///
    /// Callers are responsible for managing the `ffi::Statement` lifecycle.
    unsafe fn cursor<'e>(&'e mut self) -> &'e mut Statement<'c>
    where
        'c: 'e,
        Self: 'e;

    /// Reset the [`Statement`], preparing it for new binding and execution.
    ///
    /// # Safety
    ///
    /// Callers are responsible for managing the `ffi::Statement` lifecycle.
    unsafe fn reset(&mut self) -> Result<()>;
}

impl<'c> Execute<'c> for Statement<'c> {
    unsafe fn as_statement_ptr(&self) -> *mut sqlite3_stmt {
        self.as_ptr()
    }

    unsafe fn cursor<'e>(&'e mut self) -> &'e mut Statement<'c>
    where
        'c: 'e,
        Self: 'e,
    {
        self
    }

    #[inline(always)]
    unsafe fn reset(&mut self) -> Result<()> {
        Ok(())
    }
}

impl<'c, 's> Execute<'c> for &'s mut Statement<'c>
where
    'c: 's,
{
    unsafe fn as_statement_ptr(&self) -> *mut sqlite3_stmt {
        self.as_ptr()
    }

    unsafe fn cursor<'e>(&'e mut self) -> &'e mut Statement<'c>
    where
        'c: 'e,
        Self: 'e,
    {
        self
    }

    #[inline]
    unsafe fn reset(&mut self) -> Result<()> {
        call! { sqlite3_reset(self.as_statement_ptr()) }
    }
}

impl fmt::Debug for Statement<'_> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "Statement({:p})", self.handle)
    }
}

/// A value which represents the _conclusion_ of [`Statement`] execution.
pub trait Conclusion: Sized {
    /// Populate the [`Conclusion`] from a [`sqlite3`] connection pointer.
    ///
    /// # Safety
    ///
    /// Callers are responsible for managing the `ffi::Statement` lifecycle.
    unsafe fn from_connection_ptr(connection: *mut sqlite3) -> Self;
}

impl Conclusion for () {
    #[inline(always)]
    unsafe fn from_connection_ptr(_connection: *mut sqlite3) -> Self {}
}

impl Conclusion for isize {
    #[inline(always)]
    unsafe fn from_connection_ptr(connection: *mut sqlite3) -> Self {
        #[cfg(target_pointer_width = "32")]
        let changes = unsafe { sqlite3_changes(connection) };

        #[cfg(target_pointer_width = "64")]
        let changes = unsafe { sqlite3_changes64(connection) };

        changes as Self
    }
}