use super::{
CData, Descriptor, SqlChar, SqlResult, SqlText,
any_handle::AnyHandle,
bind::{CDataMut, DelayedInput, HasDataType},
buffer::{clamp_small_int, mut_buf_ptr},
column_description::{ColumnDescription, Nullability},
data_type::DataType,
drop_handle,
sql_char::{binary_length, is_truncated_bin, resize_to_fit_without_tz},
sql_result::ExtSqlReturn,
};
use log::trace;
use odbc_sys::{
Desc, FreeStmtOption, HDbc, HDesc, HStmt, Handle, HandleType, IS_POINTER, Len, ParamType,
Pointer, SQLBindCol, SQLBindParameter, SQLCloseCursor, SQLDescribeParam, SQLExecute, SQLFetch,
SQLFreeStmt, SQLGetData, SQLMoreResults, SQLNumParams, SQLNumResultCols, SQLParamData,
SQLPutData, SQLRowCount, SqlDataType, SqlReturn, StatementAttribute,
};
use std::{
ffi::c_void,
marker::PhantomData,
mem::ManuallyDrop,
num::NonZeroUsize,
ptr::{null, null_mut},
};
#[cfg(feature = "odbc_version_3_80")]
use odbc_sys::SQLCompleteAsync;
#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
use odbc_sys::{
SQLColAttribute as sql_col_attribute, SQLColumns as sql_columns,
SQLDescribeCol as sql_describe_col, SQLExecDirect as sql_exec_direc,
SQLForeignKeys as sql_foreign_keys, SQLGetStmtAttr as sql_get_stmt_attr,
SQLPrepare as sql_prepare, SQLPrimaryKeys as sql_primary_keys,
SQLSetStmtAttr as sql_set_stmt_attr, SQLTables as sql_tables,
};
#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
use odbc_sys::{
SQLColAttributeW as sql_col_attribute, SQLColumnsW as sql_columns,
SQLDescribeColW as sql_describe_col, SQLExecDirectW as sql_exec_direc,
SQLForeignKeysW as sql_foreign_keys, SQLGetStmtAttrW as sql_get_stmt_attr,
SQLPrepareW as sql_prepare, SQLPrimaryKeysW as sql_primary_keys,
SQLSetStmtAttrW as sql_set_stmt_attr, SQLTablesW as sql_tables,
};
#[derive(Debug)]
pub struct StatementImpl<'s> {
parent: PhantomData<&'s HDbc>,
handle: HStmt,
}
unsafe impl AnyHandle for StatementImpl<'_> {
fn as_handle(&self) -> Handle {
self.handle.as_handle()
}
fn handle_type(&self) -> HandleType {
HandleType::Stmt
}
}
impl Drop for StatementImpl<'_> {
fn drop(&mut self) {
unsafe {
drop_handle(self.handle.as_handle(), HandleType::Stmt);
}
}
}
impl StatementImpl<'_> {
pub unsafe fn new(handle: HStmt) -> Self {
Self {
handle,
parent: PhantomData,
}
}
pub fn into_sys(self) -> HStmt {
ManuallyDrop::new(self).handle
}
pub fn as_stmt_ref(&mut self) -> StatementRef<'_> {
StatementRef {
parent: self.parent,
handle: self.handle,
}
}
}
unsafe impl Send for StatementImpl<'_> {}
#[derive(Debug)]
pub struct StatementRef<'s> {
parent: PhantomData<&'s HDbc>,
handle: HStmt,
}
impl StatementRef<'_> {
pub(crate) unsafe fn new(handle: HStmt) -> Self {
Self {
handle,
parent: PhantomData,
}
}
}
impl Statement for StatementRef<'_> {
fn as_sys(&self) -> HStmt {
self.handle
}
}
unsafe impl AnyHandle for StatementRef<'_> {
fn as_handle(&self) -> Handle {
self.handle.as_handle()
}
fn handle_type(&self) -> HandleType {
HandleType::Stmt
}
}
unsafe impl Send for StatementRef<'_> {}
pub trait AsStatementRef {
fn as_stmt_ref(&mut self) -> StatementRef<'_>;
}
impl AsStatementRef for StatementImpl<'_> {
fn as_stmt_ref(&mut self) -> StatementRef<'_> {
self.as_stmt_ref()
}
}
impl AsStatementRef for &mut StatementImpl<'_> {
fn as_stmt_ref(&mut self) -> StatementRef<'_> {
(*self).as_stmt_ref()
}
}
impl AsStatementRef for StatementRef<'_> {
fn as_stmt_ref(&mut self) -> StatementRef<'_> {
unsafe { StatementRef::new(self.handle) }
}
}
pub trait Statement: AnyHandle {
fn as_sys(&self) -> HStmt;
unsafe fn bind_col(&mut self, column_number: u16, target: &mut impl CDataMut) -> SqlResult<()> {
unsafe {
SQLBindCol(
self.as_sys(),
column_number,
target.cdata_type(),
target.mut_value_ptr(),
target.buffer_length(),
target.mut_indicator_ptr(),
)
}
.into_sql_result("SQLBindCol")
}
unsafe fn fetch(&mut self) -> SqlResult<()> {
unsafe { SQLFetch(self.as_sys()) }.into_sql_result("SQLFetch")
}
fn get_data(&mut self, col_or_param_num: u16, target: &mut impl CDataMut) -> SqlResult<()> {
unsafe {
SQLGetData(
self.as_sys(),
col_or_param_num,
target.cdata_type(),
target.mut_value_ptr(),
target.buffer_length(),
target.mut_indicator_ptr(),
)
}
.into_sql_result("SQLGetData")
}
fn unbind_cols(&mut self) -> SqlResult<()> {
unsafe { SQLFreeStmt(self.as_sys(), FreeStmtOption::Unbind) }.into_sql_result("SQLFreeStmt")
}
unsafe fn set_num_rows_fetched(&mut self, num_rows: &mut usize) -> SqlResult<()> {
let value = num_rows as *mut usize as Pointer;
unsafe {
sql_set_stmt_attr(
self.as_sys(),
StatementAttribute::RowsFetchedPtr,
value,
IS_POINTER,
)
}
.into_sql_result("SQLSetStmtAttr")
}
fn set_query_timeout_sec(&mut self, timeout_sec: usize) -> SqlResult<()> {
let value = timeout_sec as *mut usize as Pointer;
unsafe { sql_set_stmt_attr(self.as_sys(), StatementAttribute::QueryTimeout, value, 0) }
.into_sql_result("SQLSetStmtAttr")
}
fn query_timeout_sec(&mut self) -> SqlResult<usize> {
let mut out: usize = 0;
let value = &mut out as *mut usize as Pointer;
unsafe {
sql_get_stmt_attr(
self.as_sys(),
StatementAttribute::QueryTimeout,
value,
0,
null_mut(),
)
}
.into_sql_result("SQLGetStmtAttr")
.on_success(|| out)
}
fn unset_num_rows_fetched(&mut self) -> SqlResult<()> {
unsafe {
sql_set_stmt_attr(
self.as_sys(),
StatementAttribute::RowsFetchedPtr,
null_mut(),
IS_POINTER,
)
.into_sql_result("SQLSetStmtAttr")
}
}
fn describe_col(
&mut self,
column_number: u16,
column_description: &mut ColumnDescription,
) -> SqlResult<()> {
let name = &mut column_description.name;
name.resize(name.capacity(), 0);
let mut name_length: i16 = 0;
let mut data_type = SqlDataType::UNKNOWN_TYPE;
let mut column_size = 0;
let mut decimal_digits = 0;
let mut nullable = odbc_sys::Nullability::UNKNOWN;
let res = unsafe {
sql_describe_col(
self.as_sys(),
column_number,
mut_buf_ptr(name),
clamp_small_int(name.len()),
&mut name_length,
&mut data_type,
&mut column_size,
&mut decimal_digits,
&mut nullable,
)
.into_sql_result("SQLDescribeCol")
};
if res.is_err() {
return res;
}
column_description.nullability = Nullability::new(nullable);
if name_length + 1 > clamp_small_int(name.len()) {
name.resize(name_length as usize + 1, 0);
self.describe_col(column_number, column_description)
} else {
name.resize(name_length as usize, 0);
column_description.data_type = DataType::new(data_type, column_size, decimal_digits);
res
}
}
unsafe fn exec_direct(&mut self, statement: &SqlText) -> SqlResult<()> {
unsafe {
sql_exec_direc(
self.as_sys(),
statement.ptr(),
statement.len_char().try_into().unwrap(),
)
}
.into_sql_result("SQLExecDirect")
}
fn close_cursor(&mut self) -> SqlResult<()> {
unsafe { SQLCloseCursor(self.as_sys()) }.into_sql_result("SQLCloseCursor")
}
fn prepare(&mut self, statement: &SqlText) -> SqlResult<()> {
unsafe {
sql_prepare(
self.as_sys(),
statement.ptr(),
statement.len_char().try_into().unwrap(),
)
}
.into_sql_result("SQLPrepare")
}
unsafe fn execute(&mut self) -> SqlResult<()> {
unsafe { SQLExecute(self.as_sys()) }.into_sql_result("SQLExecute")
}
fn num_result_cols(&mut self) -> SqlResult<i16> {
let mut out: i16 = 0;
unsafe { SQLNumResultCols(self.as_sys(), &mut out) }
.into_sql_result("SQLNumResultCols")
.on_success(|| out)
}
fn num_params(&mut self) -> SqlResult<u16> {
let mut out: i16 = 0;
unsafe { SQLNumParams(self.as_sys(), &mut out) }
.into_sql_result("SQLNumParams")
.on_success(|| out.try_into().unwrap())
}
unsafe fn set_row_array_size(&mut self, size: usize) -> SqlResult<()> {
assert!(size > 0);
unsafe {
sql_set_stmt_attr(
self.as_sys(),
StatementAttribute::RowArraySize,
size as Pointer,
0,
)
}
.into_sql_result("SQLSetStmtAttr")
}
fn row_array_size(&mut self) -> SqlResult<usize> {
let mut out: usize = 0;
let value = &mut out as *mut usize as Pointer;
unsafe {
sql_get_stmt_attr(
self.as_sys(),
StatementAttribute::RowArraySize,
value,
0,
null_mut(),
)
}
.into_sql_result("SQLGetStmtAttr")
.on_success(|| out)
}
unsafe fn set_paramset_size(&mut self, size: usize) -> SqlResult<()> {
assert!(size > 0);
unsafe {
sql_set_stmt_attr(
self.as_sys(),
StatementAttribute::ParamsetSize,
size as Pointer,
0,
)
}
.into_sql_result("SQLSetStmtAttr")
}
unsafe fn set_row_bind_type(&mut self, row_size: usize) -> SqlResult<()> {
unsafe {
sql_set_stmt_attr(
self.as_sys(),
StatementAttribute::RowBindType,
row_size as Pointer,
0,
)
}
.into_sql_result("SQLSetStmtAttr")
}
fn set_metadata_id(&mut self, metadata_id: bool) -> SqlResult<()> {
unsafe {
sql_set_stmt_attr(
self.as_sys(),
StatementAttribute::MetadataId,
metadata_id as usize as Pointer,
0,
)
.into_sql_result("SQLSetStmtAttr")
}
}
fn set_async_enable(&mut self, on: bool) -> SqlResult<()> {
unsafe {
sql_set_stmt_attr(
self.as_sys(),
StatementAttribute::AsyncEnable,
on as usize as Pointer,
0,
)
.into_sql_result("SQLSetStmtAttr")
}
}
unsafe fn bind_input_parameter(
&mut self,
parameter_number: u16,
parameter: &(impl HasDataType + CData + ?Sized + Send),
) -> SqlResult<()> {
let parameter_type = parameter.data_type();
unsafe {
SQLBindParameter(
self.as_sys(),
parameter_number,
ParamType::Input,
parameter.cdata_type(),
parameter_type.data_type(),
parameter_type
.column_size()
.map(NonZeroUsize::get)
.unwrap_or_default(),
parameter_type.decimal_digits(),
parameter.value_ptr() as *mut c_void,
parameter.buffer_length(),
parameter.indicator_ptr() as *mut isize,
)
}
.into_sql_result("SQLBindParameter")
}
unsafe fn bind_parameter(
&mut self,
parameter_number: u16,
input_output_type: ParamType,
parameter: &mut (impl CDataMut + HasDataType + Send),
) -> SqlResult<()> {
let parameter_type = parameter.data_type();
unsafe {
SQLBindParameter(
self.as_sys(),
parameter_number,
input_output_type,
parameter.cdata_type(),
parameter_type.data_type(),
parameter_type
.column_size()
.map(NonZeroUsize::get)
.unwrap_or_default(),
parameter_type.decimal_digits(),
parameter.value_ptr() as *mut c_void,
parameter.buffer_length(),
parameter.mut_indicator_ptr(),
)
}
.into_sql_result("SQLBindParameter")
}
unsafe fn bind_delayed_input_parameter(
&mut self,
parameter_number: u16,
parameter: &mut (impl DelayedInput + HasDataType),
) -> SqlResult<()> {
let paramater_type = parameter.data_type();
unsafe {
SQLBindParameter(
self.as_sys(),
parameter_number,
ParamType::Input,
parameter.cdata_type(),
paramater_type.data_type(),
paramater_type
.column_size()
.map(NonZeroUsize::get)
.unwrap_or_default(),
paramater_type.decimal_digits(),
parameter.stream_ptr(),
0,
parameter.indicator_ptr() as *mut isize,
)
}
.into_sql_result("SQLBindParameter")
}
fn is_unsigned_column(&mut self, column_number: u16) -> SqlResult<bool> {
unsafe { self.numeric_col_attribute(Desc::Unsigned, column_number) }.map(|out| match out {
0 => false,
1 => true,
_ => panic!("Unsigned column attribute must be either 0 or 1."),
})
}
fn col_type(&mut self, column_number: u16) -> SqlResult<SqlDataType> {
unsafe { self.numeric_col_attribute(Desc::Type, column_number) }.map(|ret| {
SqlDataType(ret.try_into().expect(
"Failed to retrieve data type from ODBC driver. The SQLLEN could not be converted to
a 16 Bit integer. If you are on a 64Bit Platform, this may be because your \
database driver being compiled against a SQLLEN with 32Bit size instead of 64Bit. \
E.g. IBM offers libdb2o.* and libdb2.*. With libdb2o.* being the one with the \
correct size.",
))
})
}
fn col_concise_type(&mut self, column_number: u16) -> SqlResult<SqlDataType> {
unsafe { self.numeric_col_attribute(Desc::ConciseType, column_number) }.map(|ret| {
SqlDataType(ret.try_into().expect(
"Failed to retrieve data type from ODBC driver. The SQLLEN could not be \
converted to a 16 Bit integer. If you are on a 64Bit Platform, this may be \
because your database driver being compiled against a SQLLEN with 32Bit size \
instead of 64Bit. E.g. IBM offers libdb2o.* and libdb2.*. With libdb2o.* being \
the one with the correct size.",
))
})
}
fn col_octet_length(&mut self, column_number: u16) -> SqlResult<isize> {
unsafe { self.numeric_col_attribute(Desc::OctetLength, column_number) }
}
fn col_display_size(&mut self, column_number: u16) -> SqlResult<isize> {
unsafe { self.numeric_col_attribute(Desc::DisplaySize, column_number) }
}
fn col_precision(&mut self, column_number: u16) -> SqlResult<isize> {
unsafe { self.numeric_col_attribute(Desc::Precision, column_number) }
}
fn col_scale(&mut self, column_number: u16) -> SqlResult<isize> {
unsafe { self.numeric_col_attribute(Desc::Scale, column_number) }
}
fn col_nullability(&mut self, column_number: u16) -> SqlResult<Nullability> {
unsafe { self.numeric_col_attribute(Desc::Nullable, column_number) }
.map(|nullability| Nullability::new(odbc_sys::Nullability(nullability as i16)))
}
fn col_name(&mut self, column_number: u16, buffer: &mut Vec<SqlChar>) -> SqlResult<()> {
let mut string_length_in_bytes: i16 = 0;
buffer.resize(buffer.capacity(), 0);
unsafe {
let mut res = sql_col_attribute(
self.as_sys(),
column_number,
Desc::Name,
mut_buf_ptr(buffer) as Pointer,
binary_length(buffer).try_into().unwrap(),
&mut string_length_in_bytes as *mut i16,
null_mut(),
)
.into_sql_result("SQLColAttribute");
if res.is_err() {
return res;
}
if is_truncated_bin(buffer, string_length_in_bytes.try_into().unwrap()) {
buffer.resize((string_length_in_bytes + 1).try_into().unwrap(), 0);
res = sql_col_attribute(
self.as_sys(),
column_number,
Desc::Name,
mut_buf_ptr(buffer) as Pointer,
binary_length(buffer).try_into().unwrap(),
&mut string_length_in_bytes as *mut i16,
null_mut(),
)
.into_sql_result("SQLColAttribute");
}
resize_to_fit_without_tz(buffer, string_length_in_bytes.try_into().unwrap());
res
}
}
unsafe fn numeric_col_attribute(
&mut self,
attribute: Desc,
column_number: u16,
) -> SqlResult<Len> {
let mut out: Len = 0;
unsafe {
sql_col_attribute(
self.as_sys(),
column_number,
attribute,
null_mut(),
0,
null_mut(),
&mut out as *mut Len,
)
}
.into_sql_result("SQLColAttribute")
.on_success(|| {
#[cfg(not(feature = "structured_logging"))]
trace!(
"SQLColAttribute called with attribute '{attribute:?}' for column \
'{column_number}' reported {out}."
);
#[cfg(feature = "structured_logging")]
trace!(
target: "odbc_api",
attribute:? = attribute,
column_number = column_number,
value = out;
"Column attribute queried"
);
out
})
}
fn reset_parameters(&mut self) -> SqlResult<()> {
unsafe {
SQLFreeStmt(self.as_sys(), FreeStmtOption::ResetParams).into_sql_result("SQLFreeStmt")
}
}
fn describe_param(&mut self, parameter_number: u16) -> SqlResult<ColumnType> {
let mut data_type = SqlDataType::UNKNOWN_TYPE;
let mut parameter_size = 0;
let mut decimal_digits = 0;
let mut nullable = odbc_sys::Nullability::UNKNOWN;
unsafe {
SQLDescribeParam(
self.as_sys(),
parameter_number,
&mut data_type,
&mut parameter_size,
&mut decimal_digits,
&mut nullable,
)
}
.into_sql_result("SQLDescribeParam")
.on_success(|| ColumnType {
data_type: DataType::new(data_type, parameter_size, decimal_digits),
nullability: Nullability::new(nullable),
})
}
fn param_data(&mut self) -> SqlResult<Option<Pointer>> {
unsafe {
let mut param_id: Pointer = null_mut();
match SQLParamData(self.as_sys(), &mut param_id as *mut Pointer) {
SqlReturn::NEED_DATA => SqlResult::Success(Some(param_id)),
other => other.into_sql_result("SQLParamData").on_success(|| None),
}
}
}
fn columns(
&mut self,
catalog_name: &SqlText,
schema_name: &SqlText,
table_name: &SqlText,
column_name: &SqlText,
) -> SqlResult<()> {
unsafe {
sql_columns(
self.as_sys(),
catalog_name.ptr(),
catalog_name.len_char().try_into().unwrap(),
schema_name.ptr(),
schema_name.len_char().try_into().unwrap(),
table_name.ptr(),
table_name.len_char().try_into().unwrap(),
column_name.ptr(),
column_name.len_char().try_into().unwrap(),
)
.into_sql_result("SQLColumns")
}
}
fn tables(
&mut self,
catalog_name: &SqlText,
schema_name: &SqlText,
table_name: &SqlText,
table_type: &SqlText,
) -> SqlResult<()> {
unsafe {
sql_tables(
self.as_sys(),
catalog_name.ptr(),
catalog_name.len_char().try_into().unwrap(),
schema_name.ptr(),
schema_name.len_char().try_into().unwrap(),
table_name.ptr(),
table_name.len_char().try_into().unwrap(),
table_type.ptr(),
table_type.len_char().try_into().unwrap(),
)
.into_sql_result("SQLTables")
}
}
fn primary_keys(
&mut self,
catalog_name: Option<&SqlText>,
schema_name: Option<&SqlText>,
table_name: &SqlText,
) -> SqlResult<()> {
unsafe {
sql_primary_keys(
self.as_sys(),
catalog_name.map_or(null(), |c| c.ptr()),
catalog_name.map_or(0, |c| c.len_char().try_into().unwrap()),
schema_name.map_or(null(), |s| s.ptr()),
schema_name.map_or(0, |s| s.len_char().try_into().unwrap()),
table_name.ptr(),
table_name.len_char().try_into().unwrap(),
)
.into_sql_result("SQLPrimaryKeys")
}
}
fn foreign_keys(
&mut self,
pk_catalog_name: &SqlText,
pk_schema_name: &SqlText,
pk_table_name: &SqlText,
fk_catalog_name: &SqlText,
fk_schema_name: &SqlText,
fk_table_name: &SqlText,
) -> SqlResult<()> {
unsafe {
sql_foreign_keys(
self.as_sys(),
pk_catalog_name.ptr(),
pk_catalog_name.len_char().try_into().unwrap(),
pk_schema_name.ptr(),
pk_schema_name.len_char().try_into().unwrap(),
pk_table_name.ptr(),
pk_table_name.len_char().try_into().unwrap(),
fk_catalog_name.ptr(),
fk_catalog_name.len_char().try_into().unwrap(),
fk_schema_name.ptr(),
fk_schema_name.len_char().try_into().unwrap(),
fk_table_name.ptr(),
fk_table_name.len_char().try_into().unwrap(),
)
.into_sql_result("SQLForeignKeys")
}
}
fn put_binary_batch(&mut self, batch: &[u8]) -> SqlResult<()> {
if batch.is_empty() {
panic!("Attempt to put empty batch into data source.")
}
unsafe {
SQLPutData(
self.as_sys(),
batch.as_ptr() as Pointer,
batch.len().try_into().unwrap(),
)
.into_sql_result("SQLPutData")
}
}
fn row_count(&mut self) -> SqlResult<isize> {
let mut ret = 0isize;
unsafe {
SQLRowCount(self.as_sys(), &mut ret as *mut isize)
.into_sql_result("SQLRowCount")
.on_success(|| ret)
}
}
#[cfg(feature = "odbc_version_3_80")]
fn complete_async(&mut self, function_name: &'static str) -> SqlResult<SqlResult<()>> {
let mut ret = SqlReturn::ERROR;
unsafe {
SQLCompleteAsync(self.handle_type(), self.as_handle(), &mut ret.0 as *mut _)
.into_sql_result("SQLCompleteAsync")
}
.on_success(|| ret.into_sql_result(function_name))
}
unsafe fn more_results(&mut self) -> SqlResult<()> {
unsafe { SQLMoreResults(self.as_sys()).into_sql_result("SQLMoreResults") }
}
fn application_row_descriptor(&mut self) -> SqlResult<Descriptor<'_>> {
unsafe {
let mut hdesc = HDesc::null();
let hdesc_out = &mut hdesc as *mut HDesc as Pointer;
odbc_sys::SQLGetStmtAttr(
self.as_sys(),
odbc_sys::StatementAttribute::AppRowDesc,
hdesc_out,
0,
null_mut(),
)
.into_sql_result("SQLGetStmtAttr")
.on_success(|| Descriptor::new(hdesc))
}
}
fn application_parameter_descriptor(&mut self) -> SqlResult<Descriptor<'_>> {
unsafe {
let mut hdesc = HDesc::null();
let hdesc_out = &mut hdesc as *mut HDesc as Pointer;
odbc_sys::SQLGetStmtAttr(
self.as_sys(),
odbc_sys::StatementAttribute::AppParamDesc,
hdesc_out,
0,
null_mut(),
)
.into_sql_result("SQLGetStmtAttr")
.on_success(|| Descriptor::new(hdesc))
}
}
}
impl Statement for StatementImpl<'_> {
fn as_sys(&self) -> HStmt {
self.handle
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct ColumnType {
pub nullability: Nullability,
pub data_type: DataType,
}
#[deprecated(note = "Use `ColumnType` instead.")]
pub type ParameterDescription = ColumnType;