use crate::{
CursorImpl, CursorPolling, Error, ParameterCollectionRef, Preallocated, Prepared, Sleep,
buffers::BufferDesc,
execute::{
execute_columns, execute_foreign_keys, execute_tables, execute_with_parameters_polling,
},
handles::{
self, SqlText, State, Statement, StatementConnection, StatementImpl, StatementParent,
slice_to_utf8,
},
};
use log::error;
use std::{
borrow::Cow,
fmt::{self, Debug, Display},
mem::{ManuallyDrop, MaybeUninit},
ptr, str,
sync::Arc,
thread::panicking,
};
impl Drop for Connection<'_> {
fn drop(&mut self) {
match self.connection.disconnect().into_result(&self.connection) {
Ok(()) => (),
Err(Error::Diagnostics {
record,
function: _,
}) if record.state == State::INVALID_STATE_TRANSACTION => {
if let Err(e) = self.rollback() {
error!(
"Error during rolling back transaction (In order to recover from \
invalid transaction state during disconnect {}",
e
);
}
if let Err(e) = self.connection.disconnect().into_result(&self.connection) {
if !panicking() {
panic!("Unexpected error disconnecting (after rollback attempt): {e:?}")
}
}
}
Err(e) => {
if !panicking() {
panic!("Unexpected error disconnecting: {e:?}")
}
}
}
}
}
pub struct Connection<'c> {
connection: handles::Connection<'c>,
}
impl<'c> Connection<'c> {
pub(crate) fn new(connection: handles::Connection<'c>) -> Self {
Self { connection }
}
pub fn into_handle(self) -> handles::Connection<'c> {
let dont_drop_me = MaybeUninit::new(self);
let self_ptr = dont_drop_me.as_ptr();
unsafe { ptr::read(&(*self_ptr).connection) }
}
pub fn execute(
&self,
query: &str,
params: impl ParameterCollectionRef,
query_timeout_sec: Option<usize>,
) -> Result<Option<CursorImpl<StatementImpl<'_>>>, Error> {
if params.parameter_set_size() == 0 {
return Ok(None);
}
let mut statement = self.preallocate()?;
if let Some(seconds) = query_timeout_sec {
statement.set_query_timeout_sec(seconds)?;
}
statement.into_cursor(query, params)
}
pub async fn execute_polling(
&self,
query: &str,
params: impl ParameterCollectionRef,
sleep: impl Sleep,
) -> Result<Option<CursorPolling<StatementImpl<'_>>>, Error> {
if params.parameter_set_size() == 0 {
return Ok(None);
}
let query = SqlText::new(query);
let mut statement = self.allocate_statement()?;
statement.set_async_enable(true).into_result(&statement)?;
execute_with_parameters_polling(statement, Some(&query), params, sleep).await
}
pub fn into_cursor(
self,
query: &str,
params: impl ParameterCollectionRef,
query_timeout_sec: Option<usize>,
) -> Result<Option<CursorImpl<StatementConnection<Connection<'c>>>>, ConnectionAndError<'c>>
{
let mut error = None;
let mut cursor = None;
match self.execute(query, params, query_timeout_sec) {
Ok(Some(c)) => cursor = Some(c),
Ok(None) => return Ok(None),
Err(e) => error = Some(e),
};
if let Some(e) = error {
drop(cursor);
return Err(ConnectionAndError {
error: e,
previous: self,
});
}
let cursor = cursor.unwrap();
let mut cursor = ManuallyDrop::new(cursor);
let handle = cursor.as_sys();
let statement = unsafe { StatementConnection::new(handle, self) };
let cursor = unsafe { CursorImpl::new(statement) };
Ok(Some(cursor))
}
pub fn prepare(&self, query: &str) -> Result<Prepared<StatementImpl<'_>>, Error> {
let query = SqlText::new(query);
let mut stmt = self.allocate_statement()?;
stmt.prepare(&query).into_result(&stmt)?;
Ok(Prepared::new(stmt))
}
pub fn into_prepared(
self,
query: &str,
) -> Result<Prepared<StatementConnection<Connection<'c>>>, Error> {
let query = SqlText::new(query);
let mut stmt = self.allocate_statement()?;
stmt.prepare(&query).into_result(&stmt)?;
let stmt = unsafe { StatementConnection::new(stmt.into_sys(), self) };
Ok(Prepared::new(stmt))
}
pub fn preallocate(&self) -> Result<Preallocated<StatementImpl<'_>>, Error> {
let stmt = self.allocate_statement()?;
unsafe { Ok(Preallocated::new(stmt)) }
}
pub fn into_preallocated(
self,
) -> Result<Preallocated<StatementConnection<Connection<'c>>>, Error> {
let stmt = self.allocate_statement()?;
unsafe {
let stmt = StatementConnection::new(stmt.into_sys(), self);
Ok(Preallocated::new(stmt))
}
}
pub fn set_autocommit(&self, enabled: bool) -> Result<(), Error> {
self.connection
.set_autocommit(enabled)
.into_result(&self.connection)
}
pub fn commit(&self) -> Result<(), Error> {
self.connection.commit().into_result(&self.connection)
}
pub fn rollback(&self) -> Result<(), Error> {
self.connection.rollback().into_result(&self.connection)
}
pub fn is_dead(&self) -> Result<bool, Error> {
self.connection.is_dead().into_result(&self.connection)
}
pub fn packet_size(&self) -> Result<u32, Error> {
self.connection.packet_size().into_result(&self.connection)
}
pub fn database_management_system_name(&self) -> Result<String, Error> {
let mut buf = Vec::new();
self.connection
.fetch_database_management_system_name(&mut buf)
.into_result(&self.connection)?;
let name = slice_to_utf8(&buf).unwrap();
Ok(name)
}
pub fn max_catalog_name_len(&self) -> Result<u16, Error> {
self.connection
.max_catalog_name_len()
.into_result(&self.connection)
}
pub fn max_schema_name_len(&self) -> Result<u16, Error> {
self.connection
.max_schema_name_len()
.into_result(&self.connection)
}
pub fn max_table_name_len(&self) -> Result<u16, Error> {
self.connection
.max_table_name_len()
.into_result(&self.connection)
}
pub fn max_column_name_len(&self) -> Result<u16, Error> {
self.connection
.max_column_name_len()
.into_result(&self.connection)
}
pub fn current_catalog(&self) -> Result<String, Error> {
let mut buf = Vec::new();
self.connection
.fetch_current_catalog(&mut buf)
.into_result(&self.connection)?;
let name = slice_to_utf8(&buf).expect("Return catalog must be correctly encoded");
Ok(name)
}
pub fn columns(
&self,
catalog_name: &str,
schema_name: &str,
table_name: &str,
column_name: &str,
) -> Result<CursorImpl<StatementImpl<'_>>, Error> {
execute_columns(
self.allocate_statement()?,
&SqlText::new(catalog_name),
&SqlText::new(schema_name),
&SqlText::new(table_name),
&SqlText::new(column_name),
)
}
pub fn tables(
&self,
catalog_name: &str,
schema_name: &str,
table_name: &str,
table_type: &str,
) -> Result<CursorImpl<StatementImpl<'_>>, Error> {
let statement = self.allocate_statement()?;
execute_tables(
statement,
&SqlText::new(catalog_name),
&SqlText::new(schema_name),
&SqlText::new(table_name),
&SqlText::new(table_type),
)
}
pub fn foreign_keys(
&self,
pk_catalog_name: &str,
pk_schema_name: &str,
pk_table_name: &str,
fk_catalog_name: &str,
fk_schema_name: &str,
fk_table_name: &str,
) -> Result<CursorImpl<StatementImpl<'_>>, Error> {
let statement = self.allocate_statement()?;
execute_foreign_keys(
statement,
&SqlText::new(pk_catalog_name),
&SqlText::new(pk_schema_name),
&SqlText::new(pk_table_name),
&SqlText::new(fk_catalog_name),
&SqlText::new(fk_schema_name),
&SqlText::new(fk_table_name),
)
}
pub fn columns_buffer_descs(
&self,
type_name_max_len: usize,
remarks_max_len: usize,
column_default_max_len: usize,
) -> Result<Vec<BufferDesc>, Error> {
let null_i16 = BufferDesc::I16 { nullable: true };
let not_null_i16 = BufferDesc::I16 { nullable: false };
let null_i32 = BufferDesc::I32 { nullable: true };
let catalog_name_desc = BufferDesc::Text {
max_str_len: self.max_catalog_name_len()? as usize,
};
let schema_name_desc = BufferDesc::Text {
max_str_len: self.max_schema_name_len()? as usize,
};
let table_name_desc = BufferDesc::Text {
max_str_len: self.max_table_name_len()? as usize,
};
let column_name_desc = BufferDesc::Text {
max_str_len: self.max_column_name_len()? as usize,
};
let data_type_desc = not_null_i16;
let type_name_desc = BufferDesc::Text {
max_str_len: type_name_max_len,
};
let column_size_desc = null_i32;
let buffer_len_desc = null_i32;
let decimal_digits_desc = null_i16;
let precision_radix_desc = null_i16;
let nullable_desc = not_null_i16;
let remarks_desc = BufferDesc::Text {
max_str_len: remarks_max_len,
};
let column_default_desc = BufferDesc::Text {
max_str_len: column_default_max_len,
};
let sql_data_type_desc = not_null_i16;
let sql_datetime_sub_desc = null_i16;
let char_octet_len_desc = null_i32;
let ordinal_pos_desc = BufferDesc::I32 { nullable: false };
const IS_NULLABLE_LEN_MAX_LEN: usize = 3;
let is_nullable_desc = BufferDesc::Text {
max_str_len: IS_NULLABLE_LEN_MAX_LEN,
};
Ok(vec![
catalog_name_desc,
schema_name_desc,
table_name_desc,
column_name_desc,
data_type_desc,
type_name_desc,
column_size_desc,
buffer_len_desc,
decimal_digits_desc,
precision_radix_desc,
nullable_desc,
remarks_desc,
column_default_desc,
sql_data_type_desc,
sql_datetime_sub_desc,
char_octet_len_desc,
ordinal_pos_desc,
is_nullable_desc,
])
}
fn allocate_statement(&self) -> Result<StatementImpl<'_>, Error> {
self.connection
.allocate_statement()
.into_result(&self.connection)
}
}
impl Debug for Connection<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Connection")
}
}
unsafe impl StatementParent for Connection<'_> {}
unsafe impl StatementParent for Arc<Connection<'_>> {}
#[derive(Default, Clone, Copy)]
pub struct ConnectionOptions {
pub login_timeout_sec: Option<u32>,
pub packet_size: Option<u32>,
}
impl ConnectionOptions {
pub fn apply(&self, handle: &handles::Connection) -> Result<(), Error> {
if let Some(timeout) = self.login_timeout_sec {
handle.set_login_timeout_sec(timeout).into_result(handle)?;
}
if let Some(packet_size) = self.packet_size {
handle.set_packet_size(packet_size).into_result(handle)?;
}
Ok(())
}
}
pub fn escape_attribute_value(unescaped: &str) -> Cow<'_, str> {
if unescaped.contains(&[';', '+'][..]) {
let escaped = unescaped.replace('}', "}}");
Cow::Owned(format!("{{{escaped}}}"))
} else {
Cow::Borrowed(unescaped)
}
}
#[derive(Debug)]
pub struct FailedStateTransition<S> {
pub error: Error,
pub previous: S,
}
impl<S> From<FailedStateTransition<S>> for Error {
fn from(value: FailedStateTransition<S>) -> Self {
value.error
}
}
impl<S> Display for FailedStateTransition<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.error)
}
}
impl<S> std::error::Error for FailedStateTransition<S>
where
S: Debug,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.error.source()
}
}
type ConnectionAndError<'conn> = FailedStateTransition<Connection<'conn>>;
pub trait ConnectionTransitions: Sized {
type StatementParent: StatementParent;
fn into_cursor(
self,
query: &str,
params: impl ParameterCollectionRef,
query_timeout_sec: Option<usize>,
) -> Result<
Option<CursorImpl<StatementConnection<Self::StatementParent>>>,
FailedStateTransition<Self>,
>;
fn into_prepared(
self,
query: &str,
) -> Result<Prepared<StatementConnection<Self::StatementParent>>, Error>;
fn into_preallocated(
self,
) -> Result<Preallocated<StatementConnection<Self::StatementParent>>, Error>;
}
impl<'env> ConnectionTransitions for Connection<'env> {
type StatementParent = Self;
fn into_cursor(
self,
query: &str,
params: impl ParameterCollectionRef,
query_timeout_sec: Option<usize>,
) -> Result<Option<CursorImpl<StatementConnection<Self>>>, FailedStateTransition<Self>> {
self.into_cursor(query, params, query_timeout_sec)
}
fn into_prepared(self, query: &str) -> Result<Prepared<StatementConnection<Self>>, Error> {
self.into_prepared(query)
}
fn into_preallocated(self) -> Result<Preallocated<StatementConnection<Self>>, Error> {
self.into_preallocated()
}
}
impl<'env> ConnectionTransitions for Arc<Connection<'env>> {
type StatementParent = Self;
fn into_cursor(
self,
query: &str,
params: impl ParameterCollectionRef,
query_timeout_sec: Option<usize>,
) -> Result<Option<CursorImpl<StatementConnection<Self>>>, FailedStateTransition<Self>> {
let result = self.execute(query, params, query_timeout_sec);
let maybe_stmt_ptr = result
.map(|opt| opt.map(|cursor| cursor.into_stmt().into_sys()))
.map_err(|error| {
FailedStateTransition {
error,
previous: Arc::clone(&self),
}
})?;
let Some(stmt_ptr) = maybe_stmt_ptr else {
return Ok(None);
};
let stmt = unsafe { StatementConnection::new(stmt_ptr, self) };
let cursor = unsafe { CursorImpl::new(stmt) };
Ok(Some(cursor))
}
fn into_prepared(self, query: &str) -> Result<Prepared<StatementConnection<Self>>, Error> {
let stmt = self.prepare(query)?;
let stmt_ptr = stmt.into_handle().into_sys();
let stmt = unsafe { StatementConnection::new(stmt_ptr, self) };
let prepared = Prepared::new(stmt);
Ok(prepared)
}
fn into_preallocated(self) -> Result<Preallocated<StatementConnection<Self>>, Error> {
let stmt = self.preallocate()?;
let stmt_ptr = stmt.into_handle().into_sys();
let stmt = unsafe { StatementConnection::new(stmt_ptr, self) };
let preallocated = unsafe { Preallocated::new(stmt) };
Ok(preallocated)
}
}