#![allow(clippy::rc_buffer)]
use crate::connection::ConnectionHandle;
use crate::statement::StatementHandle;
use crate::{SqliteColumn, SqliteError};
use bytes::{Buf, Bytes};
use libsqlite3_sys::{
sqlite3, sqlite3_prepare_v3, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT,
};
use rbdc::err_protocol;
use rbdc::error::Error;
use rbdc::ext::ustr::UStr;
use smallvec::SmallVec;
use std::collections::HashMap;
use std::os::raw::c_char;
use std::ptr::{null, null_mut, NonNull};
use std::sync::Arc;
use std::{cmp, i32};
#[derive(Debug)]
pub struct VirtualStatement {
persistent: bool,
index: Option<usize>,
tail: Bytes,
pub(crate) handles: SmallVec<[StatementHandle; 1]>,
pub(crate) columns: SmallVec<[Arc<Vec<SqliteColumn>>; 1]>,
pub(crate) column_names: SmallVec<[Arc<HashMap<UStr, usize>>; 1]>,
}
pub struct PreparedStatement<'a> {
pub(crate) handle: &'a mut StatementHandle,
pub(crate) columns: &'a Arc<Vec<SqliteColumn>>,
pub(crate) column_names: &'a Arc<HashMap<UStr, usize>>,
}
impl VirtualStatement {
pub(crate) fn new(mut query: &str, persistent: bool) -> Result<Self, Error> {
query = query.trim();
if query.len() > i32::max_value() as usize {
return Err(err_protocol!(
"query string must be smaller than {} bytes",
i32::MAX
));
}
Ok(Self {
persistent,
tail: Bytes::from(String::from(query)),
handles: SmallVec::with_capacity(1),
index: None,
columns: SmallVec::with_capacity(1),
column_names: SmallVec::with_capacity(1),
})
}
pub(crate) fn prepare_next(
&mut self,
conn: &mut ConnectionHandle,
) -> Result<Option<PreparedStatement<'_>>, Error> {
self.index = self
.index
.map(|idx| cmp::min(idx + 1, self.handles.len()))
.or(Some(0));
while self.handles.len() <= self.index.unwrap_or(0) {
if self.tail.is_empty() {
return Ok(None);
}
if let Some(statement) = prepare(conn.as_ptr(), &mut self.tail, self.persistent)? {
let num = statement.column_count();
let mut columns = Vec::with_capacity(num);
let mut column_names = HashMap::with_capacity(num);
for i in 0..num {
let name: UStr = statement.column_name(i).to_owned().into();
let type_info = statement
.column_decltype(i)
.unwrap_or_else(|| statement.column_type_info(i));
columns.push(SqliteColumn {
ordinal: i,
name: name.clone(),
type_info,
});
column_names.insert(name, i);
}
self.handles.push(statement);
self.columns.push(Arc::new(columns));
self.column_names.push(Arc::new(column_names));
}
}
Ok(self.current())
}
pub fn current(&mut self) -> Option<PreparedStatement<'_>> {
self.index
.filter(|&idx| idx < self.handles.len())
.map(move |idx| PreparedStatement {
handle: &mut self.handles[idx],
columns: &self.columns[idx],
column_names: &self.column_names[idx],
})
}
pub fn reset(&mut self) -> Result<(), Error> {
self.index = None;
for handle in self.handles.iter_mut() {
handle.reset()?;
handle.clear_bindings();
}
Ok(())
}
}
fn prepare(
conn: *mut sqlite3,
query: &mut Bytes,
persistent: bool,
) -> Result<Option<StatementHandle>, Error> {
let mut flags = 0;
if persistent {
flags |= SQLITE_PREPARE_PERSISTENT;
}
while !query.is_empty() {
let mut statement_handle: *mut sqlite3_stmt = null_mut();
let mut tail: *const c_char = null();
let query_ptr = query.as_ptr() as *const c_char;
let query_len = query.len() as i32;
let status = unsafe {
sqlite3_prepare_v3(
conn,
query_ptr,
query_len,
flags as u32,
&mut statement_handle,
&mut tail,
)
};
if status != SQLITE_OK {
return Err(SqliteError::new(conn).into());
}
let n = (tail as usize) - (query_ptr as usize);
query.advance(n);
if let Some(handle) = NonNull::new(statement_handle) {
return Ok(Some(StatementHandle::new(handle)));
}
}
Ok(None)
}