use serde::{Deserialize, Serialize};
use tokio_postgres::{Client, SimpleQueryMessage};
use uuid::Uuid;
use crate::postgres::PgError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnMeta {
pub name: String,
pub type_oid: u32,
pub type_name: String,
}
#[derive(Debug, Clone)]
pub struct ActiveCursor {
pub cursor_id: String,
pub column_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionOutcome {
pub columns: Vec<ColumnMeta>,
pub rows: Vec<Vec<Option<String>>>,
pub rows_affected: Option<u64>,
pub cursor_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PageResult {
pub rows: Vec<Vec<Option<String>>>,
pub has_more: bool,
}
pub(crate) fn is_multi_statement(sql: &str) -> bool {
enum LexState {
Normal,
SingleQuote,
DoubleQuote,
LineComment,
BlockComment,
}
let bytes = sql.as_bytes();
let mut i = 0usize;
let mut state = LexState::Normal;
while i < bytes.len() {
let c = bytes[i];
match state {
LexState::Normal => match c {
b'\'' => state = LexState::SingleQuote,
b'"' => state = LexState::DoubleQuote,
b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => {
state = LexState::LineComment;
i += 1;
}
b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => {
state = LexState::BlockComment;
i += 1;
}
b';' => {
let rest = &bytes[i + 1..];
if rest.iter().any(|b| !b.is_ascii_whitespace()) {
return true;
}
return false;
}
_ => {}
},
LexState::SingleQuote => {
if c == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 1; } else {
state = LexState::Normal;
}
}
}
LexState::DoubleQuote => {
if c == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 1; } else {
state = LexState::Normal;
}
}
}
LexState::LineComment => {
if c == b'\n' {
state = LexState::Normal;
}
}
LexState::BlockComment => {
if c == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
state = LexState::Normal;
i += 1;
}
}
}
i += 1;
}
false
}
pub(crate) fn split_at_last_statement(sql: &str) -> Option<(&str, &str)> {
enum LexState {
Normal,
SingleQuote,
DoubleQuote,
LineComment,
BlockComment,
}
let bytes = sql.as_bytes();
let mut i = 0usize;
let mut state = LexState::Normal;
let mut last_delim: Option<usize> = None;
while i < bytes.len() {
let c = bytes[i];
match state {
LexState::Normal => match c {
b'\'' => state = LexState::SingleQuote,
b'"' => state = LexState::DoubleQuote,
b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => {
state = LexState::LineComment;
i += 1;
}
b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => {
state = LexState::BlockComment;
i += 1;
}
b';' => last_delim = Some(i),
_ => {}
},
LexState::SingleQuote => {
if c == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 1;
} else {
state = LexState::Normal;
}
}
}
LexState::DoubleQuote => {
if c == b'"' {
if i + 1 < bytes.len() && bytes[i + 1] == b'"' {
i += 1;
} else {
state = LexState::Normal;
}
}
}
LexState::LineComment => {
if c == b'\n' {
state = LexState::Normal;
}
}
LexState::BlockComment => {
if c == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
state = LexState::Normal;
i += 1;
}
}
}
i += 1;
}
let split = last_delim?;
let main = &sql[split + 1..];
if is_effectively_empty(main) {
return None;
}
let main = main.trim();
let preamble = &sql[..split + 1]; if is_multi_statement(main) {
return None;
}
Some((preamble, main))
}
fn is_effectively_empty(sql: &str) -> bool {
enum LexState {
Normal,
LineComment,
BlockComment,
}
let bytes = sql.as_bytes();
let mut i = 0usize;
let mut state = LexState::Normal;
while i < bytes.len() {
let c = bytes[i];
match state {
LexState::Normal => match c {
b' ' | b'\t' | b'\n' | b'\r' => {}
b'-' if i + 1 < bytes.len() && bytes[i + 1] == b'-' => {
state = LexState::LineComment;
i += 1;
}
b'/' if i + 1 < bytes.len() && bytes[i + 1] == b'*' => {
state = LexState::BlockComment;
i += 1;
}
_ => return false,
},
LexState::LineComment => {
if c == b'\n' {
state = LexState::Normal;
}
}
LexState::BlockComment => {
if c == b'*' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
state = LexState::Normal;
i += 1;
}
}
}
i += 1;
}
true
}
async fn run_multi_statement(
client: &Client,
sql: &str,
page_size: usize,
) -> Result<ExecutionOutcome, PgError> {
let stream = client.simple_query(sql).await.map_err(PgError::Driver)?;
let mut current_columns: Vec<ColumnMeta> = vec![];
let mut current_rows: Vec<Vec<Option<String>>> = vec![];
let mut last_command_complete: Option<u64> = None;
for msg in stream {
match msg {
SimpleQueryMessage::RowDescription(desc) => {
current_columns = desc
.iter()
.map(|c| ColumnMeta {
name: c.name().to_string(),
type_oid: 0,
type_name: String::new(),
})
.collect();
current_rows.clear();
}
SimpleQueryMessage::Row(row) => {
if current_rows.len() >= page_size {
continue;
}
let width = current_columns.len();
let mut cells = Vec::with_capacity(width);
for idx in 0..width {
cells.push(row.get(idx).map(str::to_string));
}
current_rows.push(cells);
}
SimpleQueryMessage::CommandComplete(n) => {
last_command_complete = Some(n);
}
_ => {}
}
}
Ok(ExecutionOutcome {
columns: current_columns,
rows: current_rows,
rows_affected: last_command_complete,
cursor_id: None,
})
}
pub async fn open_query(
client: &Client,
sql: &str,
page_size: usize,
previous: Option<ActiveCursor>,
) -> Result<(ExecutionOutcome, Option<ActiveCursor>), PgError> {
if let Some(prev) = previous.as_ref() {
let cleanup = format!("CLOSE {}; COMMIT", prev.cursor_id);
if let Err(e) = client.batch_execute(&cleanup).await {
tracing::debug!(
cursor = %prev.cursor_id,
error = %e,
"previous cursor cleanup failed; continuing with ROLLBACK"
);
let _ = client.batch_execute("ROLLBACK").await;
}
}
if is_multi_statement(sql) {
if let Some((preamble, main)) = split_at_last_statement(sql) {
client
.batch_execute(preamble)
.await
.map_err(PgError::Driver)?;
return Box::pin(open_query(client, main, page_size, None)).await;
}
let outcome = run_multi_statement(client, sql, page_size).await?;
return Ok((outcome, None));
}
let stmt = client.prepare(sql).await.map_err(PgError::Driver)?;
let columns: Vec<ColumnMeta> = stmt
.columns()
.iter()
.map(|c| ColumnMeta {
name: c.name().to_string(),
type_oid: c.type_().oid(),
type_name: c.type_().name().to_string(),
})
.collect();
if columns.is_empty() {
let stream = client.simple_query(sql).await.map_err(PgError::Driver)?;
let rows_affected = stream.into_iter().find_map(|m| match m {
SimpleQueryMessage::CommandComplete(n) => Some(n),
_ => None,
});
return Ok((
ExecutionOutcome {
columns: vec![],
rows: vec![],
rows_affected,
cursor_id: None,
},
None,
));
}
let cursor_id = format!("c_{}", Uuid::new_v4().simple());
let begin = format!("BEGIN; DECLARE {} NO SCROLL CURSOR FOR {}", cursor_id, sql);
if let Err(e) = client.batch_execute(&begin).await {
let _ = client.batch_execute("ROLLBACK").await;
return Err(PgError::Driver(e));
}
let fetch_sql = format!("FETCH FORWARD {} FROM {}", page_size, cursor_id);
let stream = match client.simple_query(&fetch_sql).await {
Ok(s) => s,
Err(e) => {
let _ = client.batch_execute("ROLLBACK").await;
return Err(PgError::Driver(e));
}
};
let (rows, fetched) = collect_rows(stream, columns.len());
if fetched < page_size {
let _ = client
.batch_execute(&format!("CLOSE {}; COMMIT", cursor_id))
.await;
return Ok((
ExecutionOutcome {
columns,
rows,
rows_affected: Some(fetched as u64),
cursor_id: None,
},
None,
));
}
let active = ActiveCursor {
cursor_id: cursor_id.clone(),
column_count: columns.len(),
};
Ok((
ExecutionOutcome {
columns,
rows,
rows_affected: Some(fetched as u64),
cursor_id: Some(cursor_id),
},
Some(active),
))
}
pub async fn fetch_page(
client: &Client,
cursor: &ActiveCursor,
count: usize,
) -> Result<PageResult, PgError> {
let sql = format!("FETCH FORWARD {} FROM {}", count, cursor.cursor_id);
let stream = client.simple_query(&sql).await.map_err(PgError::Driver)?;
let (rows, fetched) = collect_rows(stream, cursor.column_count);
Ok(PageResult {
rows,
has_more: fetched == count,
})
}
pub async fn close_query(client: &Client, cursor: &ActiveCursor) {
let sql = format!("CLOSE {}; COMMIT", cursor.cursor_id);
if let Err(e) = client.batch_execute(&sql).await {
tracing::debug!(
cursor = %cursor.cursor_id,
error = %e,
"cursor close failed; session likely already broken"
);
let _ = client.batch_execute("ROLLBACK").await;
}
}
fn collect_rows(
stream: Vec<SimpleQueryMessage>,
width: usize,
) -> (Vec<Vec<Option<String>>>, usize) {
let mut rows: Vec<Vec<Option<String>>> = Vec::new();
let mut fetched = 0usize;
for msg in stream {
match msg {
SimpleQueryMessage::Row(row) => {
let mut values: Vec<Option<String>> = Vec::with_capacity(width);
for idx in 0..width {
values.push(row.get(idx).map(str::to_string));
}
rows.push(values);
}
SimpleQueryMessage::CommandComplete(n) => {
fetched = n as usize;
}
_ => {}
}
}
(rows, fetched)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn execution_outcome_round_trips() {
let out = ExecutionOutcome {
columns: vec![ColumnMeta {
name: "id".to_string(),
type_oid: 23, type_name: "int4".to_string(),
}],
rows: vec![vec![Some("1".to_string())], vec![None]],
rows_affected: Some(2),
cursor_id: Some("c_abc".to_string()),
};
let json = serde_json::to_string(&out).expect("serialize");
let back: ExecutionOutcome = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.columns.len(), 1);
assert_eq!(back.columns[0].type_oid, 23);
assert_eq!(back.columns[0].type_name, "int4");
assert_eq!(back.rows.len(), 2);
assert_eq!(back.cursor_id.as_deref(), Some("c_abc"));
}
#[test]
fn page_result_round_trips() {
let p = PageResult {
rows: vec![vec![Some("x".into())]],
has_more: false,
};
let json = serde_json::to_string(&p).expect("serialize");
let back: PageResult = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.rows.len(), 1);
assert!(!back.has_more);
}
#[test]
fn multi_statement_detector_handles_common_cases() {
assert!(!is_multi_statement("SELECT 1"));
assert!(!is_multi_statement("SELECT 1;"));
assert!(!is_multi_statement("SELECT 1;\n"));
assert!(!is_multi_statement("SELECT 1;\n \n"));
assert!(is_multi_statement("SET x = 1; SELECT 1"));
assert!(is_multi_statement("BEGIN; UPDATE t SET v=1; COMMIT;"));
}
#[test]
fn multi_statement_detector_ignores_semicolons_in_strings() {
assert!(!is_multi_statement("SELECT 'hello; world'"));
assert!(!is_multi_statement("INSERT INTO t VALUES ('a;b;c')"));
assert!(!is_multi_statement("SELECT 'it''s; fine'"));
assert!(is_multi_statement("SELECT 'a;b'; SELECT 1"));
}
#[test]
fn multi_statement_detector_ignores_semicolons_in_identifiers() {
assert!(!is_multi_statement("SELECT \"weird;name\" FROM t"));
assert!(is_multi_statement("SELECT \"col\"; SELECT 1"));
}
#[test]
fn multi_statement_detector_ignores_semicolons_in_comments() {
assert!(!is_multi_statement("SELECT 1 -- ; comment\n"));
assert!(!is_multi_statement("SELECT 1 /* ; */ FROM t"));
assert!(!is_multi_statement("/* ; */ SELECT 1"));
assert!(is_multi_statement("SELECT 1 -- ;\n; SELECT 2"));
}
#[test]
fn smart_split_returns_preamble_and_main() {
let (pre, main) = split_at_last_statement("SET x = 1; SELECT * FROM t").expect("split");
assert_eq!(pre, "SET x = 1;");
assert_eq!(main, "SELECT * FROM t");
}
#[test]
fn smart_split_handles_multiple_preamble_statements() {
let (pre, main) = split_at_last_statement("SET x = 1; SET y = 2; SELECT 1").expect("split");
assert_eq!(pre, "SET x = 1; SET y = 2;");
assert_eq!(main, "SELECT 1");
}
#[test]
fn smart_split_returns_none_when_no_main_statement_after_delimiter() {
assert!(split_at_last_statement("SET x = 1; SELECT 1;").is_none());
assert!(split_at_last_statement("SET x = 1;").is_none());
assert!(split_at_last_statement("SET x = 1; -- trailing\n").is_none());
}
#[test]
fn smart_split_ignores_in_string_semicolons() {
let (pre, main) = split_at_last_statement("SET x = 'a;b'; SELECT 1").expect("split");
assert_eq!(pre, "SET x = 'a;b';");
assert_eq!(main, "SELECT 1");
}
#[test]
fn smart_split_returns_none_for_single_statement() {
assert!(split_at_last_statement("SELECT 1").is_none());
assert!(split_at_last_statement("SELECT 1;").is_none());
}
}