use super::{
OutputStringBuffer, SqlResult,
any_handle::AnyHandle,
buffer::mut_buf_ptr,
drop_handle,
sql_char::{
SqlChar, SqlText, binary_length, is_truncated_bin, resize_to_fit_with_tz,
resize_to_fit_without_tz,
},
sql_result::ExtSqlReturn,
statement::StatementImpl,
};
use log::trace;
use odbc_sys::{
CompletionType, ConnectionAttribute, DriverConnectOption, HDbc, HEnv, HWnd, Handle, HandleType,
IS_UINTEGER, InfoType, Pointer, SQLAllocHandle, SQLDisconnect, SQLEndTran,
};
use std::{cmp::max, ffi::c_void, marker::PhantomData, mem::size_of, ptr::null_mut};
#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
use odbc_sys::{
SQLConnect as sql_connect, SQLDriverConnect as sql_driver_connect,
SQLGetConnectAttr as sql_get_connect_attr, SQLGetInfo as sql_get_info,
SQLSetConnectAttr as sql_set_connect_attr,
};
#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
use odbc_sys::{
SQLConnectW as sql_connect, SQLDriverConnectW as sql_driver_connect,
SQLGetConnectAttrW as sql_get_connect_attr, SQLGetInfoW as sql_get_info,
SQLSetConnectAttrW as sql_set_connect_attr,
};
pub struct Connection<'c> {
parent: PhantomData<&'c HEnv>,
handle: HDbc,
}
unsafe impl AnyHandle for Connection<'_> {
fn as_handle(&self) -> Handle {
self.handle.as_handle()
}
fn handle_type(&self) -> HandleType {
HandleType::Dbc
}
}
impl Drop for Connection<'_> {
fn drop(&mut self) {
unsafe {
drop_handle(self.handle.as_handle(), HandleType::Dbc);
}
}
}
unsafe impl Send for Connection<'_> {}
impl Connection<'_> {
pub unsafe fn new(handle: HDbc) -> Self {
Self {
handle,
parent: PhantomData,
}
}
pub fn as_sys(&self) -> HDbc {
self.handle
}
pub fn connect(
&mut self,
data_source_name: &SqlText,
user: &SqlText,
pwd: &SqlText,
) -> SqlResult<()> {
unsafe {
sql_connect(
self.handle,
data_source_name.ptr(),
data_source_name.len_char().try_into().unwrap(),
user.ptr(),
user.len_char().try_into().unwrap(),
pwd.ptr(),
pwd.len_char().try_into().unwrap(),
)
.into_sql_result("SQLConnect")
}
}
pub fn connect_with_connection_string(&mut self, connection_string: &SqlText) -> SqlResult<()> {
unsafe {
let parent_window = null_mut();
let mut completed_connection_string = OutputStringBuffer::empty();
self.driver_connect(
connection_string,
parent_window,
&mut completed_connection_string,
DriverConnectOption::NoPrompt,
)
.map(|_connection_string_is_complete| ())
}
}
pub unsafe fn driver_connect(
&mut self,
connection_string: &SqlText,
parent_window: HWnd,
completed_connection_string: &mut OutputStringBuffer,
driver_completion: DriverConnectOption,
) -> SqlResult<()> {
unsafe {
sql_driver_connect(
self.handle,
parent_window,
connection_string.ptr(),
connection_string.len_char().try_into().unwrap(),
completed_connection_string.mut_buf_ptr(),
completed_connection_string.buf_len(),
completed_connection_string.mut_actual_len_ptr(),
driver_completion,
)
.into_sql_result("SQLDriverConnect")
}
}
pub fn disconnect(&mut self) -> SqlResult<()> {
unsafe { SQLDisconnect(self.handle).into_sql_result("SQLDisconnect") }
}
pub fn allocate_statement(&self) -> SqlResult<StatementImpl<'_>> {
let mut out = Handle::null();
unsafe {
SQLAllocHandle(HandleType::Stmt, self.as_handle(), &mut out)
.into_sql_result("SQLAllocHandle")
.on_success(|| StatementImpl::new(out.as_hstmt()))
}
}
pub fn set_autocommit(&self, enabled: bool) -> SqlResult<()> {
unsafe { self.set_attribute(AutocommitConnectionAttribute(enabled)) }
}
pub fn set_login_timeout_sec(&self, timeout: u32) -> SqlResult<()> {
unsafe { self.set_attribute(LoginTimeoutConnectionAttribute(timeout)) }
}
pub fn set_packet_size(&self, packet_size: u32) -> SqlResult<()> {
unsafe { self.set_attribute(PacketSizeConnectionAttribute(packet_size)) }
}
pub fn commit(&self) -> SqlResult<()> {
unsafe {
SQLEndTran(HandleType::Dbc, self.as_handle(), CompletionType::Commit)
.into_sql_result("SQLEndTran")
}
}
pub fn rollback(&self) -> SqlResult<()> {
unsafe {
SQLEndTran(HandleType::Dbc, self.as_handle(), CompletionType::Rollback)
.into_sql_result("SQLEndTran")
}
}
pub fn fetch_database_management_system_name(&self, buf: &mut Vec<SqlChar>) -> SqlResult<()> {
let mut string_length_in_bytes: i16 = 0;
let buffer_size = max(buf.capacity(), 64);
buf.resize(buffer_size, 0);
unsafe {
let mut res = sql_get_info(
self.handle,
InfoType::DbmsName,
mut_buf_ptr(buf) as Pointer,
binary_length(buf).try_into().unwrap(),
&mut string_length_in_bytes as *mut i16,
)
.into_sql_result("SQLGetInfo");
if res.is_err() {
return res;
}
if is_truncated_bin(buf, string_length_in_bytes.try_into().unwrap()) {
resize_to_fit_with_tz(buf, string_length_in_bytes.try_into().unwrap());
res = sql_get_info(
self.handle,
InfoType::DbmsName,
mut_buf_ptr(buf) as Pointer,
binary_length(buf).try_into().unwrap(),
&mut string_length_in_bytes as *mut i16,
)
.into_sql_result("SQLGetInfo");
if res.is_err() {
return res;
}
}
resize_to_fit_without_tz(buf, string_length_in_bytes.try_into().unwrap());
res
}
}
fn info_u16(&self, info_type: InfoType) -> SqlResult<u16> {
unsafe {
let mut value = 0u16;
sql_get_info(
self.handle,
info_type,
&mut value as *mut u16 as Pointer,
size_of::<*mut u16>() as i16,
null_mut(),
)
.into_sql_result("SQLGetInfo")
.on_success(|| value)
}
}
pub fn max_catalog_name_len(&self) -> SqlResult<u16> {
self.info_u16(InfoType::MaxCatalogNameLen)
}
pub fn max_schema_name_len(&self) -> SqlResult<u16> {
self.info_u16(InfoType::MaxSchemaNameLen)
}
pub fn max_table_name_len(&self) -> SqlResult<u16> {
self.info_u16(InfoType::MaxTableNameLen)
}
pub fn max_column_name_len(&self) -> SqlResult<u16> {
self.info_u16(InfoType::MaxColumnNameLen)
}
pub fn fetch_current_catalog(&self, buffer: &mut Vec<SqlChar>) -> SqlResult<()> {
let mut string_length_in_bytes: i32 = 0;
buffer.resize(buffer.capacity(), 0);
unsafe {
let mut res = sql_get_connect_attr(
self.handle,
ConnectionAttribute::CURRENT_CATALOG,
mut_buf_ptr(buffer) as Pointer,
binary_length(buffer).try_into().unwrap(),
&mut string_length_in_bytes as *mut i32,
)
.into_sql_result("SQLGetConnectAttr");
if res.is_err() {
return res;
}
if is_truncated_bin(buffer, string_length_in_bytes.try_into().unwrap()) {
resize_to_fit_with_tz(buffer, string_length_in_bytes.try_into().unwrap());
res = sql_get_connect_attr(
self.handle,
ConnectionAttribute::CURRENT_CATALOG,
mut_buf_ptr(buffer) as Pointer,
binary_length(buffer).try_into().unwrap(),
&mut string_length_in_bytes as *mut i32,
)
.into_sql_result("SQLGetConnectAttr");
}
if res.is_err() {
return res;
}
resize_to_fit_without_tz(buffer, string_length_in_bytes.try_into().unwrap());
res
}
}
pub fn is_dead(&self) -> SqlResult<bool> {
unsafe {
self.attribute_u32(ConnectionAttribute::CONNECTION_DEAD)
.map(|v| match v {
0 => false,
1 => true,
other => panic!("Unexpected result value from SQLGetConnectAttr: {other}"),
})
}
}
pub fn packet_size(&self) -> SqlResult<u32> {
unsafe { self.attribute_u32(ConnectionAttribute::PACKET_SIZE) }
}
pub unsafe fn set_attribute(&self, attribute: impl SetConnectionAttribute) -> SqlResult<()> {
unsafe {
sql_set_connect_attr(
self.handle,
attribute.attribute(),
attribute.value(),
attribute.len(),
)
.into_sql_result("SQLSetConnectAttr")
}
}
unsafe fn attribute_u32(&self, attribute: ConnectionAttribute) -> SqlResult<u32> {
let mut out: u32 = 0;
unsafe {
sql_get_connect_attr(
self.handle,
attribute,
&mut out as *mut u32 as *mut c_void,
IS_UINTEGER,
null_mut(),
)
}
.into_sql_result("SQLGetConnectAttr")
.on_success(|| {
let handle = self.handle;
#[cfg(not(feature = "structured_logging"))]
trace!(
"SQLGetConnectAttr called with attribute '{attribute:?}' for connection \
'{handle:?}' reported '{out}'."
);
#[cfg(feature = "structured_logging")]
trace!(
target: "odbc_api",
attribute:? = attribute,
handle:? = handle,
value = out;
"Connection attribute queried"
);
out
})
}
}
pub unsafe trait SetConnectionAttribute {
fn attribute(&self) -> ConnectionAttribute;
fn value(&self) -> Pointer;
fn len(&self) -> i32;
}
struct PacketSizeConnectionAttribute(pub u32);
unsafe impl SetConnectionAttribute for PacketSizeConnectionAttribute {
fn attribute(&self) -> ConnectionAttribute {
ConnectionAttribute::PACKET_SIZE
}
fn value(&self) -> Pointer {
self.0 as Pointer
}
fn len(&self) -> i32 {
IS_UINTEGER }
}
struct LoginTimeoutConnectionAttribute(pub u32);
unsafe impl SetConnectionAttribute for LoginTimeoutConnectionAttribute {
fn attribute(&self) -> ConnectionAttribute {
ConnectionAttribute::LOGIN_TIMEOUT
}
fn value(&self) -> Pointer {
self.0 as Pointer
}
fn len(&self) -> i32 {
IS_UINTEGER }
}
struct AutocommitConnectionAttribute(pub bool);
unsafe impl SetConnectionAttribute for AutocommitConnectionAttribute {
fn attribute(&self) -> ConnectionAttribute {
ConnectionAttribute::AUTOCOMMIT
}
fn value(&self) -> Pointer {
(self.0 as u32) as Pointer
}
fn len(&self) -> i32 {
IS_UINTEGER }
}