use std::collections::HashMap;
use std::ffi::CString;
use std::ptr::{self, NonNull};
use std::sync::Arc;
use std::time::Duration;
use quex_mariadb_sys as ffi;
use tokio::io::unix::AsyncFd;
use super::error::{Error, ExecuteResult, Result};
use super::options::ConnectOptions;
use super::rows::{Metadata, ResultSet};
use super::runtime::{
DriveOperation, DriveOutput, MysqlHandle, ParamScratch, SocketRef, StmtHandle, WAIT_EXCEPT,
WAIT_READ, WAIT_TIMEOUT, WAIT_WRITE, continue_operation, ensure_mysql_thread_ready,
ensure_mysql_thread_ready_for_drop, start_operation, to_cstring_ptr,
};
use super::statement::{CachedStatement, Statement};
pub(crate) struct CachedStmtEntry {
pub(crate) stmt: StmtHandle,
pub(crate) param_count: usize,
pub(crate) result_metadata: Option<Arc<Metadata>>,
pub(crate) scratch: ParamScratch,
}
pub struct Connection {
pub(crate) mysql: MysqlHandle,
socket: Option<AsyncFd<SocketRef>>,
query_metadata_cache: HashMap<Box<str>, Arc<Metadata>>,
pub(crate) statement_cache: HashMap<Box<str>, CachedStmtEntry>,
}
impl Connection {
pub async fn connect(options: ConnectOptions) -> Result<Self> {
ensure_mysql_thread_ready()?;
unsafe {
let mysql = ffi::mysql_init(ptr::null_mut());
let mysql = MysqlHandle(
NonNull::new(mysql).ok_or_else(|| Error::new("mysql_init returned null"))?,
);
if ffi::mysql_optionsv(mysql.as_ptr(), ffi::mysql_option_MYSQL_OPT_NONBLOCK, 0usize)
!= 0
{
let error = Error::from_mysql(mysql.as_ptr(), "failed to enable nonblocking mode");
ffi::mysql_close(mysql.as_ptr());
return Err(error);
}
let host = to_cstring_ptr(options.host.as_deref())?;
let user = to_cstring_ptr(options.user.as_deref())?;
let password = to_cstring_ptr(options.password.as_deref())?;
let database = to_cstring_ptr(options.database.as_deref())?;
let unix_socket = to_cstring_ptr(options.unix_socket.as_deref())?;
let mut conn = Self {
mysql,
socket: None,
query_metadata_cache: HashMap::new(),
statement_cache: HashMap::new(),
};
let mut out = DriveOutput::connect(mysql);
conn.drive(
DriveOperation::Connect {
mysql,
host: &host,
user: &user,
password: &password,
database: &database,
port: options.port,
unix_socket: &unix_socket,
},
&mut out,
)
.await?;
Ok(conn)
}
}
pub async fn query(&mut self, sql_text: &str) -> Result<ResultSet> {
ensure_mysql_thread_ready()?;
unsafe {
let sql = CString::new(sql_text)?;
let mysql = self.mysql;
let mut output = DriveOutput::query();
self.drive(DriveOperation::Query { mysql, sql: &sql }, &mut output)
.await?;
if output.query_code() != 0 {
return Err(Error::from_mysql(mysql.as_ptr(), "query failed"));
}
if ffi::mysql_field_count(mysql.as_ptr()) == 0 {
return Ok(ResultSet::empty());
}
let mut output = DriveOutput::store_result();
self.drive(DriveOperation::StoreResult { mysql }, &mut output)
.await?;
let raw_result = NonNull::new(output.store_result_ptr()).ok_or_else(|| {
Error::from_mysql(mysql.as_ptr(), "query did not produce a buffered result")
})?;
let metadata = match self.query_metadata_cache.get(sql_text) {
Some(metadata) => Arc::clone(metadata),
None => {
let metadata = Arc::new(Metadata::from_result(raw_result));
self.query_metadata_cache
.insert(sql_text.into(), Arc::clone(&metadata));
metadata
}
};
Ok(ResultSet::text(raw_result, metadata))
}
}
pub async fn execute(&mut self, sql_text: &str) -> Result<ExecuteResult> {
ensure_mysql_thread_ready()?;
unsafe {
let sql = CString::new(sql_text)?;
let mysql = self.mysql;
let mut output = DriveOutput::query();
self.drive(DriveOperation::Query { mysql, sql: &sql }, &mut output)
.await?;
if output.query_code() != 0 {
return Err(Error::from_mysql(mysql.as_ptr(), "query failed"));
}
if ffi::mysql_field_count(mysql.as_ptr()) != 0 {
return Err(Error::new("statement returned rows; use query instead"));
}
Ok(ExecuteResult {
rows_affected: ffi::mysql_affected_rows(mysql.as_ptr()) as u64,
last_insert_id: ffi::mysql_insert_id(mysql.as_ptr()) as u64,
})
}
}
pub async fn prepare(&mut self, sql: &str) -> Result<Statement<'_>> {
ensure_mysql_thread_ready()?;
unsafe {
let stmt = ffi::mysql_stmt_init(self.mysql.as_ptr());
let stmt =
StmtHandle(NonNull::new(stmt).ok_or_else(|| {
Error::from_mysql(self.mysql.as_ptr(), "mysql_stmt_init failed")
})?);
let sql = CString::new(sql)?;
let mut output = DriveOutput::stmt_prepare();
self.drive(DriveOperation::StmtPrepare { stmt, sql: &sql }, &mut output)
.await?;
if output.stmt_prepare_code() != 0 {
let error = Error::from_stmt(stmt.as_ptr(), "mysql_stmt_prepare failed");
ffi::mysql_stmt_close(stmt.as_ptr());
return Err(error);
}
Ok(Statement {
conn: self,
stmt,
result_metadata: None,
})
}
}
pub async fn prepare_cached(&mut self, sql: &str) -> Result<CachedStatement<'_>> {
if !self.statement_cache.contains_key(sql) {
let stmt = self.prepare_stmt(sql).await?;
let param_count = unsafe { ffi::mysql_stmt_param_count(stmt.as_ptr()) as usize };
self.statement_cache.insert(
sql.into(),
CachedStmtEntry {
stmt,
param_count,
result_metadata: None,
scratch: ParamScratch::new(param_count),
},
);
}
Ok(CachedStatement {
conn: self,
key: sql.into(),
})
}
pub async fn begin(&mut self) -> Result<Transaction<'_>> {
self.query("START TRANSACTION").await?;
Ok(Transaction {
conn: self,
finished: false,
})
}
pub async fn commit(&mut self) -> Result<()> {
ensure_mysql_thread_ready()?;
unsafe {
let mysql = self.mysql;
let mut output = DriveOutput::commit();
self.drive(DriveOperation::Commit { mysql }, &mut output)
.await?;
if output.commit_code() != 0 {
return Err(Error::from_mysql(mysql.as_ptr(), "commit failed"));
}
Ok(())
}
}
pub async fn rollback(&mut self) -> Result<()> {
ensure_mysql_thread_ready()?;
unsafe {
let mysql = self.mysql;
let mut output = DriveOutput::rollback();
self.drive(DriveOperation::Rollback { mysql }, &mut output)
.await?;
if output.rollback_code() != 0 {
return Err(Error::from_mysql(mysql.as_ptr(), "rollback failed"));
}
Ok(())
}
}
pub(crate) async fn drive(
&mut self,
op: DriveOperation<'_>,
out: &mut DriveOutput,
) -> Result<()> {
ensure_mysql_thread_ready()?;
let mut status = unsafe { start_operation(op, out) };
while status != 0 {
let ready = self.wait_for(status).await?;
ensure_mysql_thread_ready()?;
status = unsafe { continue_operation(op, out, ready) };
}
Ok(())
}
async fn wait_for(&mut self, status: i32) -> Result<i32> {
self.refresh_socket()?;
let timeout_ms = unsafe { ffi::mysql_get_timeout_value_ms(self.mysql.as_ptr()) } as u64;
let wants_read = (status & (WAIT_READ | WAIT_EXCEPT)) != 0;
let wants_write = (status & WAIT_WRITE) != 0;
let wants_timeout = (status & WAIT_TIMEOUT) != 0;
let socket = self
.socket
.as_ref()
.ok_or_else(|| Error::new("libmariadb did not expose a valid socket"))?;
match (wants_read, wants_write, wants_timeout) {
(true, true, true) => {
tokio::select! {
ready = socket.readable() => {
let mut ready = ready?;
ready.clear_ready();
Ok(WAIT_READ)
}
ready = socket.writable() => {
let mut ready = ready?;
ready.clear_ready();
Ok(WAIT_WRITE)
}
_ = tokio::time::sleep(Duration::from_millis(timeout_ms)) => Ok(WAIT_TIMEOUT),
}
}
(true, true, false) => {
tokio::select! {
ready = socket.readable() => {
let mut ready = ready?;
ready.clear_ready();
Ok(WAIT_READ)
}
ready = socket.writable() => {
let mut ready = ready?;
ready.clear_ready();
Ok(WAIT_WRITE)
}
}
}
(true, false, true) => {
tokio::select! {
ready = socket.readable() => {
let mut ready = ready?;
ready.clear_ready();
Ok(WAIT_READ)
}
_ = tokio::time::sleep(Duration::from_millis(timeout_ms)) => Ok(WAIT_TIMEOUT),
}
}
(false, true, true) => {
tokio::select! {
ready = socket.writable() => {
let mut ready = ready?;
ready.clear_ready();
Ok(WAIT_WRITE)
}
_ = tokio::time::sleep(Duration::from_millis(timeout_ms)) => Ok(WAIT_TIMEOUT),
}
}
(true, false, false) => {
let mut ready = socket.readable().await?;
ready.clear_ready();
Ok(WAIT_READ)
}
(false, true, false) => {
let mut ready = socket.writable().await?;
ready.clear_ready();
Ok(WAIT_WRITE)
}
(false, false, true) => {
tokio::time::sleep(Duration::from_millis(timeout_ms)).await;
Ok(WAIT_TIMEOUT)
}
(false, false, false) => {
Err(Error::new("libmariadb returned an unsupported wait status"))
}
}
}
fn refresh_socket(&mut self) -> Result<()> {
let fd = unsafe { ffi::mysql_get_socket(self.mysql.as_ptr()) };
if fd < 0 {
self.socket = None;
return Err(Error::new("libmariadb did not expose a valid socket"));
}
let needs_refresh = self
.socket
.as_ref()
.is_none_or(|socket| socket.get_ref().0 != fd);
if needs_refresh {
self.socket = Some(AsyncFd::new(SocketRef(fd))?);
}
Ok(())
}
async fn prepare_stmt(&mut self, sql: &str) -> Result<StmtHandle> {
unsafe {
let stmt = ffi::mysql_stmt_init(self.mysql.as_ptr());
let stmt =
StmtHandle(NonNull::new(stmt).ok_or_else(|| {
Error::from_mysql(self.mysql.as_ptr(), "mysql_stmt_init failed")
})?);
let sql = CString::new(sql)?;
let mut output = DriveOutput::stmt_prepare();
self.drive(DriveOperation::StmtPrepare { stmt, sql: &sql }, &mut output)
.await?;
if output.stmt_prepare_code() != 0 {
let error = Error::from_stmt(stmt.as_ptr(), "mysql_stmt_prepare failed");
ffi::mysql_stmt_close(stmt.as_ptr());
return Err(error);
}
Ok(stmt)
}
}
}
impl Drop for Connection {
fn drop(&mut self) {
if !ensure_mysql_thread_ready_for_drop() {
return;
}
unsafe {
for entry in self.statement_cache.values() {
ffi::mysql_stmt_close(entry.stmt.as_ptr());
}
ffi::mysql_close(self.mysql.as_ptr());
}
}
}
pub struct Transaction<'a> {
conn: &'a mut Connection,
finished: bool,
}
impl<'a> Transaction<'a> {
#[inline]
pub fn connection(&mut self) -> &mut Connection {
self.conn
}
pub async fn commit(mut self) -> Result<()> {
self.finished = true;
self.conn.commit().await
}
pub async fn rollback(mut self) -> Result<()> {
self.finished = true;
self.conn.rollback().await
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if !self.finished {
if !ensure_mysql_thread_ready_for_drop() {
return;
}
let _ = unsafe { ffi::mysql_rollback(self.conn.mysql.as_ptr()) };
}
}
}
unsafe impl Send for Connection {}