use crate::binding::*;
use crate::chkerr;
use crate::error::dberror_from_dpi_error;
use crate::private;
use crate::sql_type::OracleType;
use crate::sql_type::ToSql;
use crate::sql_value::BufferRowIndex;
use crate::statement::QueryParams;
use crate::to_odpi_str;
use crate::to_rust_str;
use crate::Connection;
use crate::Error;
use crate::Result;
use crate::SqlValue;
#[cfg(doc)]
use crate::Statement;
use crate::StatementType;
use std::convert::TryFrom;
use std::fmt;
use std::mem::MaybeUninit;
use std::os::raw::c_char;
use std::ptr;
use std::slice;
#[cfg(test)]
const MINIMUM_TYPE_LENGTH: u32 = 1;
#[cfg(not(test))]
const MINIMUM_TYPE_LENGTH: u32 = 64;
fn po2(mut size: u32) -> u32 {
if size < MINIMUM_TYPE_LENGTH {
size = MINIMUM_TYPE_LENGTH;
}
1u32 << (32 - (size - 1).leading_zeros())
}
fn oratype_size(oratype: &OracleType) -> Option<u32> {
match oratype {
&OracleType::Varchar2(size)
| &OracleType::NVarchar2(size)
| &OracleType::Char(size)
| &OracleType::NChar(size)
| &OracleType::Raw(size) => Some(size),
_ => None,
}
}
#[derive(Clone)]
struct BindType {
oratype: Option<OracleType>,
}
impl BindType {
fn new(oratype: &OracleType) -> BindType {
BindType {
oratype: match oratype {
OracleType::Varchar2(size) => Some(OracleType::Varchar2(po2(*size))),
OracleType::NVarchar2(size) => Some(OracleType::NVarchar2(po2(*size))),
OracleType::Char(size) => Some(OracleType::Char(po2(*size))),
OracleType::NChar(size) => Some(OracleType::NChar(po2(*size))),
OracleType::Raw(size) => Some(OracleType::Raw(po2(*size))),
_ => None,
},
}
}
fn reset_size(&mut self, new_size: u32) {
self.oratype = match self.oratype {
Some(OracleType::Varchar2(_)) => Some(OracleType::Varchar2(po2(new_size))),
Some(OracleType::NVarchar2(_)) => Some(OracleType::NVarchar2(po2(new_size))),
Some(OracleType::Char(_)) => Some(OracleType::Char(po2(new_size))),
Some(OracleType::NChar(_)) => Some(OracleType::NChar(po2(new_size))),
Some(OracleType::Raw(_)) => Some(OracleType::Raw(po2(new_size))),
_ => None,
};
}
fn as_oratype(&self) -> Option<&OracleType> {
self.oratype.as_ref()
}
}
pub struct BatchBuilder<'conn, 'sql> {
conn: &'conn Connection,
sql: &'sql str,
batch_size: usize,
with_batch_errors: bool,
with_row_counts: bool,
query_params: QueryParams,
}
impl<'conn, 'sql> BatchBuilder<'conn, 'sql> {
pub(crate) fn new(
conn: &'conn Connection,
sql: &'sql str,
batch_size: usize,
) -> BatchBuilder<'conn, 'sql> {
BatchBuilder {
conn,
sql,
batch_size,
with_batch_errors: false,
with_row_counts: false,
query_params: QueryParams::new(),
}
}
pub fn with_batch_errors<'a>(&'a mut self) -> &'a mut BatchBuilder<'conn, 'sql> {
self.with_batch_errors = true;
self
}
pub fn with_row_counts<'a>(&'a mut self) -> &'a mut BatchBuilder<'conn, 'sql> {
self.with_row_counts = true;
self
}
pub fn build(&self) -> Result<Batch<'conn>> {
let batch_size = u32::try_from(self.batch_size)
.map_err(|_| Error::OutOfRange(format!("too large batch_size: {}", self.batch_size)))?;
let conn = self.conn;
let sql = to_odpi_str(self.sql);
let mut handle: *mut dpiStmt = ptr::null_mut();
chkerr!(
conn.ctxt(),
dpiConn_prepareStmt(
conn.handle(),
0,
sql.ptr,
sql.len,
ptr::null(),
0,
&mut handle
)
);
let mut info = MaybeUninit::uninit();
chkerr!(
conn.ctxt(),
dpiStmt_getInfo(handle, info.as_mut_ptr()),
unsafe {
dpiStmt_release(handle);
}
);
let info = unsafe { info.assume_init() };
if info.isDML == 0 && info.isPLSQL == 0 {
unsafe {
dpiStmt_release(handle);
}
let msg = format!(
"Could not use {} statement",
StatementType::from_enum(info.statementType)
);
return Err(Error::InvalidOperation(msg));
};
let mut num = 0;
chkerr!(
conn.ctxt(),
dpiStmt_getBindCount(handle, &mut num),
unsafe {
dpiStmt_release(handle);
}
);
let bind_count = num as usize;
let mut bind_names = Vec::with_capacity(bind_count);
let mut bind_values = Vec::with_capacity(bind_count);
if bind_count > 0 {
let mut names: Vec<*const c_char> = vec![ptr::null_mut(); bind_count];
let mut lengths = vec![0; bind_count];
chkerr!(
conn.ctxt(),
dpiStmt_getBindNames(handle, &mut num, names.as_mut_ptr(), lengths.as_mut_ptr()),
unsafe {
dpiStmt_release(handle);
}
);
bind_names = Vec::with_capacity(num as usize);
for i in 0..(num as usize) {
bind_names.push(to_rust_str(names[i], lengths[i]));
bind_values.push(SqlValue::for_bind(
conn.conn.clone(),
self.query_params.clone(),
batch_size,
));
}
};
Ok(Batch {
conn,
handle,
statement_type: StatementType::from_enum(info.statementType),
bind_count,
bind_names,
bind_values,
bind_types: vec![None; bind_count],
batch_index: 0,
batch_size,
with_batch_errors: self.with_batch_errors,
with_row_counts: self.with_row_counts,
query_params: self.query_params.clone(),
})
}
}
pub struct Batch<'conn> {
pub(crate) conn: &'conn Connection,
handle: *mut dpiStmt,
statement_type: StatementType,
bind_count: usize,
bind_names: Vec<String>,
bind_values: Vec<SqlValue>,
bind_types: Vec<Option<BindType>>,
batch_index: u32,
batch_size: u32,
with_batch_errors: bool,
with_row_counts: bool,
query_params: QueryParams,
}
impl<'conn> Batch<'conn> {
pub fn close(&mut self) -> Result<()> {
chkerr!(self.conn.ctxt(), dpiStmt_close(self.handle, ptr::null(), 0));
self.handle = ptr::null_mut();
Ok(())
}
pub fn append_row(&mut self, params: &[&dyn ToSql]) -> Result<()> {
self.check_batch_index()?;
for (i, param) in params.iter().enumerate() {
self.bind_internal(i + 1, *param)?;
}
self.append_row_common()
}
pub fn append_row_named(&mut self, params: &[(&str, &dyn ToSql)]) -> Result<()> {
self.check_batch_index()?;
for param in params {
self.bind_internal(param.0, param.1)?;
}
self.append_row_common()
}
fn append_row_common(&mut self) -> Result<()> {
if self.with_batch_errors {
self.set_batch_index(self.batch_index + 1);
} else {
self.set_batch_index(self.batch_index + 1);
if self.batch_index == self.batch_size {
self.execute()?;
}
}
Ok(())
}
pub fn execute(&mut self) -> Result<()> {
let result = self.execute_sub();
let num_rows = self.batch_index;
self.batch_index = 0;
for bind_value in &mut self.bind_values {
for i in 0..num_rows {
bind_value.buffer_row_index = BufferRowIndex::Owned(i);
bind_value.set_null()?;
}
bind_value.buffer_row_index = BufferRowIndex::Owned(0);
}
result
}
fn execute_sub(&mut self) -> Result<()> {
if self.batch_index == 0 {
return Ok(());
}
let mut exec_mode = DPI_MODE_EXEC_DEFAULT;
if self.conn.autocommit() {
exec_mode |= DPI_MODE_EXEC_COMMIT_ON_SUCCESS;
}
if self.with_batch_errors {
exec_mode |= DPI_MODE_EXEC_BATCH_ERRORS;
}
if self.with_row_counts {
exec_mode |= DPI_MODE_EXEC_ARRAY_DML_ROWCOUNTS;
}
chkerr!(
self.conn.ctxt(),
dpiStmt_executeMany(self.handle, exec_mode, self.batch_index)
);
self.conn.ctxt().set_warning();
if self.with_batch_errors {
let mut errnum = 0;
chkerr!(
self.conn.ctxt(),
dpiStmt_getBatchErrorCount(self.handle, &mut errnum)
);
if errnum != 0 {
let mut errs = Vec::with_capacity(errnum as usize);
chkerr!(
self.conn.ctxt(),
dpiStmt_getBatchErrors(self.handle, errnum, errs.as_mut_ptr())
);
unsafe { errs.set_len(errnum as usize) };
return Err(Error::BatchErrors(
errs.iter().map(dberror_from_dpi_error).collect(),
));
}
}
Ok(())
}
pub fn bind_count(&self) -> usize {
self.bind_count
}
pub fn bind_names(&self) -> Vec<&str> {
self.bind_names.iter().map(|name| name.as_str()).collect()
}
fn check_batch_index(&self) -> Result<()> {
if self.batch_index < self.batch_size {
Ok(())
} else {
Err(Error::OutOfRange(format!(
"Over the max batch size {}",
self.batch_size
)))
}
}
pub fn set_type<I>(&mut self, bindidx: I, oratype: &OracleType) -> Result<()>
where
I: BatchBindIndex,
{
let pos = bindidx.idx(self)?;
if self.bind_types[pos].is_some() {
return Err(Error::InvalidOperation(format!(
"The bind parameter type at {} has been specified already.",
bindidx
)));
}
self.bind_values[pos].init_handle(oratype)?;
chkerr!(
self.conn.ctxt(),
bindidx.bind(self.handle, self.bind_values[pos].handle)
);
self.bind_types[pos] = Some(BindType::new(oratype));
Ok(())
}
pub fn set<I>(&mut self, index: I, value: &dyn ToSql) -> Result<()>
where
I: BatchBindIndex,
{
self.check_batch_index()?;
self.bind_internal(index, value)
}
fn bind_internal<I>(&mut self, bindidx: I, value: &dyn ToSql) -> Result<()>
where
I: BatchBindIndex,
{
let pos = bindidx.idx(self)?;
if self.bind_types[pos].is_none() {
let oratype = value.oratype(self.conn)?;
let bind_type = BindType::new(&oratype);
self.bind_values[pos].init_handle(bind_type.as_oratype().unwrap_or(&oratype))?;
chkerr!(
self.conn.ctxt(),
bindidx.bind(self.handle, self.bind_values[pos].handle)
);
self.bind_types[pos] = Some(bind_type);
}
match self.bind_values[pos].set(value) {
Err(Error::DpiError(dberr)) if dberr.message().starts_with("DPI-1019:") => {
let bind_type = self.bind_types[pos].as_mut().unwrap();
if bind_type.as_oratype().is_none() {
return Err(Error::DpiError(dberr));
}
let new_oratype = value.oratype(self.conn)?;
let new_size = oratype_size(&new_oratype).ok_or(Error::DpiError(dberr))?;
bind_type.reset_size(new_size);
let mut new_sql_value = SqlValue::for_bind(
self.conn.conn.clone(),
self.query_params.clone(),
self.batch_size,
);
new_sql_value.init_handle(bind_type.as_oratype().unwrap())?;
for idx in 0..self.batch_index {
chkerr!(
self.conn.ctxt(),
dpiVar_copyData(
new_sql_value.handle,
idx,
self.bind_values[pos].handle,
idx
)
);
}
new_sql_value.buffer_row_index = BufferRowIndex::Owned(self.batch_index);
new_sql_value.set(value)?;
chkerr!(
self.conn.ctxt(),
bindidx.bind(self.handle, new_sql_value.handle)
);
self.bind_values[pos] = new_sql_value;
Ok(())
}
x => x,
}
}
fn set_batch_index(&mut self, batch_index: u32) {
self.batch_index = batch_index;
for bind_value in &mut self.bind_values {
bind_value.buffer_row_index = BufferRowIndex::Owned(batch_index);
}
}
pub fn row_counts(&self) -> Result<Vec<u64>> {
let mut num_row_counts = 0;
let mut row_counts = ptr::null_mut();
chkerr!(
self.conn.ctxt(),
dpiStmt_getRowCounts(self.handle, &mut num_row_counts, &mut row_counts)
);
Ok(unsafe { slice::from_raw_parts(row_counts, num_row_counts as usize) }.to_vec())
}
pub fn statement_type(&self) -> StatementType {
self.statement_type
}
pub fn is_plsql(&self) -> bool {
matches!(
self.statement_type,
StatementType::Begin | StatementType::Declare | StatementType::Call
)
}
pub fn is_dml(&self) -> bool {
matches!(
self.statement_type,
StatementType::Insert
| StatementType::Update
| StatementType::Delete
| StatementType::Merge
)
}
}
impl<'conn> Drop for Batch<'conn> {
fn drop(&mut self) {
unsafe { dpiStmt_release(self.handle) };
}
}
pub trait BatchBindIndex: private::Sealed + fmt::Display {
#[doc(hidden)]
fn idx(&self, batch: &Batch) -> Result<usize>;
#[doc(hidden)]
unsafe fn bind(&self, stmt_handle: *mut dpiStmt, var_handle: *mut dpiVar) -> i32;
}
impl BatchBindIndex for usize {
#[doc(hidden)]
fn idx(&self, batch: &Batch) -> Result<usize> {
let num = batch.bind_count();
if 0 < num && *self <= num {
Ok(*self - 1)
} else {
Err(Error::InvalidBindIndex(*self))
}
}
#[doc(hidden)]
unsafe fn bind(&self, stmt_handle: *mut dpiStmt, var_handle: *mut dpiVar) -> i32 {
dpiStmt_bindByPos(stmt_handle, *self as u32, var_handle)
}
}
impl<'a> BatchBindIndex for &'a str {
#[doc(hidden)]
fn idx(&self, batch: &Batch) -> Result<usize> {
let bindname = self.to_uppercase();
batch
.bind_names()
.iter()
.position(|&name| name == bindname)
.ok_or_else(|| Error::InvalidBindName((*self).to_string()))
}
#[doc(hidden)]
unsafe fn bind(&self, stmt_handle: *mut dpiStmt, var_handle: *mut dpiVar) -> i32 {
let s = to_odpi_str(self);
dpiStmt_bindByName(stmt_handle, s.ptr, s.len, var_handle)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util;
#[derive(Debug)]
struct TestData {
int_val: i32,
string_val: &'static str,
error_code: Option<i32>,
}
impl TestData {
const fn new(int_val: i32, string_val: &'static str, error_code: Option<i32>) -> TestData {
TestData {
int_val,
string_val,
error_code,
}
}
}
const ERROR_UNIQUE_INDEX_VIOLATION: Option<i32> = Some(1);
const ERROR_TOO_LARGE_VALUE: Option<i32> = Some(12899);
const TEST_DATA: [TestData; 10] = [
TestData::new(0, "0", None),
TestData::new(1, "1111", None),
TestData::new(2, "222222222222", None),
TestData::new(3, "3333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333", None),
TestData::new(4, "44444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444444", ERROR_TOO_LARGE_VALUE),
TestData::new(1, "55555555555555", ERROR_UNIQUE_INDEX_VIOLATION),
TestData::new(6, "66666666666", None),
TestData::new(2, "7", ERROR_UNIQUE_INDEX_VIOLATION),
TestData::new(8, "8", None),
TestData::new(3, "9999999999999999999999999", ERROR_UNIQUE_INDEX_VIOLATION),
];
fn append_rows_then_execute(batch: &mut Batch, rows: &[&TestData]) -> Result<()> {
for row in rows {
batch.append_row(&[&row.int_val, &row.string_val])?;
}
batch.execute()?;
Ok(())
}
fn check_rows_inserted(conn: &Connection, expected_rows: &[&TestData]) -> Result<()> {
let mut rows =
conn.query_as::<(i32, String)>("select * from TestTempTable order by intCol", &[])?;
let mut expected_rows = expected_rows.to_vec();
expected_rows.sort_by(|a, b| a.int_val.cmp(&b.int_val));
for expected_row in expected_rows {
let row_opt = rows.next();
assert!(row_opt.is_some());
let row = row_opt.unwrap()?;
assert_eq!(row.0, expected_row.int_val);
assert_eq!(row.1, expected_row.string_val);
}
assert!(rows.next().is_none());
Ok(())
}
#[test]
fn batch_insert() {
let conn = test_util::connect().unwrap();
let rows: Vec<&TestData> = TEST_DATA
.iter()
.filter(|data| data.error_code.is_none())
.collect();
let mut batch = conn
.batch("insert into TestTempTable values(:1, :2)", rows.len())
.build()
.unwrap();
append_rows_then_execute(&mut batch, &rows).unwrap();
check_rows_inserted(&conn, &rows).unwrap();
}
#[test]
fn batch_execute_twice() {
let conn = test_util::connect().unwrap();
let rows_total: Vec<&TestData> = TEST_DATA
.iter()
.filter(|data| data.error_code.is_none())
.collect();
let (rows_first, rows_second) = rows_total.split_at(rows_total.len() / 2);
let mut batch = conn
.batch("insert into TestTempTable values(:1, :2)", rows_first.len())
.build()
.unwrap();
append_rows_then_execute(&mut batch, rows_first).unwrap();
append_rows_then_execute(&mut batch, rows_second).unwrap();
check_rows_inserted(&conn, &rows_total).unwrap();
}
#[test]
fn batch_with_error() {
let conn = test_util::connect().unwrap();
let rows: Vec<&TestData> = TEST_DATA.iter().collect();
let expected_rows: Vec<&TestData> = TEST_DATA
.iter()
.take_while(|data| data.error_code.is_none())
.collect();
let mut batch = conn
.batch("insert into TestTempTable values(:1, :2)", rows.len())
.build()
.unwrap();
match append_rows_then_execute(&mut batch, &rows) {
Err(Error::OciError(dberr)) => {
let errcode = TEST_DATA
.iter()
.find(|data| data.error_code.is_some())
.unwrap()
.error_code
.unwrap();
assert_eq!(dberr.code(), errcode);
}
x => {
panic!("got {:?}", x);
}
}
check_rows_inserted(&conn, &expected_rows).unwrap();
}
#[test]
fn batch_with_batch_errors() {
let conn = test_util::connect().unwrap();
let rows: Vec<&TestData> = TEST_DATA.iter().collect();
let expected_rows: Vec<&TestData> = TEST_DATA
.iter()
.filter(|row| row.error_code.is_none())
.collect();
let mut batch = conn
.batch("insert into TestTempTable values(:1, :2)", rows.len())
.with_batch_errors()
.build()
.unwrap();
match append_rows_then_execute(&mut batch, &rows) {
Err(Error::BatchErrors(errs)) => {
let expected_errors: Vec<(u32, i32)> = TEST_DATA
.iter()
.enumerate()
.filter(|row| row.1.error_code.is_some())
.map(|row| (row.0 as u32, row.1.error_code.unwrap()))
.collect();
let actual_errors: Vec<(u32, i32)> = errs
.iter()
.map(|dberr| (dberr.offset(), dberr.code()))
.collect();
assert_eq!(expected_errors, actual_errors);
}
x => {
panic!("got {:?}", x);
}
}
check_rows_inserted(&conn, &expected_rows).unwrap();
}
}