use crate::blr::message_blr;
use crate::connection::Connection;
use crate::error::{DatabaseError, Error, Result, StatusVector};
use crate::message::{encode_row_into, message_buffer_len};
use crate::transaction::Transaction;
use crate::value::{ColumnMeta, Value};
use crate::wire::consts::*;
use crate::wire::response::{read_op, read_response, read_response_body, read_status_vector};
use crate::wire::stream::{op_name, op_packet};
use crate::wire::xdr::ParameterBuffer;
pub struct Batch {
handle: i32,
params: Vec<ColumnMeta>,
pending: Vec<u8>,
pending_count: u32,
pending_blobs: Vec<u8>,
blob_stream_len: usize,
pending_regblobs: Vec<(u64, u64)>,
next_blob_id: u64,
blob_stream: bool,
charset: crate::charset::Charset,
default_bpb: Option<Vec<u8>>,
blob_segmented: bool,
closed: bool,
}
const BLOB_SEGHDR_ALIGN: usize = 2;
const _: () = assert!(BLOB_ALIGN.is_multiple_of(BLOB_SEGHDR_ALIGN));
#[inline]
fn align_up(n: usize, align: usize) -> usize {
(n + align - 1) & !(align - 1)
}
fn bpb_is_segmented(bpb: &[u8]) -> bool {
let mut i = 1; while i + 1 < bpb.len() {
let tag = bpb[i];
let len = bpb[i + 1] as usize;
let val = &bpb[i + 2..(i + 2 + len).min(bpb.len())];
if tag == bpb::TYPE {
return val.first() == Some(&bpb::TYPE_SEGMENTED);
}
i += 2 + len;
}
false
}
const BLOB_ALIGN: usize = 4;
impl Batch {
pub fn handle(&self) -> i32 {
self.handle
}
pub fn pending(&self) -> u32 {
self.pending_count
}
pub fn blob_stream_len(&self) -> usize {
self.blob_stream_len
}
pub fn params(&self) -> &[ColumnMeta] {
&self.params
}
pub fn add_blob(&mut self, data: &[u8]) -> Result<u64> {
if !self.blob_stream {
return Err(Error::protocol(
"add_blob exige uma coluna BLOB na instrução do batch",
));
}
let id = self.next_blob_id;
self.next_blob_id += 1;
self.pending_blobs.extend_from_slice(&id.to_be_bytes());
if self.blob_segmented {
let size_field = align_up(2 + data.len(), BLOB_SEGHDR_ALIGN);
self.pending_blobs
.extend_from_slice(&(size_field as u32).to_be_bytes());
self.pending_blobs.extend_from_slice(&0u32.to_be_bytes()); self.pending_blobs
.extend_from_slice(&(data.len() as u32).to_be_bytes());
self.pending_blobs.extend_from_slice(data);
self.blob_stream_len += align_up(16 + size_field, BLOB_ALIGN);
} else {
self.pending_blobs
.extend_from_slice(&(data.len() as u32).to_be_bytes());
self.pending_blobs.extend_from_slice(&0u32.to_be_bytes()); self.pending_blobs.extend_from_slice(data);
self.blob_stream_len += align_up(16 + data.len(), BLOB_ALIGN);
}
Ok(id)
}
pub fn set_default_bpb(&mut self, bpb: Vec<u8>) -> Result<()> {
if !self.blob_stream {
return Err(Error::protocol(
"set_default_bpb exige uma coluna BLOB na instrução do batch",
));
}
self.blob_segmented = bpb_is_segmented(&bpb);
self.default_bpb = Some(bpb);
Ok(())
}
pub fn set_segmented(&mut self, segmented: bool) -> Result<()> {
let kind = if segmented {
bpb::TYPE_SEGMENTED
} else {
bpb::TYPE_STREAM
};
self.set_default_bpb(vec![BPB_VERSION1, bpb::TYPE, 4, kind, 0, 0, 0])
}
pub fn register_blob(&mut self, existing_id: u64) -> Result<u64> {
if !self.blob_stream {
return Err(Error::protocol(
"register_blob exige uma coluna BLOB na instrução do batch",
));
}
let batch_id = self.next_blob_id;
self.next_blob_id += 1;
self.pending_regblobs.push((existing_id, batch_id));
Ok(batch_id)
}
pub fn add(&mut self, values: &[Value]) -> Result<()> {
encode_row_into(&mut self.pending, &self.params, values, self.charset)?;
self.pending_count += 1;
Ok(())
}
pub fn execute(&mut self, conn: &mut Connection, tx: &Transaction) -> Result<BatchResult> {
if self.closed {
return Err(Error::protocol("batch já foi fechado"));
}
if let Some(bpb) = self.default_bpb.take() {
let mut w = op_packet(op::BATCH_SET_BPB);
w.put_i32(self.handle);
w.put_bytes(&bpb); conn.io().send(&w)?;
read_response(conn.io())?;
}
if !self.pending_blobs.is_empty() {
let mut w = op_packet(op::BATCH_BLOB_STREAM);
w.put_i32(self.handle);
w.put_i32(self.blob_stream_len as i32);
w.put_raw(&self.pending_blobs);
conn.io().send(&w)?;
read_response(conn.io())?;
self.pending_blobs.clear();
self.blob_stream_len = 0;
}
if !self.pending_regblobs.is_empty() {
for (existing_id, batch_id) in std::mem::take(&mut self.pending_regblobs) {
let mut w = op_packet(op::BATCH_REGBLOB);
w.put_i32(self.handle);
w.put_raw(&existing_id.to_be_bytes());
w.put_raw(&batch_id.to_be_bytes());
conn.io().send(&w)?;
read_response(conn.io())?;
}
}
if self.pending_count > 0 {
let mut w = op_packet(op::BATCH_MSG);
w.put_i32(self.handle);
w.put_i32(self.pending_count as i32);
w.put_raw(&self.pending);
w.align();
conn.io().send(&w)?;
read_response(conn.io())?;
self.pending.clear();
self.pending_count = 0;
}
let mut w = op_packet(op::BATCH_EXEC);
w.put_i32(self.handle);
w.put_i32(tx.handle());
conn.io().send(&w)?;
read_batch_cs(conn)
}
pub fn cancel(&mut self, conn: &mut Connection) -> Result<()> {
self.pending.clear();
self.pending_count = 0;
self.pending_blobs.clear();
self.blob_stream_len = 0;
self.pending_regblobs.clear();
let mut w = op_packet(op::BATCH_CANCEL);
w.put_i32(self.handle);
conn.io().send(&w)?;
read_response(conn.io())?;
Ok(())
}
pub fn close(mut self, conn: &mut Connection) -> Result<()> {
self.closed = true;
let mut w = op_packet(op::BATCH_RLS);
w.put_i32(self.handle);
conn.io().send(&w)?;
read_response(conn.io())?;
let mut w = op_packet(op::FREE_STATEMENT);
w.put_i32(self.handle);
w.put_i32(free::DROP);
conn.io().send(&w)?;
read_response(conn.io())?;
Ok(())
}
}
impl Drop for Batch {
fn drop(&mut self) {
if !self.closed {
crate::warn_unclosed("Batch", self.handle);
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct BatchOptions {
pub multierror: bool,
pub record_counts: bool,
pub buffer_bytes: Option<u32>,
}
impl Default for BatchOptions {
fn default() -> Self {
BatchOptions {
multierror: false,
record_counts: true,
buffer_bytes: None,
}
}
}
impl BatchOptions {
pub fn new() -> Self {
Self::default()
}
pub fn multierror(mut self, on: bool) -> Self {
self.multierror = on;
self
}
pub fn record_counts(mut self, on: bool) -> Self {
self.record_counts = on;
self
}
pub fn buffer_bytes(mut self, bytes: u32) -> Self {
self.buffer_bytes = Some(bytes);
self
}
}
impl Connection {
pub fn create_batch(&mut self, tx: &Transaction, sql: &str) -> Result<Batch> {
self.create_batch_with(tx, sql, BatchOptions::default())
}
pub fn create_batch_with(
&mut self,
tx: &Transaction,
sql: &str,
opts: BatchOptions,
) -> Result<Batch> {
let mut stmt = self.prepare(tx, sql)?;
let handle = stmt.handle();
let params: Vec<ColumnMeta> = stmt.params().to_vec();
stmt.forget_handle();
drop(stmt);
let blr = message_blr(¶ms);
let msglen = message_buffer_len(¶ms);
let blob_stream = params
.iter()
.any(|c| sql_type::base(c.sql_type) == sql_type::BLOB);
let mut pb = ParameterBuffer::new(1);
if opts.record_counts {
pb.bytes_be_len4(batch_tag::RECORD_COUNTS, &1u32.to_le_bytes());
}
if opts.multierror {
pb.bytes_be_len4(batch_tag::MULTIERROR, &1u32.to_le_bytes());
}
if let Some(bytes) = opts.buffer_bytes {
pb.bytes_be_len4(batch_tag::BUFFER_BYTES_SIZE, &bytes.to_le_bytes());
}
if blob_stream {
pb.bytes_be_len4(
batch_tag::BLOB_POLICY,
&(blob_policy::STREAM as u32).to_le_bytes(),
);
}
let mut w = op_packet(op::BATCH_CREATE);
w.put_i32(handle);
w.put_bytes(&blr); w.put_i32(msglen as i32);
w.put_bytes(pb.as_slice()); self.io().send(&w)?;
read_response(self.io())?;
Ok(Batch {
handle,
params,
pending: Vec::new(),
pending_count: 0,
pending_blobs: Vec::new(),
blob_stream_len: 0,
pending_regblobs: Vec::new(),
next_blob_id: 1,
blob_stream,
charset: self.charset(),
default_bpb: None,
blob_segmented: false,
closed: false,
})
}
}
#[derive(Debug, Clone, Default)]
pub struct BatchResult {
pub total: u32,
pub update_counts: Vec<i32>,
pub errors: Vec<BatchError>,
}
impl BatchResult {
pub fn all_succeeded(&self) -> bool {
self.errors.is_empty() && !self.update_counts.contains(&batch_cs::EXECUTE_FAILED)
}
pub fn total_affected(&self) -> u64 {
self.update_counts
.iter()
.filter(|&&c| c >= 0)
.map(|&c| c as u64)
.sum()
}
}
#[derive(Debug, Clone)]
pub struct BatchError {
pub message_index: u32,
pub error: DatabaseError,
}
fn read_batch_cs(conn: &mut Connection) -> Result<BatchResult> {
let code = read_op(conn.io())?;
if code == op::RESPONSE {
read_response_body(conn.io())?.into_result()?;
return Err(Error::protocol(
"op_batch_exec retornou op_response sem erro",
));
}
if code != op::BATCH_CS {
return Err(Error::protocol(format!(
"esperava op_batch_cs, veio {} ({code})",
op_name(code)
)));
}
let _stmt = conn.io().read_i32()?;
let reccount = conn.io().read_i32()? as u32;
let updates = conn.io().read_i32()? as u32;
let vectors = conn.io().read_i32()? as u32;
let errors = conn.io().read_i32()? as u32;
let mut update_counts = Vec::with_capacity(updates as usize);
for _ in 0..updates {
update_counts.push(conn.io().read_i32()?);
}
let mut batch_errors = Vec::with_capacity(vectors as usize);
for _ in 0..vectors {
let pos = conn.io().read_i32()? as u32;
let status = read_status_vector(conn.io())?;
batch_errors.push(BatchError {
message_index: pos,
error: DatabaseError::new(status),
});
}
for _ in 0..errors {
let pos = conn.io().read_i32()? as u32;
if !batch_errors.iter().any(|e| e.message_index == pos) {
let empty = StatusVector {
args: Vec::new(),
sql_state: None,
};
batch_errors.push(BatchError {
message_index: pos,
error: DatabaseError::new(empty),
});
}
}
Ok(BatchResult {
total: reccount,
update_counts,
errors: batch_errors,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn fake_blob_batch() -> Batch {
Batch {
handle: 0,
params: Vec::new(),
pending: Vec::new(),
pending_count: 0,
pending_blobs: Vec::new(),
blob_stream_len: 0,
pending_regblobs: Vec::new(),
next_blob_id: 1,
blob_stream: true,
charset: crate::charset::Charset::Utf8,
default_bpb: None,
blob_segmented: false,
closed: true,
}
}
fn entry_len(data_len: usize) -> usize {
align_up(16 + data_len, BLOB_ALIGN)
}
#[test]
fn add_blob_atribui_ids_sequenciais_a_partir_de_1() {
let mut b = fake_blob_batch();
assert_eq!(b.add_blob(b"a").unwrap(), 1);
assert_eq!(b.add_blob(b"bb").unwrap(), 2);
assert_eq!(b.add_blob(b"ccc").unwrap(), 3);
}
#[test]
fn add_blob_exige_coluna_blob() {
let mut b = fake_blob_batch();
b.blob_stream = false;
assert!(b.add_blob(b"x").is_err());
}
#[test]
fn blob_stream_len_acumula_tamanhos_alinhados() {
let mut b = fake_blob_batch();
for data in [&b"x"[..], &b"yy"[..], &b"zzz"[..], &b"wwww"[..]] {
b.add_blob(data).unwrap();
}
let esperado = entry_len(1) + entry_len(2) + entry_len(3) + entry_len(4);
assert_eq!(b.blob_stream_len(), esperado);
assert_eq!(b.blob_stream_len() % BLOB_ALIGN, 0);
}
#[test]
fn blob_stream_len_de_um_unico_blob() {
let mut b = fake_blob_batch();
b.add_blob(&[0u8; 30]).unwrap();
assert_eq!(b.blob_stream_len(), 48);
let mut b = fake_blob_batch();
b.add_blob(&[0u8; 33]).unwrap();
assert_eq!(b.blob_stream_len(), 52);
}
#[test]
fn batch_options_buffer_bytes() {
assert_eq!(BatchOptions::new().buffer_bytes, None);
let opts = BatchOptions::new().buffer_bytes(16 * 1024 * 1024);
assert_eq!(opts.buffer_bytes, Some(16 * 1024 * 1024));
}
}