use std::{borrow::Cow, sync::Arc};
use base64::{engine::general_purpose::STANDARD as STD_BASE64_ENGINE, Engine};
use rsa::{Pkcs1v15Encrypt, RsaPublicKey};
use serde::{Serialize, Serializer};
use serde_json::value::RawValue;
use sqlx_core::sql_str::SqlStr;
use crate::{
arguments::ExaBuffer, options::ProtocolVersion, responses::ExaRwAttributes, ExaAttributes,
ExaTypeInfo, SqlxError, SqlxResult,
};
pub struct WithAttributes<'attr, REQ> {
needs_send: bool,
attributes: &'attr ExaRwAttributes<'static>,
request: &'attr mut REQ,
}
impl<'attr, REQ> WithAttributes<'attr, REQ> {
pub fn new(request: &'attr mut REQ, attributes: &'attr ExaAttributes) -> Self {
Self {
needs_send: attributes.needs_send(),
attributes: attributes.read_write(),
request,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct LoginCreds(pub ProtocolVersion);
impl Serialize for WithAttributes<'_, LoginCreds> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::Login {
protocol_version: self.request.0,
};
command.serialize(serializer)
}
}
#[derive(Clone, Copy, Debug)]
pub struct LoginToken(pub ProtocolVersion);
impl Serialize for WithAttributes<'_, LoginToken> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::LoginToken {
protocol_version: self.request.0,
};
command.serialize(serializer)
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Disconnect;
impl Serialize for WithAttributes<'_, Disconnect> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Command::Disconnect.serialize(serializer)
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct GetAttributes;
impl Serialize for WithAttributes<'_, GetAttributes> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Command::GetAttributes.serialize(serializer)
}
}
#[derive(Clone, Debug, Default)]
pub struct SetAttributes;
impl Serialize for WithAttributes<'_, SetAttributes> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::SetAttributes {
attributes: self.attributes,
};
command.serialize(serializer)
}
}
#[derive(Clone, Debug)]
pub struct CloseResultSets(pub Vec<u16>);
impl Serialize for WithAttributes<'_, CloseResultSets> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::CloseResultSet {
attributes: self.needs_send.then_some(self.attributes),
result_set_handles: &self.request.0,
};
command.serialize(serializer)
}
}
#[derive(Copy, Clone, Debug)]
pub struct ClosePreparedStmt(pub u16);
impl Serialize for WithAttributes<'_, ClosePreparedStmt> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::ClosePreparedStatement {
attributes: self.needs_send.then_some(self.attributes),
statement_handle: self.request.0,
};
command.serialize(serializer)
}
}
#[cfg(feature = "etl")]
#[derive(Clone, Copy, Debug)]
pub struct GetHosts(pub std::net::IpAddr);
#[cfg(feature = "etl")]
impl Serialize for WithAttributes<'_, GetHosts> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Command::GetHosts {
attributes: self.needs_send.then_some(self.attributes),
host_ip: self.request.0,
}
.serialize(serializer)
}
}
#[derive(Clone, Copy, Debug)]
pub struct Fetch {
result_set_handle: u16,
start_position: usize,
num_bytes: usize,
}
impl Fetch {
pub fn new(result_set_handle: u16, start_position: usize, num_bytes: usize) -> Self {
Self {
result_set_handle,
start_position,
num_bytes,
}
}
}
impl Serialize for WithAttributes<'_, Fetch> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::Fetch {
attributes: self.needs_send.then_some(self.attributes),
result_set_handle: self.request.result_set_handle,
start_position: self.request.start_position,
num_bytes: self.request.num_bytes,
};
command.serialize(serializer)
}
}
#[derive(Clone, Debug)]
pub struct Execute(pub SqlStr);
impl Serialize for WithAttributes<'_, Execute> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::Execute {
attributes: self.needs_send.then_some(self.attributes),
sql_text: self.request.0.as_str(),
};
command.serialize(serializer)
}
}
#[derive(Clone, Debug)]
pub struct ExecuteBatch(pub SqlStr);
impl ExecuteBatch {
fn split_query(&self) -> Vec<&str> {
#[derive(Clone, Copy)]
enum Inside {
Statement,
LineComment,
BlockComment,
DoubleQuote,
SingleQuote,
Whitespace,
}
let query = self.0.as_str().trim();
let mut chars = query.char_indices().peekable();
let mut state = Inside::Statement;
let mut statements = Vec::new();
let mut start = 0;
while let Some((i, c)) = chars.next() {
let mut peek = || chars.peek().map(|(_, c)| *c);
let is_whitespace = |p: Option<char>| p.is_some_and(char::is_whitespace);
#[allow(clippy::match_same_arms, reason = "better readability if split")]
match (state, c) {
(Inside::Statement, '-') if Some('-') == peek() => {
chars.next();
state = Inside::LineComment;
}
(Inside::Statement, '/') if Some('*') == peek() => {
chars.next();
state = Inside::BlockComment;
}
(Inside::Statement, '"') => state = Inside::DoubleQuote,
(Inside::Statement, '\'') => state = Inside::SingleQuote,
(Inside::Statement, ';') => {
statements.push(&query[start..=i]);
start = i + 1;
if is_whitespace(peek()) {
state = Inside::Whitespace;
}
}
(Inside::DoubleQuote, '"') if Some('"') == peek() => {
chars.next();
}
(Inside::SingleQuote, '\'') if Some('\'') == peek() => {
chars.next();
}
(Inside::DoubleQuote, '"') => state = Inside::Statement,
(Inside::SingleQuote, '\'') => state = Inside::Statement,
(Inside::LineComment, '\n') => state = Inside::Statement,
(Inside::BlockComment, '*') if Some('/') == peek() => {
chars.next();
state = Inside::Statement;
}
(Inside::Whitespace, _) if !is_whitespace(peek()) => {
start = i + 1;
state = Inside::Statement;
}
_ => (),
}
}
let remaining = &query[start..];
if !remaining.is_empty() {
statements.push(remaining);
}
statements
}
}
impl Serialize for WithAttributes<'_, ExecuteBatch> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::ExecuteBatch {
attributes: self.needs_send.then_some(self.attributes),
sql_texts: &self.request.split_query(),
};
command.serialize(serializer)
}
}
#[derive(Clone, Debug)]
pub struct CreatePreparedStmt(pub SqlStr);
impl Serialize for WithAttributes<'_, CreatePreparedStmt> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::CreatePreparedStatement {
attributes: self.needs_send.then_some(self.attributes),
sql_text: self.request.0.as_str(),
};
command.serialize(serializer)
}
}
#[derive(Clone, Debug)]
pub struct ExecutePreparedStmt {
statement_handle: u16,
num_columns: usize,
num_rows: usize,
columns: Arc<[ExaTypeInfo]>,
data: PreparedStmtData,
}
impl ExecutePreparedStmt {
pub fn new(handle: u16, columns: Arc<[ExaTypeInfo]>, data: ExaBuffer) -> Self {
Self {
statement_handle: handle,
num_columns: columns.len(),
num_rows: data.num_param_sets(),
columns,
data: data.into(),
}
}
}
impl Serialize for WithAttributes<'_, ExecutePreparedStmt> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let command = Command::ExecutePreparedStatement {
attributes: self.needs_send.then_some(self.attributes),
statement_handle: self.request.statement_handle,
num_columns: self.request.num_columns,
num_rows: self.request.num_rows,
columns: &self.request.columns,
data: &self.request.data,
};
command.serialize(serializer)
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ExaLoginRequest<'a> {
#[serde(skip_serializing)]
pub protocol_version: ProtocolVersion,
#[serde(skip_serializing)]
pub fetch_size: usize,
#[serde(skip_serializing)]
pub statement_cache_capacity: usize,
#[serde(flatten)]
pub login: LoginRef<'a>,
pub use_compression: bool,
pub client_name: &'static str,
pub client_version: &'static str,
pub client_os: &'static str,
pub client_runtime: &'static str,
pub attributes: ExaRwAttributes<'a>,
}
impl ExaLoginRequest<'_> {
pub fn encrypt_password(&mut self, key: &RsaPublicKey) -> SqlxResult<()> {
let LoginRef::Credentials { password, .. } = &mut self.login else {
return Ok(());
};
let enc_pass = key
.encrypt(
&mut rand::thread_rng(),
Pkcs1v15Encrypt,
password.as_bytes(),
)
.map(|pass| STD_BASE64_ENGINE.encode(pass))
.map(Cow::Owned)
.map_err(|e| SqlxError::Protocol(e.to_string()))?;
let _ = std::mem::replace(password, enc_pass);
Ok(())
}
}
impl Serialize for WithAttributes<'_, ExaLoginRequest<'_>> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.request.serialize(serializer)
}
}
#[derive(Clone, Debug, Serialize)]
#[serde(untagged)]
pub enum LoginRef<'a> {
#[serde(rename_all = "camelCase")]
Credentials {
username: &'a str,
password: Cow<'a, str>,
},
#[serde(rename_all = "camelCase")]
AccessToken { access_token: &'a str },
#[serde(rename_all = "camelCase")]
RefreshToken { refresh_token: &'a str },
}
#[derive(Clone, Debug, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(tag = "command")]
enum Command<'a> {
#[serde(rename_all = "camelCase")]
Login {
protocol_version: ProtocolVersion,
},
#[serde(rename_all = "camelCase")]
LoginToken {
protocol_version: ProtocolVersion,
},
Disconnect,
GetAttributes,
#[serde(rename_all = "camelCase")]
SetAttributes {
attributes: &'a ExaRwAttributes<'static>,
},
#[serde(rename_all = "camelCase")]
CloseResultSet {
#[serde(skip_serializing_if = "Option::is_none")]
attributes: Option<&'a ExaRwAttributes<'static>>,
result_set_handles: &'a [u16],
},
#[serde(rename_all = "camelCase")]
ClosePreparedStatement {
#[serde(skip_serializing_if = "Option::is_none")]
attributes: Option<&'a ExaRwAttributes<'static>>,
statement_handle: u16,
},
#[cfg(feature = "etl")]
#[serde(rename_all = "camelCase")]
GetHosts {
#[serde(skip_serializing_if = "Option::is_none")]
attributes: Option<&'a ExaRwAttributes<'static>>,
host_ip: std::net::IpAddr,
},
#[serde(rename_all = "camelCase")]
Fetch {
#[serde(skip_serializing_if = "Option::is_none")]
attributes: Option<&'a ExaRwAttributes<'static>>,
result_set_handle: u16,
start_position: usize,
num_bytes: usize,
},
#[serde(rename_all = "camelCase")]
Execute {
#[serde(skip_serializing_if = "Option::is_none")]
attributes: Option<&'a ExaRwAttributes<'static>>,
sql_text: &'a str,
},
#[serde(rename_all = "camelCase")]
ExecuteBatch {
#[serde(skip_serializing_if = "Option::is_none")]
attributes: Option<&'a ExaRwAttributes<'static>>,
sql_texts: &'a [&'a str],
},
#[serde(rename_all = "camelCase")]
CreatePreparedStatement {
#[serde(skip_serializing_if = "Option::is_none")]
attributes: Option<&'a ExaRwAttributes<'static>>,
sql_text: &'a str,
},
#[serde(rename_all = "camelCase")]
ExecutePreparedStatement {
#[serde(skip_serializing_if = "Option::is_none")]
attributes: Option<&'a ExaRwAttributes<'static>>,
statement_handle: u16,
num_columns: usize,
num_rows: usize,
#[serde(skip_serializing_if = "<[ExaTypeInfo]>::is_empty")]
#[serde(serialize_with = "serialize_params")]
columns: &'a [ExaTypeInfo],
#[serde(skip_serializing_if = "PreparedStmtData::is_empty")]
data: &'a PreparedStmtData,
},
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct PreparedStmtParam<'a> {
data_type: &'a ExaTypeInfo,
}
impl<'a> From<&'a ExaTypeInfo> for PreparedStmtParam<'a> {
fn from(data_type: &'a ExaTypeInfo) -> Self {
Self { data_type }
}
}
fn serialize_params<S>(params: &[ExaTypeInfo], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.collect_seq(params.iter().map(PreparedStmtParam::from))
}
#[derive(Debug, Clone)]
struct PreparedStmtData {
buffer: String,
num_rows: usize,
}
impl PreparedStmtData {
fn is_empty(&self) -> bool {
self.num_rows == 0
}
}
impl Serialize for PreparedStmtData {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
#[allow(clippy::transmute_ptr_to_ptr)]
unsafe { std::mem::transmute::<&str, &RawValue>(&self.buffer) }.serialize(serializer)
}
}
impl From<ExaBuffer> for PreparedStmtData {
fn from(value: ExaBuffer) -> Self {
Self {
num_rows: value.num_param_sets(),
buffer: value.finish(),
}
}
}
#[cfg(test)]
mod tests {
use sqlx_core::sql_str::SqlStr;
use super::ExecuteBatch;
#[test]
fn test_simple_statements() {
assert_eq!(
ExecuteBatch(SqlStr::from_static(
"SELECT * FROM users; SELECT * FROM orders;"
))
.split_query(),
vec!["SELECT * FROM users;", "SELECT * FROM orders;"]
);
}
#[test]
fn test_semicolon_in_single_quote() {
assert_eq!(
ExecuteBatch(SqlStr::from_static(
"SELECT ';' AS val; SELECT 'abc;def' AS val2;"
))
.split_query(),
vec!["SELECT ';' AS val;", "SELECT 'abc;def' AS val2;"]
);
}
#[test]
fn test_semicolon_in_double_quote() {
assert_eq!(
ExecuteBatch(SqlStr::from_static("SELECT \"col;name\" FROM table;")).split_query(),
vec!["SELECT \"col;name\" FROM table;"]
);
}
#[test]
fn test_semicolon_in_line_comment() {
assert_eq!(
ExecuteBatch(SqlStr::from_static(
"SELECT 1; -- this is a comment; with a semicolon\nSELECT 2;"
))
.split_query(),
vec![
"SELECT 1;",
"-- this is a comment; with a semicolon\nSELECT 2;"
]
);
}
#[test]
fn test_semicolon_in_block_comment() {
assert_eq!(
ExecuteBatch(SqlStr::from_static(
"SELECT 1; /* multi-line ; comment */ SELECT 2;"
))
.split_query(),
vec!["SELECT 1;", "/* multi-line ; comment */ SELECT 2;"]
);
}
#[test]
fn test_escaped_quotes() {
assert_eq!(
ExecuteBatch(SqlStr::from_static(
"SELECT 'It''s a test; really'; SELECT \"escaped\"\"quote\" FROM dual;"
))
.split_query(),
vec![
"SELECT 'It''s a test; really';",
"SELECT \"escaped\"\"quote\" FROM dual;"
]
);
}
#[test]
fn test_trailing_semicolon_and_whitespace() {
assert_eq!(
ExecuteBatch(SqlStr::from_static("SELECT 1;; \n \n;")).split_query(),
vec!["SELECT 1;", ";", ";"]
);
}
#[test]
fn test_leading_semicolon() {
assert_eq!(
ExecuteBatch(SqlStr::from_static(";SELECT 1;")).split_query(),
vec![";", "SELECT 1;"]
);
}
#[test]
fn test_leading_semicolon_and_whitespace() {
assert_eq!(
ExecuteBatch(SqlStr::from_static(" ; SELECT 1;")).split_query(),
vec![";", "SELECT 1;"]
);
}
#[test]
fn test_no_semicolon() {
assert_eq!(
ExecuteBatch(SqlStr::from_static("SELECT 1")).split_query(),
vec!["SELECT 1"]
);
}
#[test]
fn test_no_whitespace_between_statements() {
assert_eq!(
ExecuteBatch(SqlStr::from_static("SELECT 1;SELECT 2")).split_query(),
vec!["SELECT 1;", "SELECT 2"]
);
}
#[test]
fn test_no_whitespace_between_stmt_and_comment() {
assert_eq!(
ExecuteBatch(SqlStr::from_static("SELECT 1;/*testing*/SELECT 2;")).split_query(),
vec!["SELECT 1;", "/*testing*/SELECT 2;"]
);
}
#[test]
fn test_trailing_comment() {
assert_eq!(
ExecuteBatch(SqlStr::from_static("SELECT 1;/*testing*/")).split_query(),
vec!["SELECT 1;", "/*testing*/"]
);
}
#[test]
fn test_whitespace_between_statements() {
let query = "
/* Writing some comments */
SELECT 1;
-- Then writing some more comments
SELECT 2;
";
assert_eq!(
ExecuteBatch(SqlStr::from_static(query)).split_query(),
vec![
"/* Writing some comments */
SELECT 1;",
"-- Then writing some more comments
SELECT 2;"
]
);
}
#[test]
fn test_empty_input() {
assert_eq!(
ExecuteBatch(SqlStr::from_static("")).split_query(),
Vec::<&str>::new()
);
}
#[test]
fn test_mixed_content() {
let query = r#"
SELECT 'test;--'; -- line comment with ;
/* block comment ;
over lines */
SELECT "str;with;semicolons";
"#;
assert_eq!(
ExecuteBatch(SqlStr::from_static(query)).split_query(),
vec![
"SELECT 'test;--';",
r#"-- line comment with ;
/* block comment ;
over lines */
SELECT "str;with;semicolons";"#
]
);
}
}