use crate::blr::{input_blr, message_blr, prepare_info_items};
use crate::connection::Connection;
use crate::error::{Error, Result};
use crate::message::{decode_row, encode_row};
use crate::transaction::Transaction;
use crate::value::{ColumnMeta, Value};
use crate::wire::consts::*;
use crate::wire::response::{read_op, read_response, read_response_body};
use crate::wire::stream::{op_name, op_packet};
use crate::wire::xdr::{read_le_int, read_le_int_signed};
const SQL_DIALECT: i32 = 3;
const INFO_BUFFER_LEN: i32 = 0xfb80;
const FETCH_BATCH: i32 = 200;
const RECORDS_BUFFER_LEN: i32 = 64;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RowsAffected {
pub selected: u64,
pub inserted: u64,
pub updated: u64,
pub deleted: u64,
}
impl RowsAffected {
pub fn total_modified(&self) -> u64 {
self.inserted + self.updated + self.deleted
}
}
#[derive(Debug)]
pub struct Statement {
handle: i32,
stmt_type: i32,
params: Vec<ColumnMeta>,
columns: Vec<ColumnMeta>,
cursor_open: bool,
buffered: std::collections::VecDeque<Vec<Value>>,
exhausted: bool,
scrollable: bool,
fetch_size: i32,
dropped: bool,
}
impl Statement {
pub fn handle(&self) -> i32 {
self.handle
}
pub fn stmt_type(&self) -> i32 {
self.stmt_type
}
pub fn is_select(&self) -> bool {
self.stmt_type == stmt_type::SELECT || self.stmt_type == stmt_type::SELECT_FOR_UPD
}
pub fn set_scrollable(&mut self, yes: bool) {
self.scrollable = yes;
}
pub fn is_scrollable(&self) -> bool {
self.scrollable
}
pub fn set_fetch_size(&mut self, n: i32) {
self.fetch_size = n.max(1);
}
pub fn fetch_size(&self) -> i32 {
self.fetch_size
}
pub fn columns(&self) -> &[ColumnMeta] {
&self.columns
}
pub fn params(&self) -> &[ColumnMeta] {
&self.params
}
pub fn execute(
&mut self,
conn: &mut Connection,
tx: &Transaction,
params: &[Value],
) -> Result<()> {
let has_params = !self.params.is_empty();
let in_blr = if has_params {
input_blr(&self.params)
} else {
Vec::new()
};
let message = if has_params {
encode_row(&self.params, params, conn.charset())?
} else {
Vec::new()
};
let mut w = op_packet(op::EXECUTE);
w.put_i32(self.handle);
w.put_i32(tx.handle());
w.put_bytes(&in_blr); w.put_i32(0); w.put_i32(if has_params { 1 } else { 0 }); if has_params {
w.put_raw(&message);
w.align();
}
w.put_bytes(&[]); w.put_i32(if self.scrollable {
cursor_type::SCROLLABLE
} else {
0
});
if conn.protocol_version() >= 18 {
w.put_i32(0); }
conn.io().send(&w)?;
read_response(conn.io())?;
self.cursor_open = self.is_select();
self.buffered.clear();
self.exhausted = false;
Ok(())
}
pub fn fetch(&mut self, conn: &mut Connection) -> Result<Option<Vec<Value>>> {
loop {
if let Some(row) = self.buffered.pop_front() {
return Ok(Some(row));
}
if self.exhausted || !self.cursor_open {
self.cursor_open = false;
return Ok(None);
}
self.fetch_batch(conn)?;
}
}
fn fetch_batch(&mut self, conn: &mut Connection) -> Result<()> {
let out_blr = message_blr(&self.columns);
let mut w = op_packet(op::FETCH);
w.put_i32(self.handle);
w.put_bytes(&out_blr);
w.put_i32(0); w.put_i32(self.fetch_size);
conn.io().send(&w)?;
loop {
let code = read_op(conn.io())?;
if code != op::FETCH_RESPONSE {
if code == op::RESPONSE {
read_response_body(conn.io())?.into_result()?;
}
return Err(Error::protocol(format!(
"expected op_fetch_response, got {} ({code})",
op_name(code)
)));
}
let status = conn.io().read_i32()?; let count = conn.io().read_i32()?; if count == 0 {
self.exhausted = status == 100;
return Ok(());
}
let cs = conn.charset();
let row = decode_row(conn.io(), &self.columns, cs)?;
self.buffered.push_back(row);
if status == 100 {
self.exhausted = true;
return Ok(());
}
}
}
pub fn fetch_all(&mut self, conn: &mut Connection) -> Result<Vec<Vec<Value>>> {
let mut rows = Vec::new();
while let Some(row) = self.fetch(conn)? {
rows.push(row);
}
Ok(rows)
}
pub fn rows<'a>(&'a mut self, conn: &'a mut Connection) -> RowStream<'a> {
RowStream { stmt: self, conn }
}
pub fn fetch_scroll(
&mut self,
conn: &mut Connection,
direction: i32,
offset: i32,
) -> Result<Option<Vec<Value>>> {
if !self.cursor_open {
return Ok(None);
}
self.buffered.clear();
let out_blr = message_blr(&self.columns);
let mut w = op_packet(op::FETCH_SCROLL);
w.put_i32(self.handle);
w.put_bytes(&out_blr);
w.put_i32(0); w.put_i32(1); w.put_i32(direction);
w.put_i32(offset);
conn.io().send(&w)?;
let mut row = None;
loop {
let code = read_op(conn.io())?;
if code != op::FETCH_RESPONSE {
if code == op::RESPONSE {
read_response_body(conn.io())?.into_result()?;
}
return Err(Error::protocol(format!(
"expected op_fetch_response, got {} ({code})",
op_name(code)
)));
}
let status = conn.io().read_i32()?;
let count = conn.io().read_i32()?;
if count == 0 {
break;
}
let cs = conn.charset();
let r = decode_row(conn.io(), &self.columns, cs)?;
if row.is_none() {
row = Some(r);
}
if status == 100 {
break;
}
}
self.exhausted = false;
Ok(row)
}
pub fn fetch_next(&mut self, conn: &mut Connection) -> Result<Option<Vec<Value>>> {
self.fetch_scroll(conn, scroll::NEXT, 0)
}
pub fn fetch_prior(&mut self, conn: &mut Connection) -> Result<Option<Vec<Value>>> {
self.fetch_scroll(conn, scroll::PRIOR, 0)
}
pub fn fetch_first(&mut self, conn: &mut Connection) -> Result<Option<Vec<Value>>> {
self.fetch_scroll(conn, scroll::FIRST, 0)
}
pub fn fetch_last(&mut self, conn: &mut Connection) -> Result<Option<Vec<Value>>> {
self.fetch_scroll(conn, scroll::LAST, 0)
}
pub fn fetch_absolute(
&mut self,
conn: &mut Connection,
pos: i32,
) -> Result<Option<Vec<Value>>> {
self.fetch_scroll(conn, scroll::ABSOLUTE, pos)
}
pub fn fetch_relative(
&mut self,
conn: &mut Connection,
offset: i32,
) -> Result<Option<Vec<Value>>> {
self.fetch_scroll(conn, scroll::RELATIVE, offset)
}
pub fn rows_affected(&self, conn: &mut Connection) -> Result<RowsAffected> {
let w = crate::connection::info_request(
op::INFO_SQL,
self.handle,
&[isql::RECORDS],
RECORDS_BUFFER_LEN,
);
conn.io().send(&w)?;
let resp = read_response(conn.io())?;
Ok(parse_records(&resp.data))
}
pub fn close(&mut self, conn: &mut Connection) -> Result<()> {
if !self.cursor_open {
return Ok(());
}
self.free(conn, free::CLOSE)?;
self.cursor_open = false;
Ok(())
}
pub fn drop_statement(mut self, conn: &mut Connection) -> Result<()> {
self.free(conn, free::DROP)?;
self.dropped = true;
Ok(())
}
fn free(&mut self, conn: &mut Connection, mode: i32) -> Result<()> {
let mut w = op_packet(op::FREE_STATEMENT);
w.put_i32(self.handle);
w.put_i32(mode);
conn.io().send(&w)?;
read_response(conn.io())?;
Ok(())
}
pub(crate) fn forget_handle(&mut self) {
self.dropped = true;
}
}
impl Drop for Statement {
fn drop(&mut self) {
if !self.dropped {
crate::warn_unclosed("Statement", self.handle);
}
}
}
pub struct RowStream<'a> {
stmt: &'a mut Statement,
conn: &'a mut Connection,
}
impl RowStream<'_> {
pub fn try_next(&mut self) -> Result<Option<Vec<Value>>> {
self.stmt.fetch(self.conn)
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Result<Option<Vec<Value>>> {
self.try_next()
}
pub fn try_collect(mut self) -> Result<Vec<Vec<Value>>> {
let mut rows = Vec::new();
while let Some(row) = self.try_next()? {
rows.push(row);
}
Ok(rows)
}
pub fn try_for_each<F>(mut self, mut f: F) -> Result<()>
where
F: FnMut(Vec<Value>) -> Result<()>,
{
while let Some(row) = self.try_next()? {
f(row)?;
}
Ok(())
}
}
impl Connection {
pub fn prepare(&mut self, tx: &Transaction, sql: &str) -> Result<Statement> {
let mut w = op_packet(op::ALLOCATE_STATEMENT);
w.put_i32(self.db_handle());
self.io().send(&w)?;
let handle = read_response(self.io())?.handle;
let mut w = op_packet(op::PREPARE_STATEMENT);
w.put_i32(tx.handle());
w.put_i32(handle);
w.put_i32(SQL_DIALECT);
w.put_str(sql);
w.put_bytes(prepare_info_items());
w.put_i32(INFO_BUFFER_LEN);
self.io().send(&w)?;
let resp = read_response(self.io())?;
let info = parse_prepare_response(&resp.data)?;
Ok(Statement {
handle,
stmt_type: info.stmt_type,
params: info.params,
columns: info.columns,
cursor_open: false,
buffered: std::collections::VecDeque::new(),
exhausted: false,
scrollable: false,
fetch_size: FETCH_BATCH,
dropped: false,
})
}
}
struct PreparedInfo {
stmt_type: i32,
params: Vec<ColumnMeta>,
columns: Vec<ColumnMeta>,
}
#[derive(Clone, Copy, PartialEq)]
enum Block {
None,
Bind,
Select,
}
fn parse_prepare_response(data: &[u8]) -> Result<PreparedInfo> {
let mut stmt_type = 0;
let mut params = Vec::new();
let mut columns = Vec::new();
let mut block = Block::None;
let mut cur: Option<ColumnMeta> = None;
let mut i = 0;
while i < data.len() {
let tag = data[i];
i += 1;
match tag {
INFO_END => break,
INFO_TRUNCATED => {
return Err(Error::protocol(
"prepare describe-info truncated; buffer too small",
));
}
isql::SELECT => block = Block::Select,
isql::BIND => block = Block::Bind,
isql::DESCRIBE_END => {
if let Some(c) = cur.take() {
match block {
Block::Bind => params.push(c),
Block::Select => columns.push(c),
Block::None => {}
}
}
}
_ => {
if i + 2 > data.len() {
return Err(Error::protocol("prepare describe-info: short length"));
}
let len = u16::from_le_bytes([data[i], data[i + 1]]) as usize;
i += 2;
if i + len > data.len() {
return Err(Error::protocol("prepare describe-info: short value"));
}
let val = &data[i..i + len];
i += len;
apply_info_item(tag, val, &mut stmt_type, &mut cur);
}
}
}
Ok(PreparedInfo {
stmt_type,
params,
columns,
})
}
fn apply_info_item(tag: u8, val: &[u8], stmt_type: &mut i32, cur: &mut Option<ColumnMeta>) {
match tag {
isql::STMT_TYPE => *stmt_type = read_le_int(val) as i32,
isql::SQLDA_SEQ => {
let seq = read_le_int(val) as usize;
*cur = Some(ColumnMeta {
index: seq.saturating_sub(1),
..Default::default()
});
}
isql::TYPE => {
if let Some(c) = cur.as_mut() {
let t = read_le_int(val) as i32;
c.sql_type = t;
c.nullable = sql_type::is_nullable(t);
}
}
isql::SUB_TYPE => {
if let Some(c) = cur.as_mut() {
c.sub_type = read_le_int_signed(val) as i32;
}
}
isql::SCALE => {
if let Some(c) = cur.as_mut() {
c.scale = read_le_int_signed(val) as i32;
}
}
isql::LENGTH => {
if let Some(c) = cur.as_mut() {
c.length = read_le_int(val) as i32;
}
}
isql::FIELD => set_name(cur, val, |c, s| c.field = s),
isql::RELATION => set_name(cur, val, |c, s| c.relation = s),
isql::ALIAS => set_name(cur, val, |c, s| c.alias = s),
isql::OWNER => set_name(cur, val, |c, s| c.owner = s),
_ => {}
}
}
fn set_name(cur: &mut Option<ColumnMeta>, val: &[u8], assign: impl Fn(&mut ColumnMeta, String)) {
if let Some(c) = cur.as_mut() {
assign(c, String::from_utf8_lossy(val).into_owned());
}
}
fn parse_records(data: &[u8]) -> RowsAffected {
let mut out = RowsAffected::default();
for (tag, val) in InfoItems::new(data) {
if tag == isql::RECORDS {
for (sub, v) in InfoItems::new(val) {
let n = read_le_int(v) as u64;
match sub {
info_req::SELECT_COUNT => out.selected = n,
info_req::INSERT_COUNT => out.inserted = n,
info_req::UPDATE_COUNT => out.updated = n,
info_req::DELETE_COUNT => out.deleted = n,
_ => {}
}
}
}
}
out
}
struct InfoItems<'a> {
data: &'a [u8],
pos: usize,
}
impl<'a> InfoItems<'a> {
fn new(data: &'a [u8]) -> Self {
InfoItems { data, pos: 0 }
}
}
impl<'a> Iterator for InfoItems<'a> {
type Item = (u8, &'a [u8]);
fn next(&mut self) -> Option<Self::Item> {
let tag = *self.data.get(self.pos)?;
if tag == INFO_END {
return None;
}
self.pos += 1;
let lo = *self.data.get(self.pos)? as usize;
let hi = *self.data.get(self.pos + 1)? as usize;
let len = lo | (hi << 8);
self.pos += 2;
let val = self.data.get(self.pos..self.pos + len)?;
self.pos += len;
Some((tag, val))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn item(tag: u8, val: &[u8]) -> Vec<u8> {
let mut v = vec![tag];
v.extend_from_slice(&(val.len() as u16).to_le_bytes());
v.extend_from_slice(val);
v
}
#[test]
fn parses_select_describe_for_two_columns() {
let mut data = Vec::new();
data.extend(item(isql::STMT_TYPE, &stmt_type::SELECT.to_le_bytes()));
data.push(isql::BIND);
data.extend(item(isql::DESCRIBE_VARS, &0i32.to_le_bytes()));
data.push(isql::SELECT);
data.extend(item(isql::DESCRIBE_VARS, &2i32.to_le_bytes()));
data.extend(item(isql::SQLDA_SEQ, &1i32.to_le_bytes()));
data.extend(item(isql::TYPE, &(sql_type::SHORT | 1).to_le_bytes())); data.extend(item(isql::SUB_TYPE, &0i32.to_le_bytes()));
data.extend(item(isql::SCALE, &0i32.to_le_bytes()));
data.extend(item(isql::LENGTH, &2i32.to_le_bytes()));
data.extend(item(isql::FIELD, b"EMP_NO"));
data.extend(item(isql::ALIAS, b"EMP_NO"));
data.push(isql::DESCRIBE_END);
data.extend(item(isql::SQLDA_SEQ, &2i32.to_le_bytes()));
data.extend(item(isql::TYPE, &sql_type::VARYING.to_le_bytes()));
data.extend(item(isql::SUB_TYPE, &0i32.to_le_bytes()));
data.extend(item(isql::SCALE, &0i32.to_le_bytes()));
data.extend(item(isql::LENGTH, &15i32.to_le_bytes()));
data.extend(item(isql::FIELD, b"FIRST_NAME"));
data.extend(item(isql::ALIAS, b"FIRST_NAME"));
data.push(isql::DESCRIBE_END);
data.push(INFO_END);
let info = parse_prepare_response(&data).unwrap();
assert_eq!(info.stmt_type, stmt_type::SELECT);
assert!(info.params.is_empty());
assert_eq!(info.columns.len(), 2);
let emp_no = &info.columns[0];
assert_eq!(emp_no.index, 0);
assert_eq!(sql_type::base(emp_no.sql_type), sql_type::SHORT);
assert!(emp_no.nullable);
assert_eq!(emp_no.name(), "EMP_NO");
let first_name = &info.columns[1];
assert_eq!(sql_type::base(first_name.sql_type), sql_type::VARYING);
assert_eq!(first_name.length, 15);
assert!(!first_name.nullable);
assert_eq!(first_name.name(), "FIRST_NAME");
}
#[test]
fn truncated_info_is_an_error() {
let data = [INFO_TRUNCATED];
assert!(parse_prepare_response(&data).is_err());
}
#[test]
fn parses_record_counts() {
fn sub(tag: u8, n: i32) -> Vec<u8> {
let mut v = vec![tag, 4, 0]; v.extend_from_slice(&n.to_le_bytes());
v
}
let mut nested = Vec::new();
nested.extend(sub(info_req::SELECT_COUNT, 5));
nested.extend(sub(info_req::INSERT_COUNT, 0));
nested.extend(sub(info_req::UPDATE_COUNT, 5));
nested.extend(sub(info_req::DELETE_COUNT, 0));
let mut data = vec![isql::RECORDS];
data.extend_from_slice(&(nested.len() as u16).to_le_bytes());
data.extend_from_slice(&nested);
data.push(INFO_END);
let r = parse_records(&data);
assert_eq!(r.selected, 5);
assert_eq!(r.updated, 5);
assert_eq!(r.inserted, 0);
assert_eq!(r.deleted, 0);
assert_eq!(r.total_modified(), 5);
}
#[test]
fn negative_scale_is_sign_extended() {
let mut data = Vec::new();
data.push(isql::SELECT);
data.extend(item(isql::SQLDA_SEQ, &1i32.to_le_bytes()));
data.extend(item(isql::TYPE, &sql_type::INT64.to_le_bytes()));
data.extend(item(isql::SCALE, &(-2i32).to_le_bytes()));
data.extend(item(isql::LENGTH, &8i32.to_le_bytes()));
data.push(isql::DESCRIBE_END);
data.push(INFO_END);
let info = parse_prepare_response(&data).unwrap();
assert_eq!(info.columns[0].scale, -2);
}
}