use crate::{
buffers::BufferDescription,
execute::execute_with_parameters,
handles::{self, State, Statement, StatementImpl},
parameter_collection::ParameterCollection,
CursorImpl, Error, Preallocated, Prepared,
};
use std::{borrow::Cow, str, thread::panicking};
use widestring::{U16Str, U16String};
impl<'conn> Drop for Connection<'conn> {
fn drop(&mut self) {
match self.connection.disconnect() {
Ok(()) => (),
Err(Error::Diagnostics(record)) if record.state == State::INVALID_STATE_TRANSACTION => {
if let Err(e) = self.connection.rollback() {
if !panicking() {
panic!(
"Unexpected error rolling back transaction (In order to recover \
from invalid transaction state during disconnect): {:?}",
e
)
}
}
if let Err(e) = self.connection.disconnect() {
if !panicking() {
panic!("Unexpected error disconnecting): {:?}", e)
}
}
}
Err(e) => {
if !panicking() {
panic!("Unexpected error disconnecting: {:?}", e)
}
}
}
}
}
pub struct Connection<'c> {
connection: handles::Connection<'c>,
}
impl<'c> Connection<'c> {
pub(crate) fn new(connection: handles::Connection<'c>) -> Self {
Self { connection }
}
pub fn execute_utf16(
&self,
query: &U16Str,
params: impl ParameterCollection,
) -> Result<Option<CursorImpl<'_, StatementImpl<'_>>>, Error> {
let lazy_statement = move || self.connection.allocate_statement();
execute_with_parameters(lazy_statement, Some(query), params)
}
pub fn execute(
&self,
query: &str,
params: impl ParameterCollection,
) -> Result<Option<CursorImpl<'_, StatementImpl<'_>>>, Error> {
let query = U16String::from_str(query);
self.execute_utf16(&query, params)
}
pub fn prepare_utf16(&self, query: &U16Str) -> Result<Prepared<'_>, Error> {
let mut stmt = self.connection.allocate_statement()?;
stmt.prepare(query)?;
Ok(Prepared::new(stmt))
}
pub fn prepare(&self, query: &str) -> Result<Prepared<'_>, Error> {
let query = U16String::from_str(query);
self.prepare_utf16(&query)
}
pub fn preallocate(&self) -> Result<Preallocated<'_>, Error> {
let stmt = self.connection.allocate_statement()?;
Ok(Preallocated::new(stmt))
}
pub fn set_autocommit(&self, enabled: bool) -> Result<(), Error> {
self.connection.set_autocommit(enabled)
}
pub fn commit(&self) -> Result<(), Error> {
self.connection.commit()
}
pub fn rollback(&self) -> Result<(), Error> {
self.connection.rollback()
}
pub fn is_dead(&self) -> Result<bool, Error> {
self.connection.is_dead()
}
pub unsafe fn promote_to_send(self) -> force_send_sync::Send<Self> {
force_send_sync::Send::new(self)
}
pub fn fetch_database_management_system_name(&self, buf: &mut Vec<u16>) -> Result<(), Error> {
self.connection.fetch_database_management_system_name(buf)
}
pub fn database_management_system_name(&self) -> Result<String, Error> {
let mut buf = Vec::new();
self.fetch_database_management_system_name(&mut buf)?;
let name = U16String::from_vec(buf);
Ok(name.to_string().unwrap())
}
pub fn max_catalog_name_len(&self) -> Result<usize, Error> {
self.connection.max_catalog_name_len().map(|v| v as usize)
}
pub fn max_schema_name_len(&self) -> Result<usize, Error> {
self.connection.max_schema_name_len().map(|v| v as usize)
}
pub fn max_table_name_len(&self) -> Result<usize, Error> {
self.connection.max_table_name_len().map(|v| v as usize)
}
pub fn max_column_name_len(&self) -> Result<usize, Error> {
self.connection.max_column_name_len().map(|v| v as usize)
}
pub fn fetch_current_catalog(&self, buf: &mut Vec<u16>) -> Result<(), Error> {
self.connection.fetch_current_catalog(buf)
}
pub fn current_catalog(&self) -> Result<String, Error> {
let mut buf = Vec::new();
self.fetch_current_catalog(&mut buf)?;
let name = U16String::from_vec(buf);
Ok(name.to_string().unwrap())
}
pub fn columns(
&self,
catalog_name: &str,
schema_name: &str,
table_name: &str,
column_name: &str,
) -> Result<Option<CursorImpl<'_, StatementImpl<'_>>>, Error> {
self.connection.columns(
self.connection.allocate_statement()?,
&U16String::from_str(catalog_name),
&U16String::from_str(schema_name),
&U16String::from_str(table_name),
&U16String::from_str(column_name),
)
}
pub fn columns_buffer_description(
&self,
type_name_max_len: usize,
remarks_max_len: usize,
column_default_max_len: usize,
) -> Result<Vec<BufferDescription>, Error> {
self.connection.columns_buffer_description(
type_name_max_len,
remarks_max_len,
column_default_max_len,
)
}
}
pub fn escape_attribute_value(unescaped: &str) -> Cow<'_, str> {
if unescaped.contains(';') {
let escaped = unescaped.replace("}", "}}");
Cow::Owned(format!("{{{}}}", escaped))
} else {
Cow::Borrowed(unescaped)
}
}