use bytes::BytesMut;
use qail_core::ast::{Action, AggregateFunc, Expr, Qail};
use super::dml::{encode_delete, encode_export, encode_insert, encode_select, encode_update};
use crate::protocol::EncodeError;
#[inline(always)]
fn validate_sql_bytes(sql: &[u8]) -> Result<(), EncodeError> {
if sql.contains(&0) {
return Err(EncodeError::NullByte);
}
Ok(())
}
#[inline(always)]
fn checked_i16_count(count: usize) -> Result<i16, EncodeError> {
i16::try_from(count).map_err(|_| EncodeError::TooManyParameters(count))
}
#[inline(always)]
fn checked_i32_len(len: usize) -> Result<i32, EncodeError> {
i32::try_from(len).map_err(|_| EncodeError::MessageTooLarge(len))
}
#[inline(always)]
fn checked_wire_len(content_len: usize) -> Result<i32, EncodeError> {
let total = content_len
.checked_add(4)
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
checked_i32_len(total)
}
#[inline(always)]
fn params_wire_size(params: &[Option<Vec<u8>>]) -> Result<usize, EncodeError> {
params.iter().try_fold(0usize, |acc, p| {
let data_len = p.as_ref().map_or(0, |v| v.len());
if data_len > i32::MAX as usize {
return Err(EncodeError::MessageTooLarge(data_len));
}
let field_size = 4usize
.checked_add(data_len)
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
acc.checked_add(field_size)
.ok_or(EncodeError::MessageTooLarge(usize::MAX))
})
}
#[inline(always)]
fn result_format_wire_len(result_format: i16) -> usize {
if result_format == 0 { 2 } else { 4 }
}
#[inline(always)]
fn write_result_formats(buf: &mut BytesMut, result_format: i16) {
if result_format == 0 {
buf.extend_from_slice(&0i16.to_be_bytes());
} else {
buf.extend_from_slice(&1i16.to_be_bytes());
buf.extend_from_slice(&result_format.to_be_bytes());
}
}
pub fn build_extended_query(
sql: &[u8],
params: &[Option<Vec<u8>>],
) -> Result<BytesMut, EncodeError> {
build_extended_query_with_result_format(sql, params, 0)
}
pub fn build_extended_query_with_result_format(
sql: &[u8],
params: &[Option<Vec<u8>>],
result_format: i16,
) -> Result<BytesMut, EncodeError> {
if params.len() > i16::MAX as usize {
return Err(EncodeError::TooManyParameters(params.len()));
}
validate_sql_bytes(sql)?;
let params_size = params_wire_size(params)?;
let result_formats_size = result_format_wire_len(result_format);
let total_size = 9usize
.checked_add(sql.len())
.and_then(|v| v.checked_add(11))
.and_then(|v| v.checked_add(params_size))
.and_then(|v| v.checked_add(result_formats_size))
.and_then(|v| v.checked_add(6))
.and_then(|v| v.checked_add(10))
.and_then(|v| v.checked_add(5))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let mut buf = BytesMut::with_capacity(total_size);
build_extended_query_into_with_result_format(&mut buf, sql, params, result_format)?;
Ok(buf)
}
pub fn build_extended_query_into(
buf: &mut BytesMut,
sql: &[u8],
params: &[Option<Vec<u8>>],
) -> Result<(), EncodeError> {
build_extended_query_into_with_result_format(buf, sql, params, 0)
}
pub fn build_extended_query_into_with_result_format(
buf: &mut BytesMut,
sql: &[u8],
params: &[Option<Vec<u8>>],
result_format: i16,
) -> Result<(), EncodeError> {
if params.len() > i16::MAX as usize {
return Err(EncodeError::TooManyParameters(params.len()));
}
validate_sql_bytes(sql)?;
let params_size = params_wire_size(params)?;
let result_formats_size = result_format_wire_len(result_format);
let total_size = 9usize
.checked_add(sql.len())
.and_then(|v| v.checked_add(11))
.and_then(|v| v.checked_add(params_size))
.and_then(|v| v.checked_add(result_formats_size))
.and_then(|v| v.checked_add(6))
.and_then(|v| v.checked_add(10))
.and_then(|v| v.checked_add(5))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let param_count = checked_i16_count(params.len())?;
buf.clear();
buf.reserve(total_size);
buf.extend_from_slice(b"P");
let parse_content_len = 1usize
.checked_add(sql.len())
.and_then(|v| v.checked_add(1))
.and_then(|v| v.checked_add(2))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let parse_len = checked_wire_len(parse_content_len)?;
buf.extend_from_slice(&parse_len.to_be_bytes());
buf.extend_from_slice(&[0]); buf.extend_from_slice(sql);
buf.extend_from_slice(&[0]); buf.extend_from_slice(&0i16.to_be_bytes());
buf.extend_from_slice(b"B");
let bind_content_len = 1usize
.checked_add(1)
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(params_size))
.and_then(|v| v.checked_add(result_formats_size))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let bind_len = checked_wire_len(bind_content_len)?;
buf.extend_from_slice(&bind_len.to_be_bytes());
buf.extend_from_slice(&[0]); buf.extend_from_slice(&[0]); buf.extend_from_slice(&0i16.to_be_bytes()); buf.extend_from_slice(¶m_count.to_be_bytes());
for param in params {
match param {
None => buf.extend_from_slice(&(-1i32).to_be_bytes()),
Some(data) => {
let data_len = checked_i32_len(data.len())?;
buf.extend_from_slice(&data_len.to_be_bytes());
buf.extend_from_slice(data);
}
}
}
write_result_formats(buf, result_format);
buf.extend_from_slice(b"D");
buf.extend_from_slice(&6i32.to_be_bytes()); buf.extend_from_slice(b"P"); buf.extend_from_slice(&[0]);
buf.extend_from_slice(b"E");
buf.extend_from_slice(&9i32.to_be_bytes());
buf.extend_from_slice(&[0]); buf.extend_from_slice(&0i32.to_be_bytes());
buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
Ok(())
}
pub fn encode_batch(cmds: &[Qail]) -> Result<BytesMut, EncodeError> {
encode_batch_with_result_format(cmds, 0)
}
pub fn encode_batch_with_result_format(
cmds: &[Qail],
result_format: i16,
) -> Result<BytesMut, EncodeError> {
let mut total_buf = BytesMut::with_capacity(cmds.len() * 256);
let result_formats_size = result_format_wire_len(result_format);
for cmd in cmds {
let mut sql_buf = BytesMut::with_capacity(256);
let mut params: Vec<Option<Vec<u8>>> = Vec::new();
match cmd.action {
Action::Get => encode_select(cmd, &mut sql_buf, &mut params),
Action::Add => encode_insert(cmd, &mut sql_buf, &mut params),
Action::Set => encode_update(cmd, &mut sql_buf, &mut params),
Action::Del => encode_delete(cmd, &mut sql_buf, &mut params),
Action::Cnt => {
let mut count_cmd = cmd.clone();
count_cmd.action = Action::Get;
count_cmd.columns = vec![Expr::Aggregate {
col: "*".to_string(),
func: AggregateFunc::Count,
distinct: false,
filter: None,
alias: None,
}];
encode_select(&count_cmd, &mut sql_buf, &mut params)
}
Action::Export => encode_export(cmd, &mut sql_buf, &mut params),
_ => {
super::AstEncoder::encode_cmd_sql_to(cmd, &mut sql_buf, &mut params)?;
Ok(())
}
}?;
let sql_bytes = sql_buf.freeze();
validate_sql_bytes(&sql_bytes)?;
let params_size = params_wire_size(¶ms)?;
let param_count = checked_i16_count(params.len())?;
total_buf.extend_from_slice(b"P");
let parse_content_len = 1usize
.checked_add(sql_bytes.len())
.and_then(|v| v.checked_add(1))
.and_then(|v| v.checked_add(2))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let parse_len = checked_wire_len(parse_content_len)?;
total_buf.extend_from_slice(&parse_len.to_be_bytes());
total_buf.extend_from_slice(&[0]);
total_buf.extend_from_slice(&sql_bytes);
total_buf.extend_from_slice(&[0]);
total_buf.extend_from_slice(&0i16.to_be_bytes());
total_buf.extend_from_slice(b"B");
let bind_content_len = 1usize
.checked_add(1)
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(params_size))
.and_then(|v| v.checked_add(result_formats_size))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let bind_len = checked_wire_len(bind_content_len)?;
total_buf.extend_from_slice(&bind_len.to_be_bytes());
total_buf.extend_from_slice(&[0]);
total_buf.extend_from_slice(&[0]);
total_buf.extend_from_slice(&0i16.to_be_bytes());
total_buf.extend_from_slice(¶m_count.to_be_bytes());
for param in ¶ms {
match param {
None => total_buf.extend_from_slice(&(-1i32).to_be_bytes()),
Some(data) => {
let data_len = checked_i32_len(data.len())?;
total_buf.extend_from_slice(&data_len.to_be_bytes());
total_buf.extend_from_slice(data);
}
}
}
write_result_formats(&mut total_buf, result_format);
total_buf.extend_from_slice(b"E");
total_buf.extend_from_slice(&9i32.to_be_bytes());
total_buf.extend_from_slice(&[0]);
total_buf.extend_from_slice(&0i32.to_be_bytes());
}
total_buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
Ok(total_buf)
}
pub fn encode_batch_simple(cmds: &[Qail]) -> Result<BytesMut, EncodeError> {
let estimated_sql_size = cmds.len() * 48;
let mut total_buf = BytesMut::with_capacity(5 + estimated_sql_size + 1);
total_buf.extend_from_slice(&[b'Q', 0, 0, 0, 0]);
let mut params: Vec<Option<Vec<u8>>> = Vec::new();
for cmd in cmds {
params.clear();
match cmd.action {
Action::Get => encode_select(cmd, &mut total_buf, &mut params),
Action::Add => encode_insert(cmd, &mut total_buf, &mut params),
Action::Set => encode_update(cmd, &mut total_buf, &mut params),
Action::Del => encode_delete(cmd, &mut total_buf, &mut params),
Action::Cnt => {
let mut count_cmd = cmd.clone();
count_cmd.action = Action::Get;
count_cmd.columns = vec![Expr::Aggregate {
col: "*".to_string(),
func: AggregateFunc::Count,
distinct: false,
filter: None,
alias: None,
}];
encode_select(&count_cmd, &mut total_buf, &mut params)
}
Action::Export => encode_export(cmd, &mut total_buf, &mut params),
_ => {
super::AstEncoder::encode_cmd_sql_to(cmd, &mut total_buf, &mut params)?;
Ok(())
}
}?;
total_buf.extend_from_slice(b";");
}
if total_buf[5..].contains(&0) {
return Err(EncodeError::NullByte);
}
total_buf.extend_from_slice(&[0]);
let msg_len = i32::try_from(total_buf.len() - 1)
.map_err(|_| EncodeError::MessageTooLarge(total_buf.len() - 1))?;
total_buf[1..5].copy_from_slice(&msg_len.to_be_bytes());
Ok(total_buf)
}