use crate::{Result, Error, EmbeddedDatabase, Tuple, Value, Schema};
#[inline]
pub(super) fn starts_with_icase(s: &str, prefix: &str) -> bool {
#[allow(clippy::indexing_slicing)]
{
s.len() >= prefix.len()
&& s.as_bytes()[..prefix.len()].eq_ignore_ascii_case(prefix.as_bytes())
}
}
use super::messages::{
FrontendMessage, BackendMessage, AuthenticationMessage,
TransactionStatus, FieldDescription,
};
use super::auth::{AuthManager, AuthMethod, ScramAuthState};
use super::catalog::PgCatalog;
use super::prepared::PreparedStatementManager;
use super::ssl::SecureConnection;
use bytes::{BytesMut, BufMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use std::sync::Arc;
pub struct PgConnectionHandler<S = BufWriter<TcpStream>> {
stream: S,
pub(super) database: Arc<EmbeddedDatabase>,
auth_manager: Arc<AuthManager>,
pub(super) catalog: PgCatalog,
pub(super) prepared_statements: PreparedStatementManager,
authenticated: bool,
transaction_status: TransactionStatus,
buffer: BytesMut,
username: Option<String>,
scram_state: Option<ScramAuthState>,
write_buf: BytesMut,
suppress_ready_for_query: bool,
}
impl PgConnectionHandler<BufWriter<TcpStream>> {
pub fn new(
stream: TcpStream,
database: Arc<EmbeddedDatabase>,
auth_manager: Arc<AuthManager>,
initial_data: Option<&[u8]>,
) -> Self {
let mut buffer = BytesMut::with_capacity(8192);
if let Some(data) = initial_data {
buffer.extend_from_slice(data);
}
Self {
stream: BufWriter::new(stream),
database: database.clone(),
auth_manager,
catalog: PgCatalog::with_database(database),
prepared_statements: PreparedStatementManager::new(),
authenticated: false,
transaction_status: TransactionStatus::Idle,
buffer,
username: None,
scram_state: None,
write_buf: BytesMut::with_capacity(4096),
suppress_ready_for_query: false,
}
}
}
#[cfg(unix)]
impl PgConnectionHandler<BufWriter<UnixStream>> {
pub fn new_unix(
stream: UnixStream,
database: Arc<EmbeddedDatabase>,
auth_manager: Arc<AuthManager>,
) -> Self {
Self {
stream: BufWriter::new(stream),
database: database.clone(),
auth_manager,
catalog: PgCatalog::with_database(database),
prepared_statements: PreparedStatementManager::new(),
authenticated: false,
transaction_status: TransactionStatus::Idle,
buffer: BytesMut::with_capacity(8192),
username: None,
scram_state: None,
write_buf: BytesMut::with_capacity(4096),
suppress_ready_for_query: false,
}
}
}
#[cfg(unix)]
pub async fn handle_connection_unix(
database: Arc<EmbeddedDatabase>,
stream: UnixStream,
_connection_id: u32,
) -> Result<()> {
let auth_manager = Arc::new(AuthManager::new(AuthMethod::Trust));
let mut handler = PgConnectionHandler::new_unix(stream, database, auth_manager);
handler.handle().await
}
impl PgConnectionHandler<BufWriter<SecureConnection<TcpStream>>> {
pub fn new_with_stream(
stream: SecureConnection<TcpStream>,
database: Arc<EmbeddedDatabase>,
auth_manager: Arc<AuthManager>,
initial_data: Option<&[u8]>,
) -> Self {
let mut buffer = BytesMut::with_capacity(8192);
if let Some(data) = initial_data {
buffer.extend_from_slice(data);
}
Self {
stream: BufWriter::new(stream),
database: database.clone(),
auth_manager,
catalog: PgCatalog::with_database(database),
prepared_statements: PreparedStatementManager::new(),
authenticated: false,
transaction_status: TransactionStatus::Idle,
buffer,
username: None,
scram_state: None,
write_buf: BytesMut::with_capacity(4096),
suppress_ready_for_query: false,
}
}
}
impl<S> PgConnectionHandler<S>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
pub async fn handle(&mut self) -> Result<()> {
tracing::info!("New PostgreSQL connection");
if let Err(e) = self.handle_startup().await {
tracing::error!("Startup failed: {}", e);
let _ = self.send_error("FATAL", "08P01", &e.to_string(), None, None).await;
return Err(e);
}
tracing::debug!("Entering main message loop");
loop {
tracing::trace!("Waiting for next message from client");
match self.read_message().await {
Ok(Some(msg)) => {
tracing::debug!("Received message: {:?}", msg);
if let Err(e) = self.handle_message(msg).await {
tracing::error!("Error handling message: {}", e);
self.send_error("ERROR", "XX000", &e.to_string(), None, None).await?;
}
}
Ok(None) => {
tracing::info!("Client disconnected");
break;
}
Err(e) => {
tracing::error!("Error reading message: {}", e);
break;
}
}
}
Ok(())
}
#[allow(clippy::indexing_slicing)]
async fn handle_startup(&mut self) -> Result<()> {
let len_buf: [u8; 4];
if self.buffer.len() >= 4 {
len_buf = [self.buffer[0], self.buffer[1], self.buffer[2], self.buffer[3]];
} else {
let mut buf = [0u8; 4];
self.stream.read_exact(&mut buf).await
.map_err(|e| Error::network(format!("Failed to read startup length: {}", e)))?;
len_buf = buf;
self.buffer.extend_from_slice(&len_buf);
}
let len = i32::from_be_bytes(len_buf) as usize;
let bytes_in_buffer = self.buffer.len();
let bytes_needed = len.saturating_sub(bytes_in_buffer);
if bytes_needed > 0 {
let mut remaining_buf = vec![0u8; bytes_needed];
self.stream.read_exact(&mut remaining_buf).await
.map_err(|e| Error::network(format!("Failed to read startup message: {}", e)))?;
self.buffer.extend_from_slice(&remaining_buf);
}
let msg = FrontendMessage::parse_startup(&mut self.buffer)?
.ok_or_else(|| Error::protocol("Invalid startup message"))?;
if let FrontendMessage::Startup { protocol_version, params } = msg {
tracing::info!("Protocol version: {}, params: {:?}", protocol_version, params);
self.username = params.get("user").cloned();
if let Some(requested) = params.get("database").cloned()
.or_else(|| params.get("user").cloned())
{
if !self.database.database_name_is_valid(&requested) {
return Err(Error::authentication(format!(
"database \"{requested}\" does not exist"
)));
}
}
match self.auth_manager.method() {
AuthMethod::Trust => {
self.authenticated = true;
self.send_auth_ok().await?;
}
AuthMethod::CleartextPassword => {
self.send_message(BackendMessage::Authentication(
AuthenticationMessage::CleartextPassword
)).await?;
self.flush().await?;
if let Some(FrontendMessage::PasswordMessage { password }) = self.read_message().await? {
let username = self.username.as_ref()
.ok_or_else(|| Error::authentication("No username provided"))?;
if self.auth_manager.verify_cleartext(username, &password)? {
self.authenticated = true;
self.send_auth_ok().await?;
} else {
return Err(Error::authentication("Invalid password"));
}
} else {
return Err(Error::protocol("Expected password message"));
}
}
AuthMethod::ScramSha256 => {
self.handle_scram_authentication().await?;
}
_ => {
self.authenticated = true;
self.send_auth_ok().await?;
}
}
self.send_parameter_status(
"server_version",
&format!("16.0 (HeliosDB Nano {})", env!("CARGO_PKG_VERSION")),
).await?;
self.send_parameter_status("server_encoding", "UTF8").await?;
self.send_parameter_status("client_encoding", "UTF8").await?;
self.send_parameter_status("DateStyle", "ISO, MDY").await?;
self.send_parameter_status("TimeZone", "UTC").await?;
self.send_parameter_status("integer_datetimes", "on").await?;
self.send_message(BackendMessage::BackendKeyData {
process_id: std::process::id() as i32,
secret_key: rand::random(),
}).await?;
self.send_ready_for_query().await?;
Ok(())
} else {
Err(Error::protocol("Expected startup message"))
}
}
#[allow(clippy::indexing_slicing)]
async fn read_message(&mut self) -> Result<Option<FrontendMessage>> {
tracing::trace!("read_message: Checking buffer, len={}", self.buffer.len());
if let Some(msg) = FrontendMessage::parse(&mut self.buffer)? {
tracing::trace!("read_message: Parsed message from existing buffer");
return Ok(Some(msg));
}
let mut temp_buf = vec![0u8; 4096];
loop {
tracing::trace!("read_message: Attempting to read from stream");
match self.stream.read(&mut temp_buf).await {
Ok(0) => {
tracing::debug!("read_message: EOF received (0 bytes)");
return Ok(None); }
Ok(n) => {
tracing::trace!("read_message: Read {} bytes", n);
self.buffer.extend_from_slice(&temp_buf[..n]);
if let Some(msg) = FrontendMessage::parse(&mut self.buffer)? {
tracing::trace!("read_message: Successfully parsed message after read");
return Ok(Some(msg));
}
tracing::trace!("read_message: Insufficient data for complete message, continuing");
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
tracing::trace!("read_message: WouldBlock, sleeping 10ms");
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
Err(e) => {
tracing::error!("read_message: Read error: {}", e);
return Err(Error::network(format!("Read error: {}", e)));
}
}
}
}
async fn handle_message(&mut self, msg: FrontendMessage) -> Result<()> {
if !self.authenticated && !matches!(msg, FrontendMessage::PasswordMessage { .. }) {
return Err(Error::authentication("Not authenticated"));
}
match msg {
FrontendMessage::Query { query } => {
self.handle_query(&query).await?;
}
FrontendMessage::Parse { statement_name, query, param_types } => {
self.handle_parse_extended(statement_name, query, param_types).await?;
}
FrontendMessage::Bind { portal_name, statement_name, param_formats, params, result_formats } => {
self.handle_bind_extended(portal_name, statement_name, param_formats, params, result_formats).await?;
}
FrontendMessage::Execute { portal_name, max_rows } => {
self.handle_execute_extended(portal_name, max_rows).await?;
}
FrontendMessage::Describe { target, name } => {
self.handle_describe_extended(target, name).await?;
}
FrontendMessage::Close { target, name } => {
self.handle_close(target, name).await?;
}
FrontendMessage::Sync => {
self.send_ready_for_query().await?;
}
FrontendMessage::Flush => {
self.flush().await?;
}
FrontendMessage::Terminate => {
return Ok(());
}
_ => {
tracing::warn!("Unhandled message type: {:?}", msg);
}
}
Ok(())
}
async fn handle_query(&mut self, query: &str) -> Result<()> {
let statements = pg_split_sql_respecting_quotes(query);
if statements.len() <= 1 {
return self.handle_single_query(query).await;
}
self.suppress_ready_for_query = true;
let last_idx = statements.len() - 1;
for (i, stmt) in statements.iter().enumerate() {
if i == last_idx {
self.suppress_ready_for_query = false;
}
self.handle_single_query(stmt).await?;
}
self.suppress_ready_for_query = false;
Ok(())
}
#[allow(clippy::indexing_slicing)]
async fn handle_single_query(&mut self, query: &str) -> Result<()> {
tracing::debug!("Executing query: {}", query);
if pg_looks_like_do_block(query.trim()) {
return self.handle_do_block(query).await;
}
if query.trim().is_empty() {
self.send_message(BackendMessage::EmptyQueryResponse).await?;
self.send_ready_for_query().await?;
return Ok(());
}
let trimmed = query.trim();
if trimmed.eq_ignore_ascii_case("BEGIN") || starts_with_icase(trimmed, "BEGIN ") || trimmed.eq_ignore_ascii_case("START TRANSACTION") || starts_with_icase(trimmed, "START TRANSACTION ") {
let isolation_level = Self::parse_isolation_level(trimmed);
if self.transaction_status == TransactionStatus::InTransaction {
self.send_message(BackendMessage::NoticeResponse {
severity: "WARNING".to_string(),
code: "25001".to_string(),
message: "there is already a transaction in progress".to_string(),
}).await?;
} else {
self.database.begin()?;
self.transaction_status = TransactionStatus::InTransaction;
if let Some(level) = isolation_level {
tracing::debug!("Transaction started with isolation level: {}", level);
}
}
self.send_command_complete("BEGIN").await?;
self.send_ready_for_query().await?;
return Ok(());
} else if starts_with_icase(trimmed, "SET TRANSACTION ISOLATION LEVEL ") || starts_with_icase(trimmed, "SET SESSION CHARACTERISTICS") {
let level = Self::parse_isolation_level(trimmed);
if level.is_some() {
self.send_command_complete("SET").await?;
} else {
self.send_error("ERROR", "22023", "Invalid isolation level", None, None).await?;
return Ok(());
}
self.send_ready_for_query().await?;
return Ok(());
} else if starts_with_icase(trimmed, "SET ") && !starts_with_icase(trimmed, "SET TRANSACTION") && !starts_with_icase(trimmed, "SET SESSION") {
self.send_command_complete("SET").await?;
self.send_ready_for_query().await?;
return Ok(());
} else if starts_with_icase(trimmed, "SHOW ") {
let param = trimmed[5..].trim().trim_end_matches(';').trim();
let (col_name, value) = Self::resolve_show_parameter(param);
let schema = Schema::new(vec![
crate::Column::new(&col_name, crate::DataType::Text),
]);
let row = Tuple::new(vec![Value::String(value)]);
self.send_query_result(schema, vec![row]).await?;
self.send_ready_for_query().await?;
return Ok(());
} else if starts_with_icase(trimmed, "PRAGMA ") || trimmed.eq_ignore_ascii_case("PRAGMA") {
if let Some((name, arg)) = crate::sql::sqlite_compat::parse_pragma(trimmed) {
match name.to_lowercase().as_str() {
"table_info" => {
let table = arg.unwrap_or_default();
let table = table.trim().trim_matches(|c| c == '\'' || c == '"' || c == '`').to_string();
let rows = self.pragma_table_info(&table)?;
let schema = Schema::new(vec![
crate::Column::new("cid", crate::DataType::Int4),
crate::Column::new("name", crate::DataType::Text),
crate::Column::new("type", crate::DataType::Text),
crate::Column::new("notnull", crate::DataType::Int4),
crate::Column::new("dflt_value", crate::DataType::Text),
crate::Column::new("pk", crate::DataType::Int4),
]);
self.send_query_result(schema, rows).await?;
self.send_ready_for_query().await?;
return Ok(());
}
_ => {
tracing::debug!("PRAGMA stubbed (no-op): {} = {:?}", name, arg);
self.send_command_complete("PRAGMA").await?;
self.send_ready_for_query().await?;
return Ok(());
}
}
} else {
self.send_command_complete("PRAGMA").await?;
self.send_ready_for_query().await?;
return Ok(());
}
} else if trimmed.eq_ignore_ascii_case("COMMIT") {
if self.transaction_status == TransactionStatus::InTransaction {
self.database.commit()?;
} else {
self.send_message(BackendMessage::NoticeResponse {
severity: "WARNING".to_string(),
code: "25P01".to_string(),
message: "there is no transaction in progress".to_string(),
}).await?;
}
self.transaction_status = TransactionStatus::Idle;
self.send_command_complete("COMMIT").await?;
self.send_ready_for_query().await?;
return Ok(());
} else if trimmed.eq_ignore_ascii_case("ROLLBACK") {
if self.transaction_status == TransactionStatus::InTransaction {
self.database.rollback()?;
} else {
self.send_message(BackendMessage::NoticeResponse {
severity: "WARNING".to_string(),
code: "25P01".to_string(),
message: "there is no transaction in progress".to_string(),
}).await?;
}
self.transaction_status = TransactionStatus::Idle;
self.send_command_complete("ROLLBACK").await?;
self.send_ready_for_query().await?;
return Ok(());
}
#[cfg(feature = "ha-tier1")]
{
use crate::replication::ha_state::{ha_state, SyncMode};
use crate::replication::query_forwarder::{query_forwarder, ForwardedResult};
if ha_state().is_read_only() {
let is_write = starts_with_icase(trimmed, "INSERT")
|| starts_with_icase(trimmed, "UPDATE")
|| starts_with_icase(trimmed, "DELETE")
|| starts_with_icase(trimmed, "CREATE")
|| starts_with_icase(trimmed, "DROP")
|| starts_with_icase(trimmed, "ALTER")
|| starts_with_icase(trimmed, "TRUNCATE");
if is_write {
let config = ha_state().get_config();
let sync_mode = config.as_ref().map(|c| c.sync_mode).unwrap_or(SyncMode::Async);
if matches!(sync_mode, SyncMode::Sync | SyncMode::SemiSync) {
if let Some(forwarder) = query_forwarder() {
match forwarder.forward_query(query) {
Ok(ForwardedResult::Command { tag, .. }) => {
self.send_command_complete(&tag).await?;
self.send_ready_for_query().await?;
return Ok(());
}
Ok(ForwardedResult::Rows { columns, rows }) => {
self.send_forwarded_rows(&columns, &rows).await?;
self.send_ready_for_query().await?;
return Ok(());
}
Ok(ForwardedResult::Error { severity, code, message, detail, hint }) => {
self.send_error(&severity, &code, &message, detail, hint).await?;
self.send_ready_for_query().await?;
return Ok(());
}
Err(e) => {
self.send_error(
"ERROR",
"08006",
&format!("Failed to forward query to primary: {}", e),
None,
Some("Check primary connectivity".to_string()),
).await?;
self.send_ready_for_query().await?;
return Ok(());
}
}
} else {
self.send_error(
"ERROR",
"25006",
"cannot execute write operations: primary connection not established",
None,
Some("Standby is still connecting to primary".to_string()),
).await?;
self.send_ready_for_query().await?;
return Ok(());
}
} else {
self.send_error(
"ERROR",
"25006",
"cannot execute write operations in read-only mode (async standby)",
None,
Some("Connect to the primary for write operations, or configure sync mode for transparent routing.".to_string()),
).await?;
self.send_ready_for_query().await?;
return Ok(());
}
}
}
}
if let Some(result) = self.catalog.handle_query(query)? {
self.send_query_result(result.0, result.1).await?;
self.send_ready_for_query().await?;
return Ok(());
}
let is_select = starts_with_icase(trimmed, "SELECT");
let is_dml_returning = !is_select && {
let upper = trimmed.to_uppercase();
(starts_with_icase(trimmed, "INSERT")
|| starts_with_icase(trimmed, "UPDATE")
|| starts_with_icase(trimmed, "DELETE"))
&& upper.contains("RETURNING")
};
if is_select {
let (results, columns) = self.database.query_with_columns(query)?;
let schema = if !columns.is_empty() {
Schema::new(columns.iter().enumerate().map(|(i, name)| {
let data_type = results.first()
.and_then(|r| r.values.get(i))
.map(Value::data_type)
.unwrap_or(crate::DataType::Text);
crate::Column {
name: name.clone(),
data_type,
nullable: true,
primary_key: false,
source_table: None,
source_table_name: None,
default_expr: None,
unique: false,
storage_mode: crate::ColumnStorageMode::Default,
}
}).collect())
} else if !results.is_empty() {
results[0].schema()
} else {
Schema::new(vec![])
};
self.send_query_result(schema, results).await?;
} else if is_dml_returning {
let (affected, tuples) = self.database.execute_returning(query)?;
if tuples.is_empty() {
let tag = self.get_command_tag(query, affected);
self.send_command_complete(&tag).await?;
} else {
let schema = self.derive_returning_schema(query)
.unwrap_or_else(|_| {
if let Some(first) = tuples.first() {
first.schema()
} else {
Schema::new(vec![])
}
});
self.send_query_result(schema, tuples).await?;
}
} else {
let affected = self.database.execute(query)?;
let tag = self.get_command_tag(query, affected);
self.send_command_complete(&tag).await?;
}
self.send_ready_for_query().await?;
Ok(())
}
#[allow(clippy::indexing_slicing)]
async fn handle_scram_authentication(&mut self) -> Result<()> {
self.send_message(BackendMessage::Authentication(
AuthenticationMessage::ScramSha256
)).await?;
self.flush().await?;
let client_first = match self.read_message().await? {
Some(FrontendMessage::PasswordMessage { password }) => password,
_ => return Err(Error::protocol("Expected SASL initial response")),
};
tracing::debug!("Received client-first-message: {}", client_first);
let (username_owned, client_nonce_owned) =
super::auth::parse_scram_client_first(&client_first)?;
let username: &str = &username_owned;
let client_nonce: &str = &client_nonce_owned;
tracing::info!("SCRAM authentication for user: {}", username);
let password_store = self.auth_manager.password_store()
.ok_or_else(|| Error::authentication("SCRAM password store not configured"))?;
let credentials = password_store.get_credentials(username)
.ok_or_else(|| Error::authentication("User not found"))?;
let mut scram_state = ScramAuthState::new(username.to_string());
scram_state.set_client_nonce(client_nonce.to_string());
let client_first_bare = format!("n={},r={}", username, client_nonce);
scram_state.set_client_first_message_bare(client_first_bare);
let server_first = scram_state.build_server_first_message()?;
tracing::debug!("Sending server-first-message: {}", server_first);
self.send_message(BackendMessage::Authentication(
AuthenticationMessage::ScramSha256Continue {
data: server_first.as_bytes().to_vec(),
}
)).await?;
self.flush().await?;
let client_final = match self.read_message().await? {
Some(FrontendMessage::PasswordMessage { password }) => password,
_ => return Err(Error::protocol("Expected SASL response")),
};
tracing::debug!("Received client-final-message: {}", client_final);
let final_parts: Vec<&str> = client_final.split(',').collect();
if final_parts.len() < 3 {
return Err(Error::protocol("Invalid SCRAM client-final-message"));
}
let proof_part = final_parts.iter()
.find(|p| p.starts_with("p="))
.ok_or_else(|| Error::protocol("Missing proof in client-final-message"))?;
let client_proof_b64 = proof_part.strip_prefix("p=")
.ok_or_else(|| Error::protocol("Invalid proof format"))?;
let client_final_without_proof: Vec<&str> = final_parts.iter()
.filter(|p| !p.starts_with("p="))
.copied()
.collect();
let client_final_without_proof = client_final_without_proof.join(",");
let server_signature = scram_state.verify_client_proof(
client_proof_b64,
&client_final_without_proof,
&credentials.stored_key,
&credentials.server_key,
)?;
tracing::info!("SCRAM authentication successful for user: {}", username);
let server_final = scram_state.build_server_final_message(&server_signature)?;
tracing::debug!("Sending server-final-message: {}", server_final);
self.send_message(BackendMessage::Authentication(
AuthenticationMessage::ScramSha256Final {
data: server_final.as_bytes().to_vec(),
}
)).await?;
self.authenticated = true;
self.username = Some(username.to_string());
Ok(())
}
async fn send_query_result(&mut self, schema: Schema, rows: Vec<Tuple>) -> Result<()> {
let fields = schema_to_field_descriptions(&schema);
self.send_message(BackendMessage::RowDescription { fields }).await?;
for row in &rows {
self.send_data_row_direct(row).await?;
}
let tag = format!("SELECT {}", rows.len());
self.send_command_complete(&tag).await?;
Ok(())
}
#[cfg(feature = "ha-tier1")]
async fn send_forwarded_rows(
&mut self,
columns: &[crate::replication::query_forwarder::ColumnInfo],
rows: &[Vec<Option<String>>],
) -> Result<()> {
use crate::protocol::postgres::messages::FieldDescription;
let fields: Vec<FieldDescription> = columns
.iter()
.map(|col| FieldDescription {
name: col.name.clone(),
table_oid: 0,
column_attr_num: 0,
data_type_oid: col.type_oid,
data_type_size: -1,
type_modifier: -1,
format_code: 0, })
.collect();
self.send_message(BackendMessage::RowDescription { fields }).await?;
for row in rows {
let values: Vec<Option<Vec<u8>>> = row
.iter()
.map(|v| v.as_ref().map(|s| s.as_bytes().to_vec()))
.collect();
self.send_message(BackendMessage::DataRow { values }).await?;
}
let tag = format!("SELECT {}", rows.len());
self.send_command_complete(&tag).await?;
Ok(())
}
pub(super) async fn send_message(&mut self, msg: BackendMessage) -> Result<()> {
self.write_buf.clear();
msg.encode(&mut self.write_buf);
self.stream.write_all(&self.write_buf).await
.map_err(|e| Error::network(format!("Failed to send message: {}", e)))?;
Ok(())
}
#[allow(clippy::indexing_slicing)] async fn send_data_row_direct(&mut self, tuple: &Tuple) -> Result<()> {
self.write_buf.clear();
self.write_buf.put_u8(b'D');
let length_pos = self.write_buf.len();
self.write_buf.put_i32(0);
self.write_buf.put_i16(tuple.values.len() as i16);
let mut itoa_buf = itoa::Buffer::new();
let mut ryu_buf = ryu::Buffer::new();
for val in &tuple.values {
match val {
Value::Null => {
self.write_buf.put_i32(-1);
}
Value::Boolean(b) => {
self.write_buf.put_i32(1);
self.write_buf.put_u8(if *b { b't' } else { b'f' });
}
Value::Int2(i) => {
let s = itoa_buf.format(*i);
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Int4(i) => {
let s = itoa_buf.format(*i);
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Int8(i) => {
let s = itoa_buf.format(*i);
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Float4(f) => {
let s = ryu_buf.format(*f);
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Float8(f) => {
let s = ryu_buf.format(*f);
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::String(s) => {
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Bytes(b) => {
self.write_buf.put_i32(b.len() as i32);
self.write_buf.put_slice(b);
}
Value::Json(j) => {
let s = j.to_string();
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Numeric(n) => {
self.write_buf.put_i32(n.len() as i32);
self.write_buf.put_slice(n.as_bytes());
}
Value::Uuid(u) => {
let s = u.to_string();
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Timestamp(ts) => {
let s = ts.naive_utc().format("%Y-%m-%d %H:%M:%S%.6f").to_string();
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Date(d) => {
let s = d.format("%Y-%m-%d").to_string();
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Time(t) => {
let s = t.format("%H:%M:%S%.6f").to_string();
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Interval(micros) => {
let total_secs = micros / 1_000_000;
let days = total_secs / 86400;
let hours = (total_secs % 86400) / 3600;
let mins = (total_secs % 3600) / 60;
let secs = total_secs % 60;
let s = if days > 0 {
format!("{} days {:02}:{:02}:{:02}", days, hours, mins, secs)
} else {
format!("{:02}:{:02}:{:02}", hours, mins, secs)
};
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::Array(arr) => {
let val_length_pos = self.write_buf.len();
self.write_buf.put_i32(0);
self.write_buf.put_u8(b'{');
for (i, v) in arr.iter().enumerate() {
if i > 0 { self.write_buf.put_u8(b','); }
match v {
Value::String(s) => {
self.write_buf.put_u8(b'"');
self.write_buf.put_slice(s.as_bytes());
self.write_buf.put_u8(b'"');
}
Value::Null => self.write_buf.put_slice(b"NULL"),
other => {
let s = other.to_string();
self.write_buf.put_slice(s.as_bytes());
}
}
}
self.write_buf.put_u8(b'}');
let val_len = (self.write_buf.len() - val_length_pos - 4) as i32;
self.write_buf[val_length_pos..val_length_pos + 4].copy_from_slice(&val_len.to_be_bytes());
}
Value::Vector(v) => {
let val_length_pos = self.write_buf.len();
self.write_buf.put_i32(0);
self.write_buf.put_u8(b'{');
for (i, x) in v.iter().enumerate() {
if i > 0 { self.write_buf.put_u8(b','); }
let s = ryu_buf.format(*x);
self.write_buf.put_slice(s.as_bytes());
}
self.write_buf.put_u8(b'}');
let val_len = (self.write_buf.len() - val_length_pos - 4) as i32;
self.write_buf[val_length_pos..val_length_pos + 4].copy_from_slice(&val_len.to_be_bytes());
}
Value::DictRef { dict_id } => {
let s = itoa_buf.format(*dict_id);
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::CasRef { hash } => {
let s = hex::encode(hash);
self.write_buf.put_i32(s.len() as i32);
self.write_buf.put_slice(s.as_bytes());
}
Value::ColumnarRef => {
self.write_buf.put_i32(10);
self.write_buf.put_slice(b"<columnar>");
}
}
}
let msg_len = (self.write_buf.len() - length_pos) as i32;
self.write_buf[length_pos..length_pos + 4].copy_from_slice(&msg_len.to_be_bytes());
self.stream.write_all(&self.write_buf).await
.map_err(|e| Error::network(format!("Failed to send message: {}", e)))?;
Ok(())
}
async fn flush(&mut self) -> Result<()> {
self.stream.flush().await
.map_err(|e| Error::network(format!("Failed to flush stream: {}", e)))
}
async fn send_auth_ok(&mut self) -> Result<()> {
self.send_message(BackendMessage::Authentication(AuthenticationMessage::Ok)).await
}
async fn send_parameter_status(&mut self, name: &str, value: &str) -> Result<()> {
self.send_message(BackendMessage::ParameterStatus {
name: name.to_string(),
value: value.to_string(),
}).await
}
async fn handle_do_block(&mut self, query: &str) -> Result<()> {
let body = pg_extract_do_block_body(query).unwrap_or("");
let stripped = pg_strip_begin_end(body.trim());
let (main_body, exception_codes) = pg_split_exception(stripped);
if let Some(kw) = pg_detect_plpgsql(main_body) {
return Err(Error::query_execution(format!(
"PL/pgSQL control flow (`{kw}`) inside DO blocks is not yet \
supported in HeliosDB Nano. Rewrite the block as plain SQL, \
or execute each statement separately. \
See: docs/compatibility/plpgsql.md"
)));
}
let statements = pg_split_sql_respecting_quotes(main_body);
if statements.is_empty() {
self.send_command_complete("DO").await?;
self.send_ready_for_query().await?;
return Ok(());
}
let prev = self.suppress_ready_for_query;
self.suppress_ready_for_query = true;
for stmt in &statements {
if let Err(e) = self.database.execute(stmt) {
if pg_exception_matches(&exception_codes, &e.to_string()) {
tracing::debug!(
"DO block: caught {:?} via EXCEPTION clause; continuing",
e.to_string()
);
continue;
}
self.suppress_ready_for_query = prev;
return Err(e);
}
}
self.suppress_ready_for_query = prev;
self.send_command_complete("DO").await?;
self.send_ready_for_query().await?;
Ok(())
}
async fn send_ready_for_query(&mut self) -> Result<()> {
if self.suppress_ready_for_query {
return Ok(());
}
self.send_message(BackendMessage::ReadyForQuery {
status: self.transaction_status,
}).await?;
self.flush().await
}
pub(super) async fn send_command_complete(&mut self, tag: &str) -> Result<()> {
self.send_message(BackendMessage::CommandComplete {
tag: tag.to_string(),
}).await
}
async fn send_error(&mut self, severity: &str, code: &str, message: &str, detail: Option<String>, hint: Option<String>) -> Result<()> {
self.send_message(BackendMessage::ErrorResponse {
severity: severity.to_string(),
code: code.to_string(),
message: message.to_string(),
detail,
hint,
position: None,
}).await?;
self.send_ready_for_query().await
}
fn pragma_table_info(&self, table: &str) -> Result<Vec<Tuple>> {
let catalog = self.database.storage.catalog();
let schema = catalog.get_table_schema(table)?;
let mut rows = Vec::with_capacity(schema.columns.len());
for (idx, col) in schema.columns.iter().enumerate() {
rows.push(Tuple::new(vec![
Value::Int4(idx as i32),
Value::String(col.name.clone()),
Value::String(format!("{:?}", col.data_type).to_uppercase()),
Value::Int4(if col.nullable { 0 } else { 1 }),
col.default_expr
.as_ref()
.map(|d| Value::String(d.clone()))
.unwrap_or(Value::Null),
Value::Int4(if col.primary_key { 1 } else { 0 }),
]));
}
Ok(rows)
}
fn derive_returning_schema(&self, sql: &str) -> Result<Schema> {
let catalog = self.database.storage.catalog();
let planner = crate::sql::planner::Planner::with_catalog(&catalog)
.with_sql(sql.to_string());
let (statement, _) = self.database.parse_cached(sql)?;
let plan = planner.statement_to_plan(statement)?;
let (table_name, returning_items) = match &plan {
crate::sql::LogicalPlan::Insert { table_name, returning, .. }
| crate::sql::LogicalPlan::InsertSelect { table_name, returning, .. } => {
(table_name.as_str(), returning.as_ref())
}
crate::sql::LogicalPlan::Update { table_name, returning, .. } => {
(table_name.as_str(), returning.as_ref())
}
crate::sql::LogicalPlan::Delete { table_name, returning, .. } => {
(table_name.as_str(), returning.as_ref())
}
_ => return Err(crate::Error::query_execution("Not a DML statement")),
};
if let Some(items) = returning_items {
let table_schema = catalog.get_table_schema(table_name)?;
Ok(crate::EmbeddedDatabase::returning_schema(&table_schema, items))
} else {
Ok(Schema::new(vec![]))
}
}
pub(super) fn get_command_tag(&self, query: &str, affected: u64) -> String {
let trimmed = query.trim();
if starts_with_icase(trimmed, "INSERT") {
format!("INSERT 0 {}", affected)
} else if starts_with_icase(trimmed, "UPDATE") {
format!("UPDATE {}", affected)
} else if starts_with_icase(trimmed, "DELETE") {
format!("DELETE {}", affected)
} else if starts_with_icase(trimmed, "CREATE TABLE") {
"CREATE TABLE".to_string()
} else if starts_with_icase(trimmed, "DROP TABLE") {
"DROP TABLE".to_string()
} else if starts_with_icase(trimmed, "CREATE INDEX") {
"CREATE INDEX".to_string()
} else {
format!("OK {}", affected)
}
}
#[allow(clippy::indexing_slicing)]
fn resolve_show_parameter(param: &str) -> (String, String) {
let param_lower = param.to_lowercase();
let col = param_lower.clone();
let val = match param_lower.as_str() {
"server_version" => format!("16.0 (HeliosDB Nano {})", env!("CARGO_PKG_VERSION")),
"server_encoding" => "UTF8".to_string(),
"client_encoding" => "UTF8".to_string(),
"standard_conforming_strings" => "on".to_string(),
"transaction_isolation" | "transaction isolation level" =>
"read committed".to_string(),
"datestyle" => "ISO, MDY".to_string(),
"timezone" | "time zone" => "UTC".to_string(),
"integer_datetimes" => "on".to_string(),
"max_connections" => "100".to_string(),
"lc_collate" => "en_US.UTF-8".to_string(),
"lc_ctype" => "en_US.UTF-8".to_string(),
"search_path" => "\"$user\", public".to_string(),
"default_transaction_isolation" => "read committed".to_string(),
"is_superuser" => "on".to_string(),
_ => String::new(),
};
(col, val)
}
fn parse_isolation_level(query: &str) -> Option<String> {
let query_bytes = query.as_bytes();
let needle = b"ISOLATION LEVEL";
let pos = query_bytes.windows(needle.len())
.position(|w| w.eq_ignore_ascii_case(needle))?;
let rest = query[pos + needle.len()..].trim();
if starts_with_icase(rest, "READ UNCOMMITTED") {
Some("READ UNCOMMITTED".to_string())
} else if starts_with_icase(rest, "READ COMMITTED") {
Some("READ COMMITTED".to_string())
} else if starts_with_icase(rest, "REPEATABLE READ") {
Some("REPEATABLE READ".to_string())
} else if starts_with_icase(rest, "SERIALIZABLE") {
Some("SERIALIZABLE".to_string())
} else {
None
}
}
}
pub(super) fn schema_to_field_descriptions(schema: &Schema) -> Vec<FieldDescription> {
schema.columns.iter().map(|col| {
FieldDescription {
name: col.name.clone(),
table_oid: 0,
column_attr_num: 0,
data_type_oid: datatype_to_oid(&col.data_type),
data_type_size: datatype_to_size(&col.data_type),
type_modifier: -1,
format_code: 0, }
}).collect()
}
pub(super) fn datatype_to_oid(dt: &crate::DataType) -> i32 {
match dt {
crate::DataType::Boolean => 16,
crate::DataType::Int2 => 21,
crate::DataType::Int4 => 23,
crate::DataType::Int8 => 20,
crate::DataType::Float4 => 700,
crate::DataType::Float8 => 701,
crate::DataType::Text => 25,
crate::DataType::Varchar(_) => 1043,
crate::DataType::Json => 114,
crate::DataType::Jsonb => 3802,
crate::DataType::Timestamp => 1114,
crate::DataType::Date => 1082,
crate::DataType::Time => 1083,
crate::DataType::Uuid => 2950,
crate::DataType::Vector(_) => 1000, _ => 705, }
}
pub(super) fn datatype_to_size(dt: &crate::DataType) -> i16 {
match dt {
crate::DataType::Boolean => 1,
crate::DataType::Int2 => 2,
crate::DataType::Int4 => 4,
crate::DataType::Int8 => 8,
crate::DataType::Float4 => 4,
crate::DataType::Float8 => 8,
crate::DataType::Text => -1, crate::DataType::Varchar(_) => -1,
crate::DataType::Uuid => 16,
_ => -1,
}
}
pub(super) fn tuple_to_pg_values(tuple: &Tuple) -> Vec<Option<Vec<u8>>> {
tuple.values.iter().map(|val| {
match val {
Value::Null => None,
Value::Boolean(b) => Some(if *b { b"t" } else { b"f" }.to_vec()),
Value::Int2(i) => Some(itoa::Buffer::new().format(*i).as_bytes().to_vec()),
Value::Int4(i) => Some(itoa::Buffer::new().format(*i).as_bytes().to_vec()),
Value::Int8(i) => Some(itoa::Buffer::new().format(*i).as_bytes().to_vec()),
Value::Float4(f) => Some(ryu::Buffer::new().format(*f).as_bytes().to_vec()),
Value::Float8(f) => Some(ryu::Buffer::new().format(*f).as_bytes().to_vec()),
Value::String(s) => Some(s.as_bytes().to_vec()),
Value::Bytes(b) => Some(b.clone()),
Value::Json(j) => Some(j.to_string().into_bytes()),
Value::Numeric(n) => Some(n.as_bytes().to_vec()),
Value::Uuid(u) => Some(u.to_string().into_bytes()),
Value::Timestamp(ts) => Some(
ts.naive_utc()
.format("%Y-%m-%d %H:%M:%S%.6f")
.to_string()
.into_bytes()
),
Value::Date(d) => Some(d.format("%Y-%m-%d").to_string().into_bytes()),
Value::Time(t) => Some(t.format("%H:%M:%S%.6f").to_string().into_bytes()),
Value::Interval(micros) => {
let total_secs = micros / 1_000_000;
let days = total_secs / 86400;
let hours = (total_secs % 86400) / 3600;
let mins = (total_secs % 3600) / 60;
let secs = total_secs % 60;
let s = if days > 0 {
format!("{} days {:02}:{:02}:{:02}", days, hours, mins, secs)
} else {
format!("{:02}:{:02}:{:02}", hours, mins, secs)
};
Some(s.into_bytes())
}
Value::Array(arr) => {
let mut buf = String::with_capacity(arr.len() * 8 + 2);
buf.push('{');
for (i, v) in arr.iter().enumerate() {
if i > 0 { buf.push(','); }
match v {
Value::String(s) => { buf.push('"'); buf.push_str(s); buf.push('"'); }
Value::Null => buf.push_str("NULL"),
other => buf.push_str(&other.to_string()),
}
}
buf.push('}');
Some(buf.into_bytes())
}
Value::Vector(v) => {
let mut buf = String::with_capacity(v.len() * 8 + 2);
buf.push('{');
let mut ryu_buf = ryu::Buffer::new();
for (i, x) in v.iter().enumerate() {
if i > 0 { buf.push(','); }
buf.push_str(ryu_buf.format(*x));
}
buf.push('}');
Some(buf.into_bytes())
}
Value::DictRef { dict_id } => Some(itoa::Buffer::new().format(*dict_id).as_bytes().to_vec()),
Value::CasRef { hash } => Some(hex::encode(hash).into_bytes()),
Value::ColumnarRef => Some(b"<columnar>".to_vec()),
}
}).collect()
}
fn pg_split_sql_respecting_quotes(sql: &str) -> Vec<String> {
let mut statements = Vec::new();
let mut current = String::new();
let mut in_single_quote = false;
let mut in_dollar: Option<String> = None; let bytes = sql.as_bytes();
let mut i = 0usize;
while i < bytes.len() {
let b = bytes[i];
if let Some(tag) = &in_dollar {
current.push(b as char);
if b == b'$' {
let close = format!("${tag}$");
if sql.get(i..i + close.len()) == Some(close.as_str()) {
for c in close.chars().skip(1) {
current.push(c);
}
i += close.len();
in_dollar = None;
continue;
}
}
i += 1;
continue;
}
if in_single_quote {
current.push(b as char);
if b == b'\'' {
if bytes.get(i + 1) == Some(&b'\'') {
current.push('\'');
i += 2;
continue;
}
in_single_quote = false;
} else if b == b'\\' {
if let Some(&next) = bytes.get(i + 1) {
current.push(next as char);
i += 2;
continue;
}
}
i += 1;
continue;
}
if b == b'$' {
let rest = &sql[i + 1..];
if let Some(end) = rest.find('$') {
let tag = &rest[..end];
if tag.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
current.push('$');
for c in tag.chars() { current.push(c); }
current.push('$');
i += 1 + end + 1;
in_dollar = Some(tag.to_string());
continue;
}
}
}
match b {
b'\'' => { in_single_quote = true; current.push('\''); }
b';' => {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
statements.push(trimmed);
}
current.clear();
}
_ => current.push(b as char),
}
i += 1;
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
statements.push(trimmed);
}
statements
}
fn pg_looks_like_do_block(trimmed: &str) -> bool {
let upper = trimmed.trim_start().to_ascii_uppercase();
upper.starts_with("DO $") || upper.starts_with("DO LANGUAGE ")
}
fn pg_split_exception(body: &str) -> (&str, Vec<String>) {
let upper = body.to_ascii_uppercase();
let padded = format!(" {upper} ");
let needle = " EXCEPTION ";
let Some(off_in_padded) = padded.find(needle) else {
return (body, Vec::new());
};
let offset = off_in_padded.saturating_sub(1);
let main = &body[..offset];
let exception_block = &body[offset + "EXCEPTION".len()..];
let mut codes = Vec::new();
let upper_eb = exception_block.to_ascii_uppercase();
let mut search_from = 0;
while let Some(rel) = upper_eb[search_from..].find("WHEN") {
let start = search_from + rel + "WHEN".len();
let after = &upper_eb[start..];
let then_pos = after.find("THEN").unwrap_or(after.len());
let conditions = &after[..then_pos];
for name in conditions.split(|c: char| c == ',' || c == '|' || c == 'O' && false) {
for token in name.split_whitespace() {
let cleaned: String = token.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '_')
.collect();
if cleaned.is_empty() || cleaned.eq_ignore_ascii_case("OR") {
continue;
}
codes.push(cleaned.to_lowercase());
}
}
search_from = start + then_pos.min(after.len());
if search_from >= upper_eb.len() {
break;
}
}
(main, codes)
}
fn pg_exception_matches(codes: &[String], error_message: &str) -> bool {
if codes.is_empty() {
return false;
}
let lower = error_message.to_ascii_lowercase();
for code in codes {
let matches = match code.as_str() {
"duplicate_object" | "duplicate_table" | "duplicate_column"
| "duplicate_database" | "duplicate_schema" | "duplicate_function"
| "duplicate_alias" => lower.contains("already exists"),
"unique_violation" => lower.contains("unique") || lower.contains("duplicate key"),
"undefined_table" | "undefined_object" | "undefined_column"
| "undefined_function" => lower.contains("does not exist") || lower.contains("not found"),
"others" => true,
_ => false,
};
if matches {
return true;
}
}
false
}
fn pg_detect_plpgsql(body: &str) -> Option<&'static str> {
let upper = body.to_ascii_uppercase();
let padded = format!(" {upper} ");
const KEYWORDS: &[&str] = &[
" DECLARE ", " IF ", " LOOP ", " FOR ",
" WHILE ", " RAISE ", " RETURN ", " PERFORM ",
" EXIT ", " CONTINUE ",
];
for kw in KEYWORDS {
if padded.contains(kw) {
let trimmed: &str = kw.trim();
return Some(match trimmed {
"DECLARE" => "DECLARE",
"IF" => "IF",
"LOOP" => "LOOP",
"FOR" => "FOR",
"WHILE" => "WHILE",
"RAISE" => "RAISE",
"RETURN" => "RETURN",
"PERFORM" => "PERFORM",
"EXIT" => "EXIT",
"CONTINUE" => "CONTINUE",
_ => "plpgsql",
});
}
}
if body.contains(":=") {
return Some(":=");
}
None
}
fn pg_extract_do_block_body(sql: &str) -> Option<&str> {
let trimmed = sql.trim();
let after_do = trimmed.get(2..)?.trim_start();
let after_lang = if after_do.to_ascii_uppercase().starts_with("LANGUAGE") {
let after = after_do.get("LANGUAGE".len()..)?.trim_start();
let ident_end = after.find(|c: char| !(c.is_ascii_alphanumeric() || c == '_'))?;
after.get(ident_end..)?.trim_start()
} else {
after_do
};
if !after_lang.starts_with('$') {
return None;
}
let rest = after_lang.get(1..)?;
let tag_end = rest.find('$')?;
let tag = rest.get(..tag_end)?;
let closer = format!("${tag}$");
let body_start_abs = sql.len() - rest.len() + tag_end + 1;
let body_search = sql.get(body_start_abs..)?;
let close_rel = body_search.find(&closer)?;
sql.get(body_start_abs..body_start_abs + close_rel)
}
fn pg_strip_begin_end(body: &str) -> &str {
let mut s = body.trim();
if s.to_ascii_uppercase().starts_with("BEGIN") {
s = s.get(5..).map(str::trim_start).unwrap_or(s);
}
let u = s.to_ascii_uppercase();
for suffix in ["END;", "END"] {
if u.ends_with(suffix) {
s = s.get(..s.len() - suffix.len()).map(str::trim_end).unwrap_or(s);
break;
}
}
s
}