use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use super::catalog_views::translate_pg_catalog_query;
use super::protocol::{
read_frame, read_startup, write_frame, write_raw_byte, BackendMessage, ColumnDescriptor,
FrontendMessage, PgWireError, TransactionStatus,
};
use super::types::{value_to_pg_wire_bytes, PgOid};
use crate::runtime::RedDBRuntime;
use crate::storage::query::unified::UnifiedRecord;
use crate::storage::schema::Value;
#[derive(Debug, Clone)]
pub struct PgWireConfig {
pub bind_addr: String,
pub server_version: String,
}
impl Default for PgWireConfig {
fn default() -> Self {
Self {
bind_addr: "127.0.0.1:5432".to_string(),
server_version: "15.0 (RedDB 3.1)".to_string(),
}
}
}
pub async fn start_pg_wire_listener(
config: PgWireConfig,
runtime: Arc<RedDBRuntime>,
) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(&config.bind_addr).await?;
tracing::info!(
transport = "pg-wire",
bind = %config.bind_addr,
"listener online"
);
let cfg = Arc::new(config);
loop {
let (stream, peer) = listener.accept().await?;
let rt = Arc::clone(&runtime);
let cfg = Arc::clone(&cfg);
let peer_str = peer.to_string();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, rt, cfg).await {
tracing::warn!(
transport = "pg-wire",
peer = %peer_str,
err = %e,
"connection failed"
);
}
});
}
}
pub(crate) async fn handle_connection<S>(
mut stream: S,
runtime: Arc<RedDBRuntime>,
config: Arc<PgWireConfig>,
) -> Result<(), PgWireError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
loop {
match read_startup(&mut stream).await? {
FrontendMessage::SslRequest | FrontendMessage::GssEncRequest => {
write_raw_byte(&mut stream, b'N').await?;
continue;
}
FrontendMessage::Startup(params) => {
send_auth_ok(&mut stream, &config, ¶ms).await?;
break;
}
FrontendMessage::Unknown { .. } => {
return Ok(());
}
other => {
return Err(PgWireError::Protocol(format!(
"unexpected startup frame: {other:?}"
)));
}
}
}
loop {
let frame = match read_frame(&mut stream).await {
Ok(f) => f,
Err(PgWireError::Eof) => return Ok(()),
Err(e) => return Err(e),
};
match frame {
FrontendMessage::Query(sql) => {
handle_simple_query(&mut stream, &runtime, &sql).await?;
}
FrontendMessage::Terminate => return Ok(()),
FrontendMessage::Sync | FrontendMessage::Flush => {
write_frame(
&mut stream,
&BackendMessage::ReadyForQuery(TransactionStatus::Idle),
)
.await?;
}
FrontendMessage::PasswordMessage(_) => {
continue;
}
FrontendMessage::Unknown { tag, .. } => {
send_error(
&mut stream,
"0A000",
&format!("unsupported frame tag 0x{tag:02x}"),
)
.await?;
write_frame(
&mut stream,
&BackendMessage::ReadyForQuery(TransactionStatus::Idle),
)
.await?;
}
other => {
send_error(
&mut stream,
"0A000",
&format!("unsupported frame {other:?}"),
)
.await?;
write_frame(
&mut stream,
&BackendMessage::ReadyForQuery(TransactionStatus::Idle),
)
.await?;
}
}
}
}
async fn send_auth_ok<S>(
stream: &mut S,
config: &PgWireConfig,
params: &super::protocol::StartupParams,
) -> Result<(), PgWireError>
where
S: AsyncWrite + Unpin,
{
write_frame(stream, &BackendMessage::AuthenticationOk).await?;
for (name, value) in [
("server_version", config.server_version.as_str()),
("server_encoding", "UTF8"),
("client_encoding", "UTF8"),
("DateStyle", "ISO, MDY"),
("TimeZone", "UTC"),
("integer_datetimes", "on"),
("standard_conforming_strings", "on"),
(
"application_name",
params.get("application_name").unwrap_or(""),
),
] {
write_frame(
stream,
&BackendMessage::ParameterStatus {
name: name.to_string(),
value: value.to_string(),
},
)
.await?;
}
write_frame(
stream,
&BackendMessage::BackendKeyData {
pid: std::process::id(),
key: 0xDEADBEEF,
},
)
.await?;
write_frame(
stream,
&BackendMessage::ReadyForQuery(TransactionStatus::Idle),
)
.await?;
Ok(())
}
async fn handle_simple_query<S>(
stream: &mut S,
runtime: &RedDBRuntime,
sql: &str,
) -> Result<(), PgWireError>
where
S: AsyncWrite + Unpin,
{
if sql.trim().is_empty() {
write_frame(stream, &BackendMessage::EmptyQueryResponse).await?;
write_frame(
stream,
&BackendMessage::ReadyForQuery(TransactionStatus::Idle),
)
.await?;
return Ok(());
}
let query_result = match translate_pg_catalog_query(runtime, sql) {
Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
query: sql.to_string(),
mode: crate::storage::query::modes::QueryMode::Sql,
statement: "select",
engine: "pg-catalog",
result,
affected_rows: 0,
statement_type: "select",
}),
Ok(None) => runtime.execute_query(sql),
Err(err) => Err(err),
};
match query_result {
Ok(result) => {
if result.statement_type == "select" {
emit_result_rows(stream, &result.result).await?;
write_frame(
stream,
&BackendMessage::CommandComplete(format!(
"SELECT {}",
result.result.records.len()
)),
)
.await?;
} else {
let tag = match result.statement_type {
"insert" => format!("INSERT 0 {}", result.affected_rows),
"update" => format!("UPDATE {}", result.affected_rows),
"delete" => format!("DELETE {}", result.affected_rows),
other => other.to_uppercase(),
};
write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
}
}
Err(err) => {
let code = classify_sqlstate(&err.to_string());
send_error(stream, code, &err.to_string()).await?;
}
}
write_frame(
stream,
&BackendMessage::ReadyForQuery(TransactionStatus::Idle),
)
.await?;
Ok(())
}
async fn emit_result_rows<S>(
stream: &mut S,
result: &crate::storage::query::unified::UnifiedResult,
) -> Result<(), PgWireError>
where
S: AsyncWrite + Unpin,
{
let columns: Vec<String> = if !result.columns.is_empty() {
result.columns.clone()
} else if let Some(first) = result.records.first() {
record_field_names(first)
} else {
Vec::new()
};
let type_oids: Vec<PgOid> = columns
.iter()
.map(|col| {
result
.records
.first()
.and_then(|r| record_get(r, col))
.map(PgOid::from_value)
.unwrap_or(PgOid::Text)
})
.collect();
let descriptors: Vec<ColumnDescriptor> = columns
.iter()
.zip(type_oids.iter())
.map(|(name, oid)| ColumnDescriptor {
name: name.clone(),
table_oid: 0,
column_attr: 0,
type_oid: oid.as_u32(),
type_size: -1,
type_mod: -1,
format: 0,
})
.collect();
write_frame(stream, &BackendMessage::RowDescription(descriptors)).await?;
for record in &result.records {
let fields: Vec<Option<Vec<u8>>> = columns
.iter()
.map(|col| record_get(record, col).and_then(value_to_pg_wire_bytes))
.collect();
write_frame(stream, &BackendMessage::DataRow(fields)).await?;
}
Ok(())
}
fn record_get<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
record.get(key)
}
fn record_field_names(record: &UnifiedRecord) -> Vec<String> {
record
.column_names()
.into_iter()
.map(|k| k.to_string())
.collect()
}
async fn send_error<S>(stream: &mut S, code: &str, message: &str) -> Result<(), PgWireError>
where
S: AsyncWrite + Unpin,
{
write_frame(
stream,
&BackendMessage::ErrorResponse {
severity: "ERROR".to_string(),
code: code.to_string(),
message: message.to_string(),
},
)
.await
}
fn classify_sqlstate(msg: &str) -> &'static str {
let lower = msg.to_ascii_lowercase();
if lower.contains("not found") || lower.contains("does not exist") {
"42P01"
} else if lower.contains("parse") || lower.contains("expected") || lower.contains("syntax") {
"42601"
} else if lower.contains("already exists") {
"42P07"
} else if lower.contains("permission") || lower.contains("auth") {
"28000"
} else {
"XX000"
}
}