use std::ptr::NonNull;
use std::sync::Arc;
use quex_mariadb_sys as ffi;
use super::connection::Connection;
use super::error::{Error, ExecuteResult, Result};
use super::rows::{Metadata, ResultSet};
use super::runtime::{
DriveOperation, DriveOutput, ParamBindings, StmtHandle, enable_stmt_max_length,
ensure_mysql_thread_ready, ensure_mysql_thread_ready_for_drop,
};
use super::value::{ParamRefSlice, ParamSource, Value, ValueRef, values_as_refs};
pub struct Statement<'a> {
pub(crate) conn: &'a mut Connection,
pub(crate) stmt: StmtHandle,
pub(crate) result_metadata: Option<Arc<Metadata>>,
}
impl<'a> Statement<'a> {
pub async fn execute(&mut self, params: &[Value]) -> Result<ResultSet> {
let refs = values_as_refs(params);
self.execute_ref(&refs).await
}
pub async fn execute_ref(&mut self, params: &[ValueRef<'_>]) -> Result<ResultSet> {
self.execute_source(&ParamRefSlice(params)).await
}
pub async fn execute_source<P>(&mut self, params: &P) -> Result<ResultSet>
where
P: ParamSource + ?Sized,
{
ensure_mysql_thread_ready()?;
unsafe {
let stmt = self.stmt;
let expected = ffi::mysql_stmt_param_count(stmt.as_ptr()) as usize;
if expected != params.len() {
return Err(Error::new(format!(
"statement expects {} parameters but got {}",
expected,
params.len()
)));
}
let mut bindings = ParamBindings::new(params);
if expected != 0
&& ffi::mysql_stmt_bind_param(stmt.as_ptr(), bindings.binds.as_mut_ptr()) != 0
{
return Err(Error::from_stmt(
stmt.as_ptr(),
"mysql_stmt_bind_param failed",
));
}
let mut output = DriveOutput::stmt_execute();
self.conn
.drive(DriveOperation::StmtExecute { stmt }, &mut output)
.await?;
if output.stmt_execute_code() != 0 {
return Err(Error::from_stmt(stmt.as_ptr(), "mysql_stmt_execute failed"));
}
if ffi::mysql_stmt_field_count(stmt.as_ptr()) == 0 {
return Ok(ResultSet::empty());
}
enable_stmt_max_length(stmt.as_ptr())?;
let mut output = DriveOutput::stmt_store_result();
self.conn
.drive(DriveOperation::StmtStoreResult { stmt }, &mut output)
.await?;
if output.stmt_store_result_code() != 0 {
return Err(Error::from_stmt(
stmt.as_ptr(),
"mysql_stmt_store_result failed",
));
}
let metadata = match &self.result_metadata {
Some(metadata) => Arc::clone(metadata),
None => {
let meta = NonNull::new(ffi::mysql_stmt_result_metadata(stmt.as_ptr()))
.ok_or_else(|| {
Error::from_stmt(
stmt.as_ptr(),
"statement returned rows but no result metadata",
)
})?;
let metadata = Arc::new(Metadata::from_result(meta));
ffi::mysql_free_result(meta.as_ptr());
self.result_metadata = Some(Arc::clone(&metadata));
metadata
}
};
ResultSet::statement(self.stmt, metadata)
}
}
pub async fn exec(&mut self, params: &[Value]) -> Result<ExecuteResult> {
let refs = values_as_refs(params);
self.exec_ref(&refs).await
}
pub async fn exec_ref(&mut self, params: &[ValueRef<'_>]) -> Result<ExecuteResult> {
self.exec_source(&ParamRefSlice(params)).await
}
pub async fn exec_source<P>(&mut self, params: &P) -> Result<ExecuteResult>
where
P: ParamSource + ?Sized,
{
ensure_mysql_thread_ready()?;
unsafe {
let stmt = self.stmt;
let expected = ffi::mysql_stmt_param_count(stmt.as_ptr()) as usize;
if expected != params.len() {
return Err(Error::new(format!(
"statement expects {} parameters but got {}",
expected,
params.len()
)));
}
let mut bindings = ParamBindings::new(params);
if expected != 0
&& ffi::mysql_stmt_bind_param(stmt.as_ptr(), bindings.binds.as_mut_ptr()) != 0
{
return Err(Error::from_stmt(
stmt.as_ptr(),
"mysql_stmt_bind_param failed",
));
}
let mut output = DriveOutput::stmt_execute();
self.conn
.drive(DriveOperation::StmtExecute { stmt }, &mut output)
.await?;
if output.stmt_execute_code() != 0 {
return Err(Error::from_stmt(stmt.as_ptr(), "mysql_stmt_execute failed"));
}
if ffi::mysql_stmt_field_count(stmt.as_ptr()) != 0 {
return Err(Error::new("statement returned rows; use execute instead"));
}
Ok(ExecuteResult {
rows_affected: ffi::mysql_stmt_affected_rows(stmt.as_ptr()) as u64,
last_insert_id: ffi::mysql_stmt_insert_id(stmt.as_ptr()) as u64,
})
}
}
}
pub struct CachedStatement<'a> {
pub(crate) conn: &'a mut Connection,
pub(crate) key: Box<str>,
}
impl CachedStatement<'_> {
pub async fn execute(&mut self, params: &[Value]) -> Result<ResultSet> {
let refs = values_as_refs(params);
self.execute_ref(&refs).await
}
pub async fn execute_ref(&mut self, params: &[ValueRef<'_>]) -> Result<ResultSet> {
self.execute_source(&ParamRefSlice(params)).await
}
pub async fn execute_source<P>(&mut self, params: &P) -> Result<ResultSet>
where
P: ParamSource + ?Sized,
{
ensure_mysql_thread_ready()?;
let key = self.key.clone();
let (stmt, param_count) = {
let entry = self
.conn
.statement_cache
.get(key.as_ref())
.ok_or_else(|| Error::new("cached statement missing"))?;
(entry.stmt, entry.param_count)
};
if param_count != params.len() {
return Err(Error::new(format!(
"statement expects {} parameters but got {}",
param_count,
params.len()
)));
}
{
let entry = self
.conn
.statement_cache
.get_mut(key.as_ref())
.ok_or_else(|| Error::new("cached statement missing"))?;
unsafe {
if ffi::mysql_stmt_reset(stmt.as_ptr()) != 0 {
return Err(Error::from_stmt(stmt.as_ptr(), "mysql_stmt_reset failed"));
}
}
entry.scratch.bind_source(params)?;
unsafe {
if param_count != 0
&& ffi::mysql_stmt_bind_param(stmt.as_ptr(), entry.scratch.binds.as_mut_ptr())
!= 0
{
return Err(Error::from_stmt(
stmt.as_ptr(),
"mysql_stmt_bind_param failed",
));
}
}
}
let mut output = DriveOutput::stmt_execute();
self.conn
.drive(DriveOperation::StmtExecute { stmt }, &mut output)
.await?;
if output.stmt_execute_code() != 0 {
return Err(unsafe { Error::from_stmt(stmt.as_ptr(), "mysql_stmt_execute failed") });
}
if unsafe { ffi::mysql_stmt_field_count(stmt.as_ptr()) } == 0 {
return Ok(ResultSet::empty());
}
enable_stmt_max_length(stmt.as_ptr())?;
let mut output = DriveOutput::stmt_store_result();
self.conn
.drive(DriveOperation::StmtStoreResult { stmt }, &mut output)
.await?;
if output.stmt_store_result_code() != 0 {
return Err(unsafe {
Error::from_stmt(stmt.as_ptr(), "mysql_stmt_store_result failed")
});
}
let metadata = match self
.conn
.statement_cache
.get(key.as_ref())
.and_then(|entry| entry.result_metadata.clone())
{
Some(metadata) => metadata,
None => {
let meta = NonNull::new(unsafe { ffi::mysql_stmt_result_metadata(stmt.as_ptr()) })
.ok_or_else(|| unsafe {
Error::from_stmt(
stmt.as_ptr(),
"statement returned rows but no result metadata",
)
})?;
let metadata = Arc::new(Metadata::from_result(meta));
unsafe {
ffi::mysql_free_result(meta.as_ptr());
}
self.conn
.statement_cache
.get_mut(key.as_ref())
.expect("cached statement missing")
.result_metadata = Some(Arc::clone(&metadata));
metadata
}
};
ResultSet::statement(stmt, metadata)
}
pub async fn exec(&mut self, params: &[Value]) -> Result<ExecuteResult> {
let refs = values_as_refs(params);
self.exec_ref(&refs).await
}
pub async fn exec_ref(&mut self, params: &[ValueRef<'_>]) -> Result<ExecuteResult> {
self.exec_source(&ParamRefSlice(params)).await
}
pub async fn exec_source<P>(&mut self, params: &P) -> Result<ExecuteResult>
where
P: ParamSource + ?Sized,
{
ensure_mysql_thread_ready()?;
let key = self.key.clone();
let (stmt, param_count) = {
let entry = self
.conn
.statement_cache
.get(key.as_ref())
.ok_or_else(|| Error::new("cached statement missing"))?;
(entry.stmt, entry.param_count)
};
if param_count != params.len() {
return Err(Error::new(format!(
"statement expects {} parameters but got {}",
param_count,
params.len()
)));
}
{
let entry = self
.conn
.statement_cache
.get_mut(key.as_ref())
.ok_or_else(|| Error::new("cached statement missing"))?;
unsafe {
if ffi::mysql_stmt_reset(stmt.as_ptr()) != 0 {
return Err(Error::from_stmt(stmt.as_ptr(), "mysql_stmt_reset failed"));
}
}
entry.scratch.bind_source(params)?;
unsafe {
if param_count != 0
&& ffi::mysql_stmt_bind_param(stmt.as_ptr(), entry.scratch.binds.as_mut_ptr())
!= 0
{
return Err(Error::from_stmt(
stmt.as_ptr(),
"mysql_stmt_bind_param failed",
));
}
}
}
let mut output = DriveOutput::stmt_execute();
self.conn
.drive(DriveOperation::StmtExecute { stmt }, &mut output)
.await?;
if output.stmt_execute_code() != 0 {
return Err(unsafe { Error::from_stmt(stmt.as_ptr(), "mysql_stmt_execute failed") });
}
if unsafe { ffi::mysql_stmt_field_count(stmt.as_ptr()) } != 0 {
return Err(Error::new("statement returned rows; use execute instead"));
}
let rows_affected = unsafe { ffi::mysql_stmt_affected_rows(stmt.as_ptr()) as u64 };
let last_insert_id = unsafe { ffi::mysql_stmt_insert_id(stmt.as_ptr()) as u64 };
Ok(ExecuteResult {
rows_affected,
last_insert_id,
})
}
}
impl Drop for Statement<'_> {
fn drop(&mut self) {
if !ensure_mysql_thread_ready_for_drop() {
return;
}
unsafe {
ffi::mysql_stmt_close(self.stmt.as_ptr());
}
}
}
#[cfg(test)]
mod tests {
use std::env;
use std::thread;
use super::super::options::ConnectOptions;
use super::super::value::Value;
use super::*;
fn mysql_test_options() -> ConnectOptions {
let mut options = ConnectOptions::new();
if let Ok(host) = env::var("QUEX_TEST_MYSQL_HOST") {
options = options.host(host);
}
if let Ok(port) = env::var("QUEX_TEST_MYSQL_PORT") {
options = options.port(port.parse().expect("valid mysql test port"));
}
if let Ok(user) = env::var("QUEX_TEST_MYSQL_USER") {
options = options.user(user);
}
if let Ok(password) = env::var("QUEX_TEST_MYSQL_PASSWORD") {
options = options.password(password);
}
if let Ok(database) = env::var("QUEX_TEST_MYSQL_DATABASE") {
options = options.database(database);
}
if let Ok(unix_socket) = env::var("QUEX_TEST_MYSQL_SOCKET") {
options = options.unix_socket(unix_socket);
}
options
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "requires a local MariaDB instance and QUEX_TEST_MYSQL_* env if defaults are unsuitable"]
async fn mysql_statement_moves_across_tokio_tasks_before_execute_and_drop() {
let conn = Box::new(
Connection::connect(mysql_test_options())
.await
.expect("connect"),
);
let conn_ptr = Box::into_raw(conn);
let create_thread = thread::current().id();
let stmt = unsafe {
(&mut *conn_ptr)
.prepare("select ? as id, ? as name")
.await
.expect("prepare")
};
let (execute_thread, rows) = tokio::spawn(async move {
tokio::task::yield_now().await;
let execute_thread = thread::current().id();
let mut stmt = stmt;
let rows = stmt
.execute(&[Value::I64(11), Value::String("Ada".into())])
.await
.expect("execute");
(execute_thread, rows)
})
.await
.expect("join statement task");
assert_ne!(
execute_thread, create_thread,
"statement execute stayed on the creator thread"
);
let drop_thread = tokio::spawn(async move {
tokio::task::yield_now().await;
let drop_thread = thread::current().id();
drop(rows);
drop_thread
})
.await
.expect("join rows drop task");
assert_ne!(
drop_thread, execute_thread,
"rows drop stayed on the execute thread"
);
unsafe {
drop(Box::from_raw(conn_ptr));
}
}
}