use bytes::{BufMut, Bytes, BytesMut};
use crate::codec::write_utf16_string;
use crate::prelude::*;
#[must_use]
pub fn encode_sql_batch(sql: &str) -> Bytes {
encode_sql_batch_with_transaction(sql, 0)
}
#[must_use]
pub fn encode_sql_batch_with_transaction(sql: &str, transaction_descriptor: u64) -> Bytes {
let mut buf = BytesMut::with_capacity(22 + sql.len() * 2);
let all_headers_start = buf.len();
buf.put_u32_le(0);
buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1);
let all_headers_len = buf.len() - all_headers_start;
let len_bytes = (all_headers_len as u32).to_le_bytes();
buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
write_utf16_string(&mut buf, sql);
buf.freeze()
}
#[derive(Debug, Clone)]
pub struct SqlBatch {
sql: String,
}
impl SqlBatch {
#[must_use]
pub fn new(sql: impl Into<String>) -> Self {
Self { sql: sql.into() }
}
#[must_use]
pub fn sql(&self) -> &str {
&self.sql
}
#[must_use]
pub fn encode(&self) -> Bytes {
encode_sql_batch(&self.sql)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_encode_sql_batch() {
let sql = "SELECT 1";
let payload = encode_sql_batch(sql);
assert_eq!(payload.len(), 38);
assert_eq!(&payload[0..4], &[22, 0, 0, 0]);
assert_eq!(&payload[4..8], &[18, 0, 0, 0]);
assert_eq!(&payload[8..10], &[0x02, 0x00]);
assert_eq!(payload[22], b'S');
assert_eq!(payload[23], 0);
assert_eq!(payload[24], b'E');
assert_eq!(payload[25], 0);
}
#[test]
fn test_sql_batch_builder() {
let batch = SqlBatch::new("SELECT @@VERSION");
assert_eq!(batch.sql(), "SELECT @@VERSION");
let payload = batch.encode();
assert!(!payload.is_empty());
}
#[test]
fn test_empty_batch() {
let payload = encode_sql_batch("");
assert_eq!(payload.len(), 22);
}
}