#![allow(
clippy::cast_lossless,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::doc_markdown,
clippy::similar_names,
clippy::too_many_lines,
clippy::uninlined_format_args,
clippy::unreadable_literal,
clippy::assigning_clones
)]
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use std::thread;
use spg_engine::{EngineError, QueryResult, Role};
use spg_storage::{ColumnSchema, DataType, Row, Value};
use crate::ServerState;
const PROTOCOL_V3: u32 = 196608;
pub fn spawn_listener(
addr: &str,
state: Arc<ServerState>,
) -> std::io::Result<std::net::SocketAddr> {
let listener = TcpListener::bind(addr)?;
let local = listener.local_addr()?;
thread::spawn(move || {
for stream in listener.incoming() {
let Ok(stream) = stream else {
continue;
};
let state = Arc::clone(&state);
thread::spawn(move || {
if let Err(e) = handle_conn(stream, &state) {
eprintln!("spg-server: pg-wire conn error: {e}");
}
});
}
});
Ok(local)
}
fn handle_conn(mut stream: TcpStream, state: &Arc<ServerState>) -> std::io::Result<()> {
let _ = stream.set_nodelay(true);
let (user, params) = read_startup(&mut stream)?;
let _ = params;
let conn_state = Arc::new(crate::ConnState {
pid: std::process::id().wrapping_add(state.active_connections.load(
std::sync::atomic::Ordering::Relaxed,
) as u32),
user: user.clone(),
started_at_us: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_micros() as i64)
.unwrap_or(0),
current_sql: std::sync::RwLock::new(String::new()),
wait_event: std::sync::atomic::AtomicU8::new(0),
last_query_start_us: std::sync::atomic::AtomicI64::new(0),
in_transaction: std::sync::atomic::AtomicBool::new(false),
});
if let Ok(mut conns) = state.connections.write() {
conns.push(Arc::clone(&conn_state));
}
struct ConnGuard {
state: Arc<ServerState>,
conn: Arc<crate::ConnState>,
}
impl Drop for ConnGuard {
fn drop(&mut self) {
if let Ok(mut conns) = self.state.connections.write() {
conns.retain(|x| !Arc::ptr_eq(x, &self.conn));
}
}
}
let _conn_guard = ConnGuard {
state: Arc::clone(state),
conn: Arc::clone(&conn_state),
};
let has_users = state.engine.read().is_ok_and(|e| !e.users().is_empty());
let role = if has_users {
let user_has_scram = state
.engine
.read()
.ok()
.and_then(|e| {
e.users()
.iter()
.find_map(|(n, r)| (n == user).then(|| r.scram().is_some()))
})
.unwrap_or(false);
let outcome = if user_has_scram {
scram_auth(&mut stream, state, &user)?
} else {
cleartext_auth(&mut stream, state, &user)?
};
match outcome {
Some(r) => r,
None => return Ok(()), }
} else {
Role::Admin
};
send_msg(&mut stream, b'R', &0u32.to_be_bytes())?;
send_parameter_status(&mut stream, "server_version", "16.0 (spg-4.3)")?;
send_parameter_status(&mut stream, "client_encoding", "UTF8")?;
send_parameter_status(&mut stream, "DateStyle", "ISO, MDY")?;
send_parameter_status(&mut stream, "integer_datetimes", "on")?;
send_parameter_status(&mut stream, "standard_conforming_strings", "on")?;
let mut bkd = Vec::with_capacity(8);
bkd.extend_from_slice(&std::process::id().to_be_bytes());
bkd.extend_from_slice(&0u32.to_be_bytes());
send_msg(&mut stream, b'K', &bkd)?;
send_ready_for_query(&mut stream, b'I')?;
let mut tx_state = b'I'; let mut prepared: std::collections::HashMap<String, PreparedStmt> =
std::collections::HashMap::default();
let mut portals: std::collections::HashMap<String, Portal> =
std::collections::HashMap::default();
let mut settings: std::collections::HashMap<String, String> =
std::collections::HashMap::default();
const PIPELINE_FLUSH_BYTES: usize = 4096;
let mut wbuf: Vec<u8> = Vec::with_capacity(8192);
loop {
let mut header = [0u8; 5];
if let Err(e) = stream.read_exact(&mut header) {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
return Ok(());
}
return Err(e);
}
let msg_type = header[0];
let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
let body_len = len.saturating_sub(4);
let mut body = vec![0u8; body_len];
if body_len > 0 {
stream.read_exact(&mut body)?;
}
match msg_type {
b'Q' => {
let sql_bytes = body.strip_suffix(b"\0").unwrap_or(&body);
let now_us = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_micros() as i64)
.unwrap_or(0);
conn_state
.last_query_start_us
.store(now_us, std::sync::atomic::Ordering::Relaxed);
if let Ok(mut s) = conn_state.current_sql.write() {
*s = String::from_utf8_lossy(sql_bytes).to_string();
}
let Ok(sql_str) = std::str::from_utf8(sql_bytes) else {
send_error(&mut wbuf, "22021", "invalid UTF-8 in query")?;
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
continue;
};
let sql = sql_str.trim_end_matches(';').trim().to_string();
if let Some((name, value)) = parse_set_statement(&sql) {
settings.insert(name.to_ascii_lowercase(), value);
send_command_complete(&mut wbuf, "SET")?;
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
continue;
}
if let Some(name) = parse_show_statement(&sql) {
let resp = render_show(&name, &settings);
send_canned(&mut wbuf, &resp)?;
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
continue;
}
if let Some(copy) = parse_copy_intent(&sql) {
if !wbuf.is_empty() {
stream.write_all(&wbuf)?;
wbuf.clear();
}
match copy {
CopyIntent::From(table, opts) => {
handle_copy_from_stdin(
&mut stream,
state,
role,
&table,
&opts,
&mut tx_state,
)?;
}
CopyIntent::To(table) => {
handle_copy_to_stdout(&mut stream, state, role, &table, &mut tx_state)?;
}
}
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
continue;
}
if let Some(canned) = canned_response(&sql, state) {
send_canned(&mut wbuf, &canned)?;
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
continue;
}
conn_state
.wait_event
.store(1, std::sync::atomic::Ordering::Relaxed);
let result = execute_with_role(state, &sql, role);
conn_state
.wait_event
.store(0, std::sync::atomic::Ordering::Relaxed);
match result {
Ok(QueryResult::Rows { columns, rows }) => {
send_row_description(&mut wbuf, &columns)?;
let n = rows.len();
for row in &rows {
send_data_row(&mut wbuf, &columns, row)?;
}
send_command_complete(&mut wbuf, &format!("SELECT {n}"))?;
}
Ok(QueryResult::CommandOk { affected, modified_catalog }) => {
let tag = command_tag(&sql, affected);
send_command_complete(&mut wbuf, &tag)?;
if modified_catalog && state.audit_path.is_some() {
let _ = crate::append_audit_pub(state, &sql);
}
tx_state = if state.engine.read().is_ok_and(|e| e.in_transaction()) {
b'T'
} else {
b'I'
};
}
Err(e) => {
send_error(&mut wbuf, "42000", &e.to_string())?;
tx_state = if state.engine.read().is_ok_and(|e| e.in_transaction()) {
b'E'
} else {
b'I'
};
}
Ok(_) => {
send_error(&mut wbuf, "XX000", "unexpected QueryResult variant")?;
}
}
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
}
b'X' => {
if !wbuf.is_empty() {
let _ = stream.write_all(&wbuf);
}
return Ok(());
}
b'P' => {
if let Err(msg) = handle_parse(&body, &mut prepared, state) {
send_error(&mut wbuf, "42601", &msg)?;
} else {
send_msg(&mut wbuf, b'1', &[])?;
}
}
b'B' => {
match handle_bind(&body, &prepared) {
Ok(portal) => {
portals.insert(portal.0.clone(), portal.1);
send_msg(&mut wbuf, b'2', &[])?; }
Err(msg) => send_error(&mut wbuf, "42601", &msg)?,
}
}
b'D' => {
if !body.is_empty() {
let kind = body[0];
let name = cstring_at(&body, 1).unwrap_or_default();
let (param_oids, columns): (Vec<u32>, Vec<ColumnSchema>) = if kind == b'S' {
if let Some(stmt) = prepared.get(&name) {
let eng = state.engine.read().map_err(|_| {
std::io::Error::other("engine lock poisoned")
})?;
eng.describe_prepared(&stmt.ast)
} else {
(Vec::new(), Vec::new())
}
} else if kind == b'P' {
let cols = if let Some(portal) = portals.get(&name) {
if let Some(stmt) = prepared.get(&portal.stmt_name) {
let eng = state.engine.read().map_err(|_| {
std::io::Error::other("engine lock poisoned")
})?;
let (_, c) = eng.describe_prepared(&stmt.ast);
c
} else {
Vec::new()
}
} else {
Vec::new()
};
(Vec::new(), cols)
} else {
(Vec::new(), Vec::new())
};
if kind == b'S' {
let n = u16::try_from(param_oids.len()).map_err(|_| {
std::io::Error::other("too many parameters")
})?;
let mut pd = Vec::with_capacity(2 + param_oids.len() * 4);
pd.extend_from_slice(&n.to_be_bytes());
for oid in ¶m_oids {
pd.extend_from_slice(&oid.to_be_bytes());
}
send_msg(&mut wbuf, b't', &pd)?;
}
if columns.is_empty() {
send_msg(&mut wbuf, b'n', &[])?; } else {
send_row_description(&mut wbuf, &columns)?;
}
}
}
b'E' => {
if let Err(msg) = handle_execute(
&body,
&portals,
&prepared,
&mut wbuf,
state,
role,
&mut tx_state,
) {
send_error(&mut wbuf, "42000", &msg)?;
}
}
b'C' => {
if body.len() >= 2 {
let kind = body[0];
let name = cstring_at(&body, 1).unwrap_or_default();
if kind == b'S' {
prepared.remove(&name);
} else if kind == b'P' {
portals.remove(&name);
}
}
send_msg(&mut wbuf, b'3', &[])?; }
b'H' => {
if !wbuf.is_empty() {
stream.write_all(&wbuf)?;
wbuf.clear();
}
}
b'S' => {
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
}
b'd' | b'c' | b'f' => {
send_error(
&mut wbuf,
"08P01",
"unexpected CopyData/Done/Fail outside COPY mode",
)?;
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
}
_ => {
send_error(
&mut wbuf,
"08P01",
&format!("unknown frontend message type: 0x{msg_type:02x}"),
)?;
send_ready_for_query(&mut wbuf, tx_state)?;
stream.write_all(&wbuf)?;
wbuf.clear();
}
}
if wbuf.len() >= PIPELINE_FLUSH_BYTES {
stream.write_all(&wbuf)?;
wbuf.clear();
}
}
}
fn execute_with_role(
state: &Arc<ServerState>,
sql: &str,
role: Role,
) -> Result<QueryResult, EngineError> {
crate::try_lazy_preload_cold(state);
let lower_first = sql
.trim_start()
.split_ascii_whitespace()
.next()
.unwrap_or("")
.to_ascii_lowercase();
let is_read = matches!(lower_first.as_str(), "select" | "show");
if !is_read && !role.can_write() {
return Err(EngineError::Unsupported(
"permission denied: write requires admin or readwrite role".into(),
));
}
let is_user_mgmt = (lower_first == "create" || lower_first == "drop")
&& sql
.split_ascii_whitespace()
.nth(1)
.is_some_and(|w| w.eq_ignore_ascii_case("user"));
if is_user_mgmt && !role.can_manage_users() {
return Err(EngineError::Unsupported(
"permission denied: user management requires admin role".into(),
));
}
if is_read {
let engine = state
.engine
.read()
.map_err(|_| EngineError::Unsupported("engine rwlock poisoned".into()))?;
engine.execute_readonly(sql)
} else {
let mut engine = state
.engine
.write()
.map_err(|_| EngineError::Unsupported("engine rwlock poisoned".into()))?;
engine.execute(sql)
}
}
fn command_tag(sql: &str, affected: usize) -> String {
let first = sql
.trim_start()
.split_ascii_whitespace()
.next()
.unwrap_or("")
.to_ascii_uppercase();
match first.as_str() {
"INSERT" => format!("INSERT 0 {affected}"),
"UPDATE" => format!("UPDATE {affected}"),
"DELETE" => format!("DELETE {affected}"),
"BEGIN" => "BEGIN".to_string(),
"COMMIT" => "COMMIT".to_string(),
"ROLLBACK" => "ROLLBACK".to_string(),
other => other.to_string(), }
}
fn canned_response(sql: &str, state: &Arc<ServerState>) -> Option<CannedResponse> {
let lower = sql.trim().to_ascii_lowercase();
if lower.starts_with("select version()") || lower == "select version()" {
return Some(CannedResponse::single_text("version", "spg 4.6"));
}
if lower.starts_with("show transaction_isolation")
|| lower.starts_with("show transaction isolation level")
{
return Some(CannedResponse::single_text(
"transaction_isolation",
"read committed",
));
}
if lower.starts_with("show search_path") || lower == "show search_path" {
return Some(CannedResponse::single_text(
"search_path",
"\"$user\", public",
));
}
if lower.starts_with("show standard_conforming_strings") {
return Some(CannedResponse::single_text(
"standard_conforming_strings",
"on",
));
}
if lower.starts_with("select current_database()") || lower == "select current_database()" {
return Some(CannedResponse::single_text("current_database", "spg"));
}
if lower.starts_with("select current_schema()")
|| lower == "select current_schema()"
|| lower == "select current_schema"
{
return Some(CannedResponse::single_text("current_schema", "public"));
}
if lower == "select current_user" || lower == "select user" {
return Some(CannedResponse::single_text("current_user", "admin"));
}
if lower.starts_with("discard all") {
return Some(CannedResponse::Tag("DISCARD ALL"));
}
if lower.starts_with("discard temp")
|| lower.starts_with("discard sequences")
|| lower.starts_with("discard plans")
{
return Some(CannedResponse::Tag("DISCARD"));
}
if lower == "reset all" || lower.starts_with("reset ") {
return Some(CannedResponse::Tag("RESET"));
}
if lower.starts_with("vacuum") {
return Some(CannedResponse::Tag("VACUUM"));
}
if lower.starts_with("analyze") {
return Some(CannedResponse::Tag("ANALYZE"));
}
if lower.starts_with("cluster") {
return Some(CannedResponse::Tag("CLUSTER"));
}
if lower.starts_with("reindex") {
return Some(CannedResponse::Tag("REINDEX"));
}
if lower.starts_with("begin isolation level")
|| lower.starts_with("begin transaction isolation level")
|| lower.starts_with("start transaction isolation level")
|| lower.starts_with("set transaction isolation level")
|| lower.starts_with("set transaction read")
|| lower.starts_with("set transaction snapshot")
{
if lower.starts_with("set transaction") {
return Some(CannedResponse::Tag("SET"));
}
}
if mentions_pg_table(&lower, "pg_class") {
return Some(pg_class_response(state));
}
if mentions_pg_table(&lower, "pg_namespace") {
return Some(pg_namespace_response());
}
if mentions_pg_table(&lower, "pg_database") {
return Some(pg_database_response());
}
if mentions_pg_table(&lower, "pg_user") || mentions_pg_table(&lower, "pg_roles") {
return Some(pg_user_response(state));
}
if mentions_pg_table(&lower, "pg_tables") {
return Some(pg_tables_response(state));
}
None
}
fn mentions_pg_table(sql_lower: &str, table: &str) -> bool {
sql_lower.contains(&format!("from {table}"))
|| sql_lower.contains(&format!("from pg_catalog.{table}"))
|| sql_lower.contains(&format!("join {table}"))
|| sql_lower.contains(&format!("join pg_catalog.{table}"))
}
enum CannedResponse {
Rows {
columns: Vec<ColumnSchema>,
rows: Vec<Row>,
},
Tag(&'static str),
}
impl CannedResponse {
fn single_text(col: &'static str, val: &'static str) -> Self {
Self::Rows {
columns: vec![ColumnSchema::new(col, DataType::Text, false)],
rows: vec![Row::new(vec![Value::Text(val.to_string())])],
}
}
}
fn send_canned(stream: &mut dyn Write, c: &CannedResponse) -> std::io::Result<()> {
match c {
CannedResponse::Rows { columns, rows } => {
send_row_description(stream, columns)?;
for row in rows {
send_data_row(stream, columns, row)?;
}
send_command_complete(stream, &format!("SELECT {}", rows.len()))?;
}
CannedResponse::Tag(tag) => {
send_command_complete(stream, tag)?;
}
}
Ok(())
}
fn pg_class_response(state: &Arc<ServerState>) -> CannedResponse {
let columns = vec![
ColumnSchema::new("oid", DataType::BigInt, false),
ColumnSchema::new("relname", DataType::Text, false),
ColumnSchema::new("relkind", DataType::Text, false),
ColumnSchema::new("relnamespace", DataType::BigInt, false),
ColumnSchema::new("relowner", DataType::BigInt, false),
];
let rows = state
.engine
.read()
.map(|e| {
e.catalog()
.table_names()
.into_iter()
.enumerate()
.map(|(i, name)| {
Row::new(vec![
Value::BigInt(16384 + i as i64), Value::Text(name),
Value::Text("r".to_string()),
Value::BigInt(2200), Value::BigInt(10), ])
})
.collect()
})
.unwrap_or_default();
CannedResponse::Rows { columns, rows }
}
fn pg_namespace_response() -> CannedResponse {
let columns = vec![
ColumnSchema::new("oid", DataType::BigInt, false),
ColumnSchema::new("nspname", DataType::Text, false),
ColumnSchema::new("nspowner", DataType::BigInt, false),
];
let rows = vec![Row::new(vec![
Value::BigInt(2200),
Value::Text("public".to_string()),
Value::BigInt(10),
])];
CannedResponse::Rows { columns, rows }
}
fn pg_database_response() -> CannedResponse {
let columns = vec![
ColumnSchema::new("oid", DataType::BigInt, false),
ColumnSchema::new("datname", DataType::Text, false),
ColumnSchema::new("datdba", DataType::BigInt, false),
];
let rows = vec![Row::new(vec![
Value::BigInt(16384),
Value::Text("spg".to_string()),
Value::BigInt(10),
])];
CannedResponse::Rows { columns, rows }
}
fn pg_user_response(state: &Arc<ServerState>) -> CannedResponse {
let columns = vec![
ColumnSchema::new("usename", DataType::Text, false),
ColumnSchema::new("usesuper", DataType::Bool, false),
];
let rows = state
.engine
.read()
.map(|e| {
if e.users().is_empty() {
vec![Row::new(vec![
Value::Text("admin".to_string()),
Value::Bool(true),
])]
} else {
e.users()
.iter()
.map(|(name, rec)| {
Row::new(vec![
Value::Text(name.to_string()),
Value::Bool(matches!(rec.role, spg_engine::Role::Admin)),
])
})
.collect()
}
})
.unwrap_or_default();
CannedResponse::Rows { columns, rows }
}
fn pg_tables_response(state: &Arc<ServerState>) -> CannedResponse {
let columns = vec![
ColumnSchema::new("schemaname", DataType::Text, false),
ColumnSchema::new("tablename", DataType::Text, false),
ColumnSchema::new("tableowner", DataType::Text, false),
];
let rows = state
.engine
.read()
.map(|e| {
e.catalog()
.table_names()
.into_iter()
.map(|name| {
Row::new(vec![
Value::Text("public".to_string()),
Value::Text(name),
Value::Text("admin".to_string()),
])
})
.collect()
})
.unwrap_or_default();
CannedResponse::Rows { columns, rows }
}
#[derive(Debug, Clone)]
struct PreparedStmt {
ast: spg_sql::ast::Statement,
placeholder_count: u16,
param_type_oids: Vec<u32>,
}
#[derive(Debug, Clone)]
struct Portal {
stmt_name: String,
params: Vec<spg_storage::Value>,
}
fn cstring_at(body: &[u8], pos: usize) -> Option<String> {
let null_off = body[pos..].iter().position(|&b| b == 0)?;
let bytes = &body[pos..pos + null_off];
std::str::from_utf8(bytes).ok().map(str::to_string)
}
fn read_cstring<'a>(body: &'a [u8], cursor: &mut usize) -> Option<&'a str> {
let null_off = body[*cursor..].iter().position(|&b| b == 0)?;
let bytes = &body[*cursor..*cursor + null_off];
*cursor += null_off + 1;
std::str::from_utf8(bytes).ok()
}
fn handle_parse(
body: &[u8],
prepared: &mut std::collections::HashMap<String, PreparedStmt>,
state: &Arc<ServerState>,
) -> Result<(), String> {
let mut cur = 0;
let name = read_cstring(body, &mut cur)
.ok_or("Parse: name not null-terminated UTF-8")?
.to_string();
let sql = read_cstring(body, &mut cur)
.ok_or("Parse: SQL not null-terminated UTF-8")?
.trim_end_matches(';')
.trim()
.to_string();
if cur + 2 > body.len() {
return Err("Parse: missing parameter type count".into());
}
let oid_count = u16::from_be_bytes([body[cur], body[cur + 1]]) as usize;
cur += 2;
if cur + oid_count * 4 > body.len() {
return Err("Parse: truncated parameter OIDs".into());
}
let mut param_type_oids: Vec<u32> = Vec::with_capacity(oid_count);
for _ in 0..oid_count {
let oid = u32::from_be_bytes([body[cur], body[cur + 1], body[cur + 2], body[cur + 3]]);
param_type_oids.push(oid);
cur += 4;
}
let _ = cur; let mut eng = state
.engine
.write()
.map_err(|_| "Parse: engine lock poisoned".to_string())?;
let ast = eng
.prepare_cached(&sql)
.map_err(|e| format!("Parse: {e}"))?;
drop(eng);
let placeholder_count = count_placeholders(&sql);
prepared.insert(
name,
PreparedStmt {
ast,
placeholder_count,
param_type_oids,
},
);
Ok(())
}
fn count_placeholders(sql: &str) -> u16 {
let bytes = sql.as_bytes();
let mut max: u32 = 0;
let mut i = 0;
while i + 1 < bytes.len() {
if bytes[i] == b'$' && bytes[i + 1].is_ascii_digit() {
let mut j = i + 1;
let mut n: u32 = 0;
while j < bytes.len() && bytes[j].is_ascii_digit() {
n = n * 10 + u32::from(bytes[j] - b'0');
j += 1;
}
if n > max {
max = n;
}
i = j;
} else {
i += 1;
}
}
u16::try_from(max).unwrap_or(u16::MAX)
}
fn handle_bind(
body: &[u8],
prepared: &std::collections::HashMap<String, PreparedStmt>,
) -> Result<(String, Portal), String> {
let mut cur = 0;
let portal_name = read_cstring(body, &mut cur)
.ok_or("Bind: portal name not UTF-8")?
.to_string();
let stmt_name = read_cstring(body, &mut cur)
.ok_or("Bind: statement name not UTF-8")?
.to_string();
let stmt = prepared
.get(&stmt_name)
.ok_or_else(|| format!("Bind: prepared statement {stmt_name:?} not found"))?;
if cur + 2 > body.len() {
return Err("Bind: truncated format-code count".into());
}
let fmt_count = u16::from_be_bytes([body[cur], body[cur + 1]]) as usize;
cur += 2;
if cur + fmt_count * 2 > body.len() {
return Err("Bind: truncated format codes".into());
}
let mut formats = Vec::with_capacity(fmt_count);
for _ in 0..fmt_count {
formats.push(u16::from_be_bytes([body[cur], body[cur + 1]]));
cur += 2;
}
if cur + 2 > body.len() {
return Err("Bind: truncated parameter count".into());
}
let param_count = u16::from_be_bytes([body[cur], body[cur + 1]]) as usize;
cur += 2;
if usize::from(stmt.placeholder_count) != param_count {
return Err(format!(
"Bind: parameter count mismatch (SQL has {}, Bind has {param_count})",
stmt.placeholder_count
));
}
let mut params: Vec<spg_storage::Value> = Vec::with_capacity(param_count);
for i in 0..param_count {
if cur + 4 > body.len() {
return Err("Bind: truncated parameter length".into());
}
let len = i32::from_be_bytes([body[cur], body[cur + 1], body[cur + 2], body[cur + 3]]);
cur += 4;
if len < 0 {
params.push(spg_storage::Value::Null);
continue;
}
let len = len as usize;
if cur + len > body.len() {
return Err("Bind: parameter value truncated".into());
}
let fmt = match formats.len() {
0 => 0,
1 => formats[0],
_ => formats.get(i).copied().unwrap_or(0),
};
if fmt == 1 {
let oid = stmt.param_type_oids.get(i).copied().unwrap_or(0);
let v = decode_binary_param(oid, &body[cur..cur + len])?;
params.push(v);
cur += len;
continue;
}
if fmt != 0 {
return Err(format!("Bind: unsupported parameter format code {fmt}"));
}
let s = std::str::from_utf8(&body[cur..cur + len])
.map_err(|_| "Bind: text parameter not valid UTF-8".to_string())?;
params.push(text_param_to_value(s));
cur += len;
}
Ok((
portal_name,
Portal {
stmt_name,
params,
},
))
}
fn text_param_to_value(s: &str) -> spg_storage::Value {
let trimmed = s.trim();
if trimmed.eq_ignore_ascii_case("true") {
return spg_storage::Value::Bool(true);
}
if trimmed.eq_ignore_ascii_case("false") {
return spg_storage::Value::Bool(false);
}
if let Ok(n) = trimmed.parse::<i32>() {
return spg_storage::Value::Int(n);
}
if let Ok(n) = trimmed.parse::<i64>() {
return spg_storage::Value::BigInt(n);
}
if let Ok(x) = trimmed.parse::<f64>() {
return spg_storage::Value::Float(x);
}
if let Some(v) = parse_vector_text(trimmed) {
return spg_storage::Value::Vector(v);
}
spg_storage::Value::Text(s.to_string())
}
fn decode_binary_param(oid: u32, bytes: &[u8]) -> Result<spg_storage::Value, String> {
use spg_storage::Value;
match oid {
16 => {
if bytes.len() != 1 {
return Err(format!("Bind binary BOOL must be 1 byte, got {}", bytes.len()));
}
Ok(Value::Bool(bytes[0] != 0))
}
17 | 25 | 1043 => {
if oid == 17 {
let s = bytes
.iter()
.fold(String::with_capacity(2 + bytes.len() * 2), |mut acc, b| {
if acc.is_empty() {
acc.push('\\');
acc.push('x');
}
acc.push_str(&format!("{b:02x}"));
acc
});
Ok(Value::Text(if s.is_empty() { "\\x".into() } else { s }))
} else {
let s = std::str::from_utf8(bytes)
.map_err(|_| "Bind binary TEXT/VARCHAR: invalid UTF-8".to_string())?;
Ok(Value::Text(s.to_string()))
}
}
20 => {
if bytes.len() != 8 {
return Err(format!("Bind binary BIGINT must be 8 bytes, got {}", bytes.len()));
}
let n = i64::from_be_bytes(bytes.try_into().unwrap());
Ok(Value::BigInt(n))
}
21 => {
if bytes.len() != 2 {
return Err(format!("Bind binary INT2 must be 2 bytes, got {}", bytes.len()));
}
let n = i16::from_be_bytes(bytes.try_into().unwrap());
Ok(Value::SmallInt(n))
}
23 => {
if bytes.len() != 4 {
return Err(format!("Bind binary INT must be 4 bytes, got {}", bytes.len()));
}
let n = i32::from_be_bytes(bytes.try_into().unwrap());
Ok(Value::Int(n))
}
700 => {
if bytes.len() != 4 {
return Err(format!("Bind binary REAL must be 4 bytes, got {}", bytes.len()));
}
let f = f32::from_be_bytes(bytes.try_into().unwrap()) as f64;
Ok(Value::Float(f))
}
701 => {
if bytes.len() != 8 {
return Err(format!("Bind binary DOUBLE must be 8 bytes, got {}", bytes.len()));
}
let f = f64::from_be_bytes(bytes.try_into().unwrap());
Ok(Value::Float(f))
}
1082 => {
if bytes.len() != 4 {
return Err(format!("Bind binary DATE must be 4 bytes, got {}", bytes.len()));
}
const PG_EPOCH_DAYS_FROM_UNIX: i32 = 10957;
let pg_days = i32::from_be_bytes(bytes.try_into().unwrap());
Ok(Value::Date(pg_days + PG_EPOCH_DAYS_FROM_UNIX))
}
1114 | 1184 => {
if bytes.len() != 8 {
return Err(format!(
"Bind binary TIMESTAMP must be 8 bytes, got {}",
bytes.len()
));
}
const PG_EPOCH_MICROS_FROM_UNIX: i64 = 946_684_800_000_000;
let pg_micros = i64::from_be_bytes(bytes.try_into().unwrap());
Ok(Value::Timestamp(pg_micros + PG_EPOCH_MICROS_FROM_UNIX))
}
1700 => decode_binary_numeric(bytes),
0 => Err(
"Bind: binary format requires the parameter OID to be declared in Parse \
(got OID=0 meaning unknown)".into(),
),
_ => Err(format!(
"Bind: binary format for OID {oid} not supported in v6.3.4"
)),
}
}
fn decode_binary_numeric(bytes: &[u8]) -> Result<spg_storage::Value, String> {
if bytes.len() < 8 {
return Err("Bind binary NUMERIC: header truncated".into());
}
let ndigits = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
let weight = i16::from_be_bytes([bytes[2], bytes[3]]);
let sign = u16::from_be_bytes([bytes[4], bytes[5]]);
let dscale = u16::from_be_bytes([bytes[6], bytes[7]]);
if bytes.len() != 8 + ndigits * 2 {
return Err(format!(
"Bind binary NUMERIC: declared ndigits={ndigits} but body has {} bytes",
bytes.len()
));
}
if sign == 0xC000 {
return Err("Bind binary NUMERIC: NaN sign not supported".into());
}
let mut digits: Vec<u16> = Vec::with_capacity(ndigits);
for i in 0..ndigits {
let off = 8 + i * 2;
let d = u16::from_be_bytes([bytes[off], bytes[off + 1]]);
digits.push(d);
}
let mut unscaled: i128 = 0;
let total_digits_after_weight = ndigits as i32 - 1 - weight as i32;
for (k, d) in digits.iter().enumerate() {
let exp = (weight as i32 - k as i32) * 4;
let final_exp = exp + dscale as i32;
if final_exp >= 0 {
let pow = 10i128.pow(final_exp as u32);
unscaled = unscaled
.checked_add((*d as i128).checked_mul(pow).ok_or("NUMERIC overflow")?)
.ok_or("NUMERIC overflow")?;
} else {
let shift = (-final_exp) as u32;
let pow = 10i128.pow(shift);
unscaled = unscaled
.checked_add((*d as i128) / pow)
.ok_or("NUMERIC overflow")?;
}
}
let _ = total_digits_after_weight; let final_value = if sign == 0x4000 { -unscaled } else { unscaled };
let scale = u8::try_from(dscale).map_err(|_| "NUMERIC dscale too large".to_string())?;
Ok(spg_storage::Value::Numeric {
scaled: final_value,
scale,
})
}
fn parse_vector_text(s: &str) -> Option<Vec<f32>> {
let bytes = s.as_bytes();
if bytes.len() < 2 || bytes[0] != b'[' || bytes[bytes.len() - 1] != b']' {
return None;
}
let inner = &s[1..s.len() - 1];
if inner.trim().is_empty() {
return Some(Vec::new());
}
let mut out = Vec::with_capacity(inner.split(',').count());
for tok in inner.split(',') {
let t = tok.trim();
let f: f32 = t.parse().ok()?;
if !f.is_finite() {
return None;
}
out.push(f);
}
Some(out)
}
fn looks_numeric(s: &str) -> bool {
let s = s.trim();
if s.is_empty() {
return false;
}
s.parse::<i64>().is_ok() || s.parse::<f64>().is_ok()
}
fn handle_execute(
body: &[u8],
portals: &std::collections::HashMap<String, Portal>,
prepared: &std::collections::HashMap<String, PreparedStmt>,
stream: &mut dyn Write,
state: &Arc<ServerState>,
role: Role,
tx_state: &mut u8,
) -> Result<(), String> {
let mut cur = 0;
let portal_name = read_cstring(body, &mut cur)
.ok_or("Execute: portal name not UTF-8")?
.to_string();
if cur + 4 > body.len() {
return Err("Execute: missing max-rows".into());
}
let portal = portals
.get(&portal_name)
.ok_or_else(|| format!("Execute: portal {portal_name:?} not found"))?;
let stmt = prepared.get(&portal.stmt_name).ok_or_else(|| {
format!(
"Execute: prepared statement {:?} dropped while a portal held a reference",
portal.stmt_name
)
})?;
let needs_write = !matches!(&stmt.ast, spg_sql::ast::Statement::Select(_));
let result = {
let mut eng = if needs_write {
state
.engine
.write()
.map_err(|_| "Execute: engine lock poisoned".to_string())?
} else {
state
.engine
.write()
.map_err(|_| "Execute: engine lock poisoned".to_string())?
};
if needs_write && matches!(role, Role::ReadOnly) {
return Err("permission denied: readonly role".into());
}
eng.execute_prepared(stmt.ast.clone(), &portal.params)
};
match result {
Ok(QueryResult::Rows { columns, rows }) => {
send_row_description(stream, &columns).map_err(|e| e.to_string())?;
let n = rows.len();
for row in &rows {
send_data_row(stream, &columns, row).map_err(|e| e.to_string())?;
}
send_command_complete(stream, &format!("SELECT {n}")).map_err(|e| e.to_string())?;
}
Ok(QueryResult::CommandOk { affected, .. }) => {
let tag = command_tag_for_ast(&stmt.ast, affected);
send_command_complete(stream, &tag).map_err(|e| e.to_string())?;
*tx_state = if state.engine.read().is_ok_and(|e| e.in_transaction()) {
b'T'
} else {
b'I'
};
}
Err(e) => return Err(e.to_string()),
Ok(_) => return Err("unexpected QueryResult variant".to_string()),
}
Ok(())
}
fn command_tag_for_ast(stmt: &spg_sql::ast::Statement, affected: usize) -> String {
use spg_sql::ast::Statement;
match stmt {
Statement::Insert(_) => format!("INSERT 0 {affected}"),
Statement::Update(_) => format!("UPDATE {affected}"),
Statement::Delete(_) => format!("DELETE {affected}"),
Statement::CreateTable(_) => "CREATE TABLE".to_string(),
Statement::CreateIndex(_) => "CREATE INDEX".to_string(),
Statement::AlterIndex(_) => "ALTER INDEX".to_string(),
Statement::Begin => "BEGIN".to_string(),
Statement::Commit => "COMMIT".to_string(),
Statement::Rollback => "ROLLBACK".to_string(),
Statement::Savepoint(_) => "SAVEPOINT".to_string(),
Statement::RollbackToSavepoint(_) => "ROLLBACK".to_string(),
Statement::ReleaseSavepoint(_) => "RELEASE".to_string(),
Statement::CreateUser(_) => "CREATE USER".to_string(),
Statement::DropUser(_) => "DROP USER".to_string(),
Statement::CreatePublication(_) => "CREATE PUBLICATION".to_string(),
Statement::DropPublication(_) => "DROP PUBLICATION".to_string(),
Statement::CreateSubscription(_) => "CREATE SUBSCRIPTION".to_string(),
Statement::DropSubscription(_) => "DROP SUBSCRIPTION".to_string(),
_ => "OK".to_string(),
}
}
fn parse_set_statement(sql: &str) -> Option<(String, String)> {
let trimmed = sql.trim();
let lower = trimmed.to_ascii_lowercase();
let rest = lower.strip_prefix("set ")?;
let rest = rest
.strip_prefix("session ")
.or_else(|| rest.strip_prefix("local "))
.unwrap_or(rest);
let (name, value_part) = if let Some(idx) = rest.find('=') {
(rest[..idx].trim().to_string(), rest[idx + 1..].trim())
} else if let Some(idx) = rest.find(" to ") {
(rest[..idx].trim().to_string(), rest[idx + 4..].trim())
} else {
return None;
};
if name.is_empty() {
return None;
}
let value = value_part.trim_matches('\'').trim_matches('"').to_string();
Some((name, value))
}
fn parse_show_statement(sql: &str) -> Option<String> {
let lower = sql.trim().to_ascii_lowercase();
let rest = lower.strip_prefix("show ")?;
let name = rest.split_ascii_whitespace().next()?.to_string();
Some(name)
}
fn render_show(name: &str, settings: &std::collections::HashMap<String, String>) -> CannedResponse {
if name == "all" {
let mut entries: Vec<(String, String)> = known_defaults()
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
for (k, v) in settings {
if let Some(pos) = entries.iter().position(|(name, _)| name == k) {
entries[pos].1.clone_from(v);
} else {
entries.push((k.clone(), v.clone()));
}
}
entries.sort();
let columns = vec![
ColumnSchema::new("name", DataType::Text, false),
ColumnSchema::new("setting", DataType::Text, false),
ColumnSchema::new("description", DataType::Text, true),
];
let rows: Vec<Row> = entries
.into_iter()
.map(|(n, v)| Row::new(vec![Value::Text(n), Value::Text(v), Value::Null]))
.collect();
return CannedResponse::Rows { columns, rows };
}
let value = settings
.get(name)
.cloned()
.or_else(|| {
known_defaults()
.iter()
.find(|(k, _)| *k == name)
.map(|(_, v)| (*v).to_string())
})
.unwrap_or_default();
let columns = vec![ColumnSchema::new(name.to_string(), DataType::Text, false)];
CannedResponse::Rows {
columns,
rows: vec![Row::new(vec![Value::Text(value)])],
}
}
fn known_defaults() -> &'static [(&'static str, &'static str)] {
&[
("application_name", ""),
("client_encoding", "UTF8"),
("datestyle", "ISO, MDY"),
("default_transaction_isolation", "read committed"),
("default_transaction_read_only", "off"),
("intervalstyle", "postgres"),
("search_path", "\"$user\", public"),
("server_encoding", "UTF8"),
("server_version", "16.0 (spg-4.19)"),
("standard_conforming_strings", "on"),
("statement_timeout", "0"),
("timezone", "UTC"),
("transaction_isolation", "read committed"),
("transaction_read_only", "off"),
]
}
enum CopyIntent {
From(String, CopyOptions),
To(String),
}
#[derive(Debug, Clone, Default)]
struct CopyOptions {
pub skip: u64,
pub on_error_set_null: bool,
pub format_json: bool,
}
fn parse_copy_intent(sql: &str) -> Option<CopyIntent> {
let trimmed = sql.trim();
let lower = trimmed.to_ascii_lowercase();
let rest = lower.strip_prefix("copy ")?;
let mut it = rest.split_ascii_whitespace();
let table = it.next()?.to_string();
let dir = it.next()?;
let endpoint = it.next()?;
match (dir, endpoint) {
("from", "stdin") => {
let opts = parse_copy_options(&lower);
Some(CopyIntent::From(table, opts))
}
("to", "stdout") => Some(CopyIntent::To(table)),
_ => None,
}
}
fn parse_copy_options(lower: &str) -> CopyOptions {
let mut opts = CopyOptions::default();
let Some(open) = lower.find('(') else {
return opts;
};
let Some(close) = lower[open..].find(')') else {
return opts;
};
let inner = &lower[open + 1..open + close];
for pair in inner.split(',') {
let pair = pair.trim();
if pair.is_empty() {
continue;
}
let mut it = pair.split_ascii_whitespace();
let key = it.next().unwrap_or("");
let val = it.next().unwrap_or("");
match key {
"skip" => {
opts.skip = val.parse().unwrap_or(0);
}
"on_error" => {
if val == "set_null" {
opts.on_error_set_null = true;
}
}
"format" => {
if val == "json" {
opts.format_json = true;
}
}
_ => {}
}
}
opts
}
fn handle_copy_from_stdin(
stream: &mut TcpStream,
state: &Arc<ServerState>,
role: Role,
table: &str,
opts: &CopyOptions,
tx_state: &mut u8,
) -> std::io::Result<()> {
if !role.can_write() {
send_error(
stream,
"42501",
"permission denied: COPY FROM requires admin or readwrite",
)?;
return Ok(());
}
let Some(col_count) = state
.engine
.read()
.ok()
.and_then(|e| e.catalog().get(table).map(|t| t.schema().columns.len()))
else {
send_error(
stream,
"42P01",
&format!("relation {table:?} does not exist"),
)?;
return Ok(());
};
let mut body = Vec::with_capacity(3 + col_count * 2);
body.push(0);
body.extend_from_slice(&u16::try_from(col_count).unwrap_or(0).to_be_bytes());
for _ in 0..col_count {
body.extend_from_slice(&0u16.to_be_bytes());
}
send_msg(stream, b'G', &body)?;
let mut buf: Vec<u8> = Vec::new();
let mut inserted: u64 = 0;
let mut skipped: u64 = 0;
loop {
let mut header = [0u8; 5];
stream.read_exact(&mut header)?;
let ty = header[0];
let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
let body_len = len.saturating_sub(4);
let mut body = vec![0u8; body_len];
if body_len > 0 {
stream.read_exact(&mut body)?;
}
match ty {
b'd' => buf.extend_from_slice(&body),
b'c' => {
if !buf.is_empty() && !buf.ends_with(b"\n") {
buf.push(b'\n');
}
break;
}
b'f' => {
send_error(stream, "57014", "client aborted COPY")?;
return Ok(());
}
other => {
send_error(
stream,
"08P01",
&format!("unexpected frame 0x{other:02x} during COPY"),
)?;
return Ok(());
}
}
if let Err(msg) =
process_copy_chunk(state, table, &mut buf, &mut inserted, &mut skipped, opts)
{
send_error(stream, "22P02", &msg)?;
return Ok(());
}
}
if let Err(msg) =
process_copy_chunk(state, table, &mut buf, &mut inserted, &mut skipped, opts)
{
send_error(stream, "22P02", &msg)?;
return Ok(());
}
send_command_complete(stream, &format!("COPY {inserted}"))?;
*tx_state = if state.engine.read().is_ok_and(|e| e.in_transaction()) {
b'T'
} else {
b'I'
};
Ok(())
}
fn process_copy_chunk(
state: &Arc<ServerState>,
table: &str,
buf: &mut Vec<u8>,
inserted: &mut u64,
skipped: &mut u64,
opts: &CopyOptions,
) -> Result<(), String> {
while let Some(nl) = buf.iter().position(|&b| b == b'\n') {
let line: Vec<u8> = buf.drain(..=nl).collect();
let line = &line[..line.len() - 1]; if line == b"\\." {
return Ok(());
}
if line.is_empty() {
continue;
}
let row_text =
std::str::from_utf8(line).map_err(|_| "COPY row not valid UTF-8".to_string())?;
if *skipped < opts.skip {
*skipped += 1;
continue;
}
let sql = if opts.format_json {
match build_copy_insert_from_json(state, table, row_text, opts.on_error_set_null) {
Ok(s) => s,
Err(e) => {
if opts.on_error_set_null {
continue;
}
return Err(format!("COPY FORMAT JSON: {e}"));
}
}
} else {
let values = decode_copy_text_row(row_text);
build_copy_insert(table, &values)
};
let mut engine = state
.engine
.write()
.map_err(|_| "engine rwlock poisoned".to_string())?;
match engine.execute(&sql) {
Ok(_) => *inserted += 1,
Err(e) => {
if opts.on_error_set_null {
continue;
}
return Err(format!("COPY row INSERT failed: {e}"));
}
}
}
Ok(())
}
fn build_copy_insert_from_json(
state: &Arc<ServerState>,
table: &str,
line: &str,
_on_error: bool,
) -> Result<String, String> {
let cols: Vec<String> = state
.engine
.read()
.ok()
.and_then(|e| {
e.catalog()
.get(table)
.map(|t| t.schema().columns.iter().map(|c| c.name.clone()).collect())
})
.ok_or_else(|| format!("relation {table:?} does not exist"))?;
let pairs = parse_json_object_top_level(line)?;
let mut sql = format!("INSERT INTO {table} (");
for (i, c) in cols.iter().enumerate() {
if i > 0 {
sql.push(',');
}
sql.push_str(c);
}
sql.push_str(") VALUES (");
for (i, c) in cols.iter().enumerate() {
if i > 0 {
sql.push(',');
}
let val = pairs.iter().find(|(k, _)| k == c).map(|(_, v)| v.clone());
match val {
None => sql.push_str("NULL"),
Some(v) => sql.push_str(&v),
}
}
sql.push(')');
Ok(sql)
}
fn parse_json_object_top_level(s: &str) -> Result<Vec<(String, String)>, String> {
let trimmed = s.trim();
let body = trimmed
.strip_prefix('{')
.and_then(|s| s.strip_suffix('}'))
.ok_or_else(|| "expected JSON object {...}".to_string())?;
let mut out = Vec::new();
let mut chars = body.chars().peekable();
while chars.peek().is_some() {
skip_ws(&mut chars);
if chars.peek().is_none() {
break;
}
let key = read_json_string(&mut chars)?;
skip_ws(&mut chars);
if chars.next() != Some(':') {
return Err("expected ':' after key".into());
}
skip_ws(&mut chars);
let val_sql = read_json_value_as_sql(&mut chars)?;
out.push((key, val_sql));
skip_ws(&mut chars);
if chars.peek() == Some(&',') {
chars.next();
}
}
Ok(out)
}
fn skip_ws(chars: &mut std::iter::Peekable<std::str::Chars>) {
while let Some(&c) = chars.peek() {
if c.is_whitespace() {
chars.next();
} else {
break;
}
}
}
fn read_json_string(chars: &mut std::iter::Peekable<std::str::Chars>) -> Result<String, String> {
if chars.next() != Some('"') {
return Err("expected '\"' to start string".into());
}
let mut out = String::new();
loop {
match chars.next() {
None => return Err("unterminated JSON string".into()),
Some('"') => return Ok(out),
Some('\\') => {
let n = chars.next().ok_or("trailing escape")?;
out.push(match n {
'"' => '"',
'\\' => '\\',
'/' => '/',
'b' => '\u{08}',
'f' => '\u{0c}',
'n' => '\n',
'r' => '\r',
't' => '\t',
other => other,
});
}
Some(c) => out.push(c),
}
}
}
fn read_json_value_as_sql(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> Result<String, String> {
skip_ws(chars);
let Some(&first) = chars.peek() else {
return Err("expected value".into());
};
match first {
'"' => {
let s = read_json_string(chars)?;
Ok(format!("'{}'", s.replace('\'', "''")))
}
't' | 'f' => {
let mut s = String::new();
while let Some(&c) = chars.peek() {
if c.is_ascii_alphabetic() {
s.push(c);
chars.next();
} else {
break;
}
}
if s == "true" {
Ok("TRUE".to_string())
} else if s == "false" {
Ok("FALSE".to_string())
} else {
Err(format!("invalid bool token: {s}"))
}
}
'n' => {
for expected in ['n', 'u', 'l', 'l'] {
if chars.next() != Some(expected) {
return Err("invalid null token".into());
}
}
Ok("NULL".to_string())
}
c if c == '-' || c.is_ascii_digit() => {
let mut s = String::new();
while let Some(&c) = chars.peek() {
if c == '-' || c == '+' || c == '.' || c == 'e' || c == 'E' || c.is_ascii_digit() {
s.push(c);
chars.next();
} else {
break;
}
}
Ok(s)
}
other => Err(format!("unsupported JSON value start: {other:?}")),
}
}
fn decode_copy_text_row(line: &str) -> Vec<Option<String>> {
line.split('\t')
.map(|cell| {
if cell == "\\N" {
None
} else {
let mut out = String::with_capacity(cell.len());
let mut chars = cell.chars();
while let Some(c) = chars.next() {
if c == '\\'
&& let Some(n) = chars.next()
{
out.push(match n {
'b' => '\u{08}',
'f' => '\u{0c}',
'n' => '\n',
'r' => '\r',
't' => '\t',
'v' => '\u{0b}',
'\\' => '\\',
other => other,
});
} else {
out.push(c);
}
}
Some(out)
}
})
.collect()
}
fn build_copy_insert(table: &str, values: &[Option<String>]) -> String {
let mut sql = format!("INSERT INTO {table} VALUES (");
for (i, v) in values.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
match v {
None => sql.push_str("NULL"),
Some(s) => {
if looks_numeric(s) || matches!(s.as_str(), "true" | "false" | "TRUE" | "FALSE") {
sql.push_str(s);
} else {
sql.push('\'');
for ch in s.chars() {
if ch == '\'' {
sql.push('\'');
}
sql.push(ch);
}
sql.push('\'');
}
}
}
}
sql.push(')');
sql
}
fn handle_copy_to_stdout(
stream: &mut TcpStream,
state: &Arc<ServerState>,
role: Role,
table: &str,
tx_state: &mut u8,
) -> std::io::Result<()> {
let _ = role.can_read(); let sql = format!("SELECT * FROM {table}");
let result = execute_with_role(state, &sql, role);
let (columns, rows) = match result {
Ok(QueryResult::Rows { columns, rows }) => (columns, rows),
Ok(QueryResult::CommandOk { .. }) => {
send_error(stream, "42000", "COPY TO source produced no rows")?;
return Ok(());
}
Err(e) => {
send_error(stream, "42000", &e.to_string())?;
return Ok(());
}
Ok(_) => {
send_error(stream, "XX000", "unexpected QueryResult variant")?;
return Ok(());
}
};
let col_count = columns.len();
let mut body = Vec::with_capacity(3 + col_count * 2);
body.push(0);
body.extend_from_slice(&u16::try_from(col_count).unwrap_or(0).to_be_bytes());
for _ in 0..col_count {
body.extend_from_slice(&0u16.to_be_bytes());
}
send_msg(stream, b'H', &body)?;
let n = rows.len();
for row in &rows {
let mut line = String::new();
for (i, v) in row.values.iter().enumerate() {
if i > 0 {
line.push('\t');
}
line.push_str(&encode_copy_cell(v));
}
line.push('\n');
send_msg(stream, b'd', line.as_bytes())?;
}
send_msg(stream, b'c', &[])?; send_command_complete(stream, &format!("COPY {n}"))?;
let _ = tx_state;
Ok(())
}
fn encode_copy_cell(v: &spg_storage::Value) -> String {
use spg_storage::Value;
match v {
Value::Null => "\\N".to_string(),
Value::Bool(b) => if *b { "t" } else { "f" }.to_string(),
Value::SmallInt(n) => n.to_string(),
Value::Int(n) => n.to_string(),
Value::BigInt(n) => n.to_string(),
Value::Float(x) => format!("{x}"),
Value::Text(s) | Value::Json(s) => escape_copy_cell(s),
Value::Numeric { scaled, scale } => spg_engine::eval::format_numeric(*scaled, *scale),
Value::Date(d) => spg_engine::eval::format_date(*d),
Value::Timestamp(t) => spg_engine::eval::format_timestamp(*t),
Value::Interval { months, micros } => spg_engine::eval::format_interval(*months, *micros),
Value::Vector(v) => {
let parts: Vec<String> = v.iter().map(std::string::ToString::to_string).collect();
escape_copy_cell(&format!("[{}]", parts.join(", ")))
}
Value::Sq8Vector(q) => {
let parts: Vec<String> = spg_storage::quantize::dequantize(q)
.iter()
.map(std::string::ToString::to_string)
.collect();
escape_copy_cell(&format!("[{}]", parts.join(", ")))
}
Value::HalfVector(h) => {
let parts: Vec<String> = h
.to_f32_vec()
.iter()
.map(std::string::ToString::to_string)
.collect();
escape_copy_cell(&format!("[{}]", parts.join(", ")))
}
_ => escape_copy_cell(&format!("{v:?}")),
}
}
fn escape_copy_cell(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'\\' => out.push_str("\\\\"),
'\t' => out.push_str("\\t"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\u{08}' => out.push_str("\\b"),
'\u{0c}' => out.push_str("\\f"),
'\u{0b}' => out.push_str("\\v"),
c => out.push(c),
}
}
out
}
fn cleartext_auth(
stream: &mut TcpStream,
state: &Arc<ServerState>,
user: &str,
) -> std::io::Result<Option<Role>> {
send_msg(stream, b'R', &3u32.to_be_bytes())?;
let pwd = read_password_message(stream)?;
let verified = state
.engine
.read()
.ok()
.and_then(|e| e.verify_user(user, &pwd));
if let Some(r) = verified {
Ok(Some(r))
} else {
send_error(stream, "28P01", "password authentication failed")?;
Ok(None)
}
}
fn scram_auth(
stream: &mut TcpStream,
state: &Arc<ServerState>,
user: &str,
) -> std::io::Result<Option<Role>> {
let mut sasl_body = Vec::new();
sasl_body.extend_from_slice(&10u32.to_be_bytes());
sasl_body.extend_from_slice(b"SCRAM-SHA-256\0\0");
send_msg(stream, b'R', &sasl_body)?;
let mut header = [0u8; 5];
stream.read_exact(&mut header)?;
if header[0] != b'p' {
send_error(stream, "28000", "expected SASLInitialResponse")?;
return Ok(None);
}
let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
let mut body = vec![0u8; len.saturating_sub(4)];
stream.read_exact(&mut body)?;
let Some(mech_end) = body.iter().position(|&b| b == 0) else {
send_error(
stream,
"28000",
"SASLInitial: mechanism name not null-terminated",
)?;
return Ok(None);
};
let mech = std::str::from_utf8(&body[..mech_end]).unwrap_or("");
if mech != "SCRAM-SHA-256" {
send_error(
stream,
"28000",
&format!("only SCRAM-SHA-256 is supported, got {mech:?}"),
)?;
return Ok(None);
}
let mut cur = mech_end + 1;
if cur + 4 > body.len() {
send_error(stream, "28000", "SASLInitial: missing client-first length")?;
return Ok(None);
}
let cf_len =
u32::from_be_bytes([body[cur], body[cur + 1], body[cur + 2], body[cur + 3]]) as usize;
cur += 4;
if cur + cf_len > body.len() {
send_error(stream, "28000", "SASLInitial: client-first truncated")?;
return Ok(None);
}
let Ok(client_first_msg) = std::str::from_utf8(&body[cur..cur + cf_len]).map(str::to_string)
else {
send_error(stream, "28000", "SASLInitial: client-first not UTF-8")?;
return Ok(None);
};
let client_first = match crate::scram::parse_client_first(&client_first_msg) {
Ok(c) => c,
Err(e) => {
send_error(stream, "28000", &e.to_string())?;
return Ok(None);
}
};
let secrets = state
.engine
.read()
.ok()
.and_then(|e| {
e.users()
.iter()
.find(|(n, _)| *n == user)
.map(|(_, r)| r.scram().cloned())
})
.flatten();
let Some(secrets) = secrets else {
send_error(stream, "28P01", "user has no SCRAM verifier on file")?;
return Ok(None);
};
let server_nonce = match random_nonce_b64(18) {
Ok(n) => n,
Err(e) => {
send_error(stream, "58000", &format!("RNG failure: {e}"))?;
return Ok(None);
}
};
let combined_nonce = format!("{}{}", client_first.client_nonce, server_nonce);
let server_first = crate::scram::build_server_first(&combined_nonce, &secrets);
let mut cont_body = Vec::new();
cont_body.extend_from_slice(&11u32.to_be_bytes());
cont_body.extend_from_slice(server_first.as_bytes());
send_msg(stream, b'R', &cont_body)?;
let mut header = [0u8; 5];
stream.read_exact(&mut header)?;
if header[0] != b'p' {
send_error(stream, "28000", "expected SASLResponse")?;
return Ok(None);
}
let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
let mut body = vec![0u8; len.saturating_sub(4)];
stream.read_exact(&mut body)?;
let Ok(client_final_msg) = std::str::from_utf8(&body).map(str::to_string) else {
send_error(stream, "28000", "SASLResponse: client-final not UTF-8")?;
return Ok(None);
};
let client_final = match crate::scram::parse_client_final(&client_final_msg) {
Ok(f) => f,
Err(e) => {
send_error(stream, "28000", &e.to_string())?;
return Ok(None);
}
};
if client_final.combined_nonce != combined_nonce {
send_error(stream, "28000", "SCRAM: nonce mismatch")?;
return Ok(None);
}
let server_signature = match crate::scram::verify_and_sign(
&secrets,
&client_first.bare,
&server_first,
&client_final.without_proof,
&client_final.client_proof,
) {
Ok(s) => s,
Err(e) => {
send_error(stream, "28P01", &e.to_string())?;
return Ok(None);
}
};
let mut final_body = Vec::new();
final_body.extend_from_slice(&12u32.to_be_bytes());
final_body.extend_from_slice(server_signature.as_bytes());
send_msg(stream, b'R', &final_body)?;
let role = state.engine.read().ok().and_then(|e| {
e.users()
.iter()
.find(|(n, _)| *n == user)
.map(|(_, r)| r.role)
});
Ok(role)
}
fn random_nonce_b64(byte_len: usize) -> std::io::Result<String> {
let mut buf = vec![0u8; byte_len];
std::fs::File::open("/dev/urandom")?.read_exact(&mut buf)?;
Ok(spg_crypto::base64::encode(&buf))
}
fn read_startup(stream: &mut TcpStream) -> std::io::Result<(String, Vec<(String, String)>)> {
loop {
let mut len_bytes = [0u8; 4];
stream.read_exact(&mut len_bytes)?;
let total = u32::from_be_bytes(len_bytes) as usize;
if total < 8 {
return Err(std::io::Error::other("startup message too short"));
}
let mut body = vec![0u8; total - 4];
stream.read_exact(&mut body)?;
let proto = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
if proto == 80877103 {
stream.write_all(b"N")?;
continue;
}
if proto == 80877104 {
stream.write_all(b"N")?;
continue;
}
if proto != PROTOCOL_V3 {
return Err(std::io::Error::other(format!(
"unsupported protocol version: {proto}"
)));
}
let mut params = Vec::new();
let mut user = String::new();
let mut p = 4;
while p < body.len() {
let k_end = body[p..]
.iter()
.position(|&b| b == 0)
.ok_or_else(|| std::io::Error::other("startup key not null-terminated"))?;
let key = std::str::from_utf8(&body[p..p + k_end])
.map_err(|_| std::io::Error::other("startup key not UTF-8"))?
.to_string();
p += k_end + 1;
if key.is_empty() {
break;
}
let v_end = body[p..]
.iter()
.position(|&b| b == 0)
.ok_or_else(|| std::io::Error::other("startup value not null-terminated"))?;
let value = std::str::from_utf8(&body[p..p + v_end])
.map_err(|_| std::io::Error::other("startup value not UTF-8"))?
.to_string();
p += v_end + 1;
if key == "user" {
user = value.clone();
}
params.push((key, value));
}
return Ok((user, params));
}
}
fn read_password_message(stream: &mut TcpStream) -> std::io::Result<String> {
let mut header = [0u8; 5];
stream.read_exact(&mut header)?;
if header[0] != b'p' {
return Err(std::io::Error::other("expected PasswordMessage"));
}
let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
let body_len = len.saturating_sub(4);
let mut body = vec![0u8; body_len];
stream.read_exact(&mut body)?;
let pw = body.strip_suffix(b"\0").unwrap_or(&body);
std::str::from_utf8(pw)
.map(str::to_string)
.map_err(|_| std::io::Error::other("password not UTF-8"))
}
fn send_msg(stream: &mut dyn Write, ty: u8, body: &[u8]) -> std::io::Result<()> {
let len = u32::try_from(body.len() + 4)
.map_err(|_| std::io::Error::other("PG message body too large"))?;
let mut out = Vec::with_capacity(5 + body.len());
out.push(ty);
out.extend_from_slice(&len.to_be_bytes());
out.extend_from_slice(body);
stream.write_all(&out)
}
fn send_parameter_status(stream: &mut dyn Write, key: &str, value: &str) -> std::io::Result<()> {
let mut body = Vec::with_capacity(key.len() + value.len() + 2);
body.extend_from_slice(key.as_bytes());
body.push(0);
body.extend_from_slice(value.as_bytes());
body.push(0);
send_msg(stream, b'S', &body)
}
fn send_ready_for_query(stream: &mut dyn Write, state: u8) -> std::io::Result<()> {
send_msg(stream, b'Z', &[state])
}
fn send_command_complete(stream: &mut dyn Write, tag: &str) -> std::io::Result<()> {
let mut body = Vec::with_capacity(tag.len() + 1);
body.extend_from_slice(tag.as_bytes());
body.push(0);
send_msg(stream, b'C', &body)
}
fn send_error(stream: &mut dyn Write, sqlstate: &str, msg: &str) -> std::io::Result<()> {
let mut body = Vec::new();
body.push(b'S');
body.extend_from_slice(b"ERROR");
body.push(0);
body.push(b'C');
body.extend_from_slice(sqlstate.as_bytes());
body.push(0);
body.push(b'M');
body.extend_from_slice(msg.as_bytes());
body.push(0);
body.push(0);
send_msg(stream, b'E', &body)
}
fn send_row_description(stream: &mut dyn Write, cols: &[ColumnSchema]) -> std::io::Result<()> {
let n = u16::try_from(cols.len())
.map_err(|_| std::io::Error::other("RowDescription: too many columns"))?;
let mut body = Vec::with_capacity(2 + cols.len() * 24);
body.extend_from_slice(&n.to_be_bytes());
for c in cols {
body.extend_from_slice(c.name.as_bytes());
body.push(0);
body.extend_from_slice(&0u32.to_be_bytes()); body.extend_from_slice(&0u16.to_be_bytes()); body.extend_from_slice(&pg_type_oid(c.ty).to_be_bytes()); body.extend_from_slice(&pg_type_len(c.ty).to_be_bytes()); body.extend_from_slice(&(-1i32).to_be_bytes()); body.extend_from_slice(&0u16.to_be_bytes()); }
send_msg(stream, b'T', &body)
}
fn send_data_row(stream: &mut dyn Write, cols: &[ColumnSchema], row: &Row) -> std::io::Result<()> {
let n = u16::try_from(row.values.len())
.map_err(|_| std::io::Error::other("DataRow: too many cells"))?;
let mut body = Vec::with_capacity(2 + row.values.len() * 8);
body.extend_from_slice(&n.to_be_bytes());
for (i, v) in row.values.iter().enumerate() {
let text = value_to_pg_text(v, cols.get(i).map(|c| c.ty));
match text {
None => body.extend_from_slice(&(-1i32).to_be_bytes()), Some(s) => {
let len = i32::try_from(s.len())
.map_err(|_| std::io::Error::other("cell value too large"))?;
body.extend_from_slice(&len.to_be_bytes());
body.extend_from_slice(s.as_bytes());
}
}
}
send_msg(stream, b'D', &body)
}
const fn pg_type_oid(ty: DataType) -> u32 {
match ty {
DataType::Bool => 16,
DataType::SmallInt => 21,
DataType::Int => 23,
DataType::BigInt => 20,
DataType::Float => 701,
DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::Vector { .. } => 25,
DataType::Timestamp => 1114,
DataType::Date => 1082,
DataType::Interval => 1186,
DataType::Numeric { .. } => 1700,
DataType::Json => 114, }
}
const fn pg_type_len(ty: DataType) -> i16 {
match ty {
DataType::Bool => 1,
DataType::SmallInt => 2,
DataType::Int | DataType::Date => 4,
DataType::BigInt | DataType::Float | DataType::Timestamp => 8,
DataType::Interval => 16,
_ => -1, }
}
fn value_to_pg_text(v: &Value, _ty: Option<DataType>) -> Option<String> {
Some(match v {
Value::Null => return None,
Value::Bool(b) => if *b { "t" } else { "f" }.to_string(),
Value::SmallInt(n) => n.to_string(),
Value::Int(n) => n.to_string(),
Value::BigInt(n) => n.to_string(),
Value::Float(f) => format!("{f}"),
Value::Text(s) | Value::Json(s) => s.clone(),
Value::Timestamp(micros) => format_timestamp(*micros),
Value::Date(days) => format_date(*days),
Value::Interval { months, micros } => format!("P{months}M{micros}U"),
Value::Numeric { scaled, scale } => format_numeric(*scaled, *scale),
Value::Vector(v) => {
let parts: Vec<String> = v.iter().map(std::string::ToString::to_string).collect();
format!("[{}]", parts.join(", "))
}
Value::Sq8Vector(q) => {
let parts: Vec<String> = spg_storage::quantize::dequantize(q)
.iter()
.map(std::string::ToString::to_string)
.collect();
format!("[{}]", parts.join(", "))
}
Value::HalfVector(h) => {
let parts: Vec<String> = h
.to_f32_vec()
.iter()
.map(std::string::ToString::to_string)
.collect();
format!("[{}]", parts.join(", "))
}
_ => format!("{v:?}"),
})
}
fn format_timestamp(micros: i64) -> String {
let secs = micros.div_euclid(1_000_000);
let frac = micros.rem_euclid(1_000_000) as u32;
let (y, m, d, hh, mm, ss) = secs_to_ymdhms(secs);
if frac == 0 {
format!("{y:04}-{m:02}-{d:02} {hh:02}:{mm:02}:{ss:02}")
} else {
format!("{y:04}-{m:02}-{d:02} {hh:02}:{mm:02}:{ss:02}.{frac:06}")
}
}
fn format_date(days: i32) -> String {
let secs = i64::from(days) * 86_400;
let (y, m, d, _, _, _) = secs_to_ymdhms(secs);
format!("{y:04}-{m:02}-{d:02}")
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn secs_to_ymdhms(secs: i64) -> (i32, u32, u32, u32, u32, u32) {
let day = secs.div_euclid(86_400);
let tod = secs.rem_euclid(86_400) as u32;
let hh = tod / 3600;
let mm = (tod / 60) % 60;
let ss = tod % 60;
let z = day + 719_468;
let era = z.div_euclid(146_097);
let doe = (z - era * 146_097) as u32;
let yoe = (doe - doe / 1460 + doe / 36_524 - doe / 146_096) / 365;
let y_int = yoe as i32 + (era as i32) * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let y = if m <= 2 { y_int + 1 } else { y_int };
(y, m, d, hh, mm, ss)
}
fn format_numeric(scaled: i128, scale: u8) -> String {
if scale == 0 {
return scaled.to_string();
}
let s = scaled.abs().to_string();
let scale = scale as usize;
let (int_part, frac_part) = if s.len() > scale {
let split = s.len() - scale;
(&s[..split], &s[split..])
} else {
("0", s.as_str())
};
let mut frac_pad = "0".repeat(scale.saturating_sub(frac_part.len()));
frac_pad.push_str(frac_part);
if scaled < 0 {
format!("-{int_part}.{frac_pad}")
} else {
format!("{int_part}.{frac_pad}")
}
}