use super::{
as_handle::AsHandle,
buffer::mut_buf_ptr,
drop_handle,
sql_char::{
binary_length, is_truncated_bin, resize_to_fit_with_tz, resize_to_fit_without_tz, SqlChar,
SqlText,
},
sql_result::ExtSqlReturn,
statement::StatementImpl,
OutputStringBuffer, SqlResult,
};
use log::debug;
use odbc_sys::{
CompletionType, ConnectionAttribute, DriverConnectOption, HDbc, HEnv, HStmt, HWnd, Handle,
HandleType, InfoType, Pointer, SQLAllocHandle, SQLDisconnect, SQLEndTran, IS_UINTEGER,
};
use std::{ffi::c_void, marker::PhantomData, mem::size_of, ptr::null_mut};
#[cfg(feature = "narrow")]
use odbc_sys::{
SQLConnect as sql_connect, SQLDriverConnect as sql_driver_connect,
SQLGetConnectAttr as sql_get_connect_attr, SQLGetInfo as sql_get_info,
SQLSetConnectAttr as sql_set_connect_attr,
};
#[cfg(not(feature = "narrow"))]
use odbc_sys::{
SQLConnectW as sql_connect, SQLDriverConnectW as sql_driver_connect,
SQLGetConnectAttrW as sql_get_connect_attr, SQLGetInfoW as sql_get_info,
SQLSetConnectAttrW as sql_set_connect_attr,
};
pub struct Connection<'c> {
parent: PhantomData<&'c HEnv>,
handle: HDbc,
}
unsafe impl<'c> AsHandle for Connection<'c> {
fn as_handle(&self) -> Handle {
self.handle as Handle
}
fn handle_type(&self) -> HandleType {
HandleType::Dbc
}
}
impl<'c> Drop for Connection<'c> {
fn drop(&mut self) {
unsafe {
drop_handle(self.handle as Handle, HandleType::Dbc);
}
}
}
unsafe impl<'c> Send for Connection<'c> {}
impl<'c> Connection<'c> {
pub unsafe fn new(handle: HDbc) -> Self {
Self {
handle,
parent: PhantomData,
}
}
pub fn as_sys(&self) -> HDbc {
self.handle
}
pub fn connect(
&mut self,
data_source_name: &SqlText,
user: &SqlText,
pwd: &SqlText,
) -> SqlResult<()> {
unsafe {
sql_connect(
self.handle,
data_source_name.ptr(),
data_source_name.len_char().try_into().unwrap(),
user.ptr(),
user.len_char().try_into().unwrap(),
pwd.ptr(),
pwd.len_char().try_into().unwrap(),
)
.into_sql_result("SQLConnect")
}
}
pub fn connect_with_connection_string(&mut self, connection_string: &SqlText) -> SqlResult<()> {
unsafe {
let parent_window = null_mut();
let mut completed_connection_string = OutputStringBuffer::empty();
self.driver_connect(
connection_string,
parent_window,
&mut completed_connection_string,
DriverConnectOption::NoPrompt,
)
.map(|_connection_string_is_complete| ())
}
}
pub unsafe fn driver_connect(
&mut self,
connection_string: &SqlText,
parent_window: HWnd,
completed_connection_string: &mut OutputStringBuffer,
driver_completion: DriverConnectOption,
) -> SqlResult<()> {
sql_driver_connect(
self.handle,
parent_window,
connection_string.ptr(),
connection_string.len_char().try_into().unwrap(),
completed_connection_string.mut_buf_ptr(),
completed_connection_string.buf_len(),
completed_connection_string.mut_actual_len_ptr(),
driver_completion,
)
.into_sql_result("SQLDriverConnect")
}
pub fn disconnect(&mut self) -> SqlResult<()> {
unsafe { SQLDisconnect(self.handle).into_sql_result("SQLDisconnect") }
}
pub fn allocate_statement(&self) -> SqlResult<StatementImpl<'_>> {
let mut out = null_mut();
unsafe {
SQLAllocHandle(HandleType::Stmt, self.as_handle(), &mut out)
.into_sql_result("SQLAllocHandle")
.on_success(|| StatementImpl::new(out as HStmt))
}
}
pub fn set_autocommit(&self, enabled: bool) -> SqlResult<()> {
let val = enabled as u32;
unsafe {
sql_set_connect_attr(
self.handle,
ConnectionAttribute::AutoCommit,
val as Pointer,
0, )
.into_sql_result("SQLSetConnectAttr")
}
}
pub fn set_login_timeout_sec(&self, timeout: u32) -> SqlResult<()> {
unsafe {
sql_set_connect_attr(
self.handle,
ConnectionAttribute::LoginTimeout,
timeout as Pointer,
0,
)
.into_sql_result("SQLSetConnectAttr")
}
}
pub fn set_packet_size(&self, packet_size: u32) -> SqlResult<()> {
unsafe {
sql_set_connect_attr(
self.handle,
ConnectionAttribute::PacketSize,
packet_size as Pointer,
0,
)
.into_sql_result("SQLSetConnectAttr")
}
}
pub fn commit(&self) -> SqlResult<()> {
unsafe {
SQLEndTran(HandleType::Dbc, self.as_handle(), CompletionType::Commit)
.into_sql_result("SQLEndTran")
}
}
pub fn rollback(&self) -> SqlResult<()> {
unsafe {
SQLEndTran(HandleType::Dbc, self.as_handle(), CompletionType::Rollback)
.into_sql_result("SQLEndTran")
}
}
pub fn fetch_database_management_system_name(&self, buf: &mut Vec<SqlChar>) -> SqlResult<()> {
let mut string_length_in_bytes: i16 = 0;
buf.resize(buf.capacity(), 0);
unsafe {
let mut res = sql_get_info(
self.handle,
InfoType::DbmsName,
mut_buf_ptr(buf) as Pointer,
binary_length(buf).try_into().unwrap(),
&mut string_length_in_bytes as *mut i16,
)
.into_sql_result("SQLGetInfo");
if res.is_err() {
return res;
}
if is_truncated_bin(buf, string_length_in_bytes.try_into().unwrap()) {
resize_to_fit_with_tz(buf, string_length_in_bytes.try_into().unwrap());
res = sql_get_info(
self.handle,
InfoType::DbmsName,
mut_buf_ptr(buf) as Pointer,
binary_length(buf).try_into().unwrap(),
&mut string_length_in_bytes as *mut i16,
)
.into_sql_result("SQLGetInfo");
if res.is_err() {
return res;
}
}
resize_to_fit_without_tz(buf, string_length_in_bytes.try_into().unwrap());
res
}
}
fn info_u16(&self, info_type: InfoType) -> SqlResult<u16> {
unsafe {
let mut value = 0u16;
sql_get_info(
self.handle,
info_type,
&mut value as *mut u16 as Pointer,
size_of::<*mut u16>() as i16,
null_mut(),
)
.into_sql_result("SQLGetInfo")
.on_success(|| value)
}
}
pub fn max_catalog_name_len(&self) -> SqlResult<u16> {
self.info_u16(InfoType::MaxCatalogNameLen)
}
pub fn max_schema_name_len(&self) -> SqlResult<u16> {
self.info_u16(InfoType::MaxSchemaNameLen)
}
pub fn max_table_name_len(&self) -> SqlResult<u16> {
self.info_u16(InfoType::MaxTableNameLen)
}
pub fn max_column_name_len(&self) -> SqlResult<u16> {
self.info_u16(InfoType::MaxColumnNameLen)
}
pub fn fetch_current_catalog(&self, buffer: &mut Vec<SqlChar>) -> SqlResult<()> {
let mut string_length_in_bytes: i32 = 0;
buffer.resize(buffer.capacity(), 0);
unsafe {
let mut res = sql_get_connect_attr(
self.handle,
ConnectionAttribute::CurrentCatalog,
mut_buf_ptr(buffer) as Pointer,
binary_length(buffer).try_into().unwrap(),
&mut string_length_in_bytes as *mut i32,
)
.into_sql_result("SQLGetConnectAttr");
if res.is_err() {
return res;
}
if is_truncated_bin(buffer, string_length_in_bytes.try_into().unwrap()) {
resize_to_fit_with_tz(buffer, string_length_in_bytes.try_into().unwrap());
res = sql_get_connect_attr(
self.handle,
ConnectionAttribute::CurrentCatalog,
mut_buf_ptr(buffer) as Pointer,
binary_length(buffer).try_into().unwrap(),
&mut string_length_in_bytes as *mut i32,
)
.into_sql_result("SQLGetConnectAttr");
}
if res.is_err() {
return res;
}
resize_to_fit_without_tz(buffer, string_length_in_bytes.try_into().unwrap());
res
}
}
pub fn is_dead(&self) -> SqlResult<bool> {
unsafe {
self.attribute_u32(ConnectionAttribute::ConnectionDead)
.map(|v| match v {
0 => false,
1 => true,
other => panic!("Unexpected result value from SQLGetConnectAttr: {other}"),
})
}
}
pub fn packet_size(&self) -> SqlResult<u32> {
unsafe { self.attribute_u32(ConnectionAttribute::PacketSize) }
}
unsafe fn attribute_u32(&self, attribute: ConnectionAttribute) -> SqlResult<u32> {
let mut out: u32 = 0;
sql_get_connect_attr(
self.handle,
attribute,
&mut out as *mut u32 as *mut c_void,
IS_UINTEGER,
null_mut(),
)
.into_sql_result("SQLGetConnectAttr")
.on_success(|| {
let handle = self.handle;
debug!(
"SQLGetConnectAttr called with attribute '{attribute:?}' for connection \
'{handle:?}' reported '{out}'."
);
out
})
}
}