Skip to main content

reddb_server/wire/postgres/
server.rs

1//! PostgreSQL wire-protocol listener (Phase 3.1 PG parity).
2//!
3//! Accepts TCP connections from PG-compatible clients, drives the startup
4//! handshake, and routes simple-query frames into the existing
5//! `RedDBRuntime::execute_query` path. Results are adapted back into PG
6//! `RowDescription` + `DataRow` frames via `types::value_to_pg_wire_bytes`.
7//!
8//! Phase 3.1 intentionally supports only the simple-query subset; extended
9//! query (Parse/Bind/Execute) arrives in 3.1.x once the prepared-statement
10//! registry is reusable from this layer.
11
12use std::sync::Arc;
13
14use tokio::io::{AsyncRead, AsyncWrite};
15use tokio::net::TcpListener;
16
17use super::catalog_views::translate_pg_catalog_query;
18use super::protocol::{
19    read_frame, read_startup, write_frame, write_raw_byte, BackendMessage, ColumnDescriptor,
20    FrontendMessage, PgWireError, TransactionStatus,
21};
22use super::types::{value_to_pg_wire_bytes, PgOid};
23use crate::runtime::RedDBRuntime;
24use crate::storage::query::unified::UnifiedRecord;
25use crate::storage::schema::Value;
26
27/// Startup-tuned configuration for the PG wire listener.
28#[derive(Debug, Clone)]
29pub struct PgWireConfig {
30    /// TCP bind address ("host:port"). The caller is responsible for
31    /// keeping this disjoint from the native wire / gRPC / HTTP listeners.
32    pub bind_addr: String,
33    /// PG version string sent back in `ParameterStatus`. Many drivers
34    /// sniff this to enable/disable features. RedDB advertises a
35    /// recent-enough version to get the broadest client support.
36    pub server_version: String,
37}
38
39impl Default for PgWireConfig {
40    fn default() -> Self {
41        Self {
42            bind_addr: "127.0.0.1:5432".to_string(),
43            server_version: "15.0 (RedDB 3.1)".to_string(),
44        }
45    }
46}
47
48/// Spawn the PG wire listener. Blocks until the listener errors out.
49/// Each connection is handled in its own tokio task.
50pub async fn start_pg_wire_listener(
51    config: PgWireConfig,
52    runtime: Arc<RedDBRuntime>,
53) -> Result<(), Box<dyn std::error::Error>> {
54    let listener = TcpListener::bind(&config.bind_addr).await?;
55    tracing::info!(
56        transport = "pg-wire",
57        bind = %config.bind_addr,
58        "listener online"
59    );
60    let cfg = Arc::new(config);
61    loop {
62        let (stream, peer) = listener.accept().await?;
63        let rt = Arc::clone(&runtime);
64        let cfg = Arc::clone(&cfg);
65        let peer_str = peer.to_string();
66        tokio::spawn(async move {
67            if let Err(e) = handle_connection(stream, rt, cfg).await {
68                tracing::warn!(
69                    transport = "pg-wire",
70                    peer = %peer_str,
71                    err = %e,
72                    "connection failed"
73                );
74            }
75        });
76    }
77}
78
79/// Drive one connection's lifetime: startup → authentication → query loop.
80pub(crate) async fn handle_connection<S>(
81    mut stream: S,
82    runtime: Arc<RedDBRuntime>,
83    config: Arc<PgWireConfig>,
84) -> Result<(), PgWireError>
85where
86    S: AsyncRead + AsyncWrite + Unpin + Send,
87{
88    // Handshake. The first frame may be SSLRequest / GSSENCRequest
89    // (pre-auth negotiation) or a plain Startup. Loop once to cover
90    // SSL-not-supported path: reply 'N' and expect the client to send
91    // a regular Startup next.
92    loop {
93        match read_startup(&mut stream).await? {
94            FrontendMessage::SslRequest | FrontendMessage::GssEncRequest => {
95                // 'N' = not supported — client continues in plaintext and
96                // re-sends a normal Startup on the same socket.
97                write_raw_byte(&mut stream, b'N').await?;
98                continue;
99            }
100            FrontendMessage::Startup(params) => {
101                send_auth_ok(&mut stream, &config, &params).await?;
102                break;
103            }
104            FrontendMessage::Unknown { .. } => {
105                // CancelRequest: no response expected; drop the socket.
106                return Ok(());
107            }
108            other => {
109                return Err(PgWireError::Protocol(format!(
110                    "unexpected startup frame: {other:?}"
111                )));
112            }
113        }
114    }
115
116    // Main query loop.
117    loop {
118        let frame = match read_frame(&mut stream).await {
119            Ok(f) => f,
120            Err(PgWireError::Eof) => return Ok(()),
121            Err(e) => return Err(e),
122        };
123
124        match frame {
125            FrontendMessage::Query(sql) => {
126                handle_simple_query(&mut stream, &runtime, &sql).await?;
127            }
128            FrontendMessage::Terminate => return Ok(()),
129            FrontendMessage::Sync | FrontendMessage::Flush => {
130                // These are part of the extended protocol. For simple-query
131                // sessions we still echo ReadyForQuery so robust clients
132                // (that mix S/H frames defensively) keep moving.
133                write_frame(
134                    &mut stream,
135                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
136                )
137                .await?;
138            }
139            FrontendMessage::PasswordMessage(_) => {
140                // Should only arrive during auth. Ignore post-auth.
141                continue;
142            }
143            FrontendMessage::Unknown { tag, .. } => {
144                send_error(
145                    &mut stream,
146                    "0A000",
147                    &format!("unsupported frame tag 0x{tag:02x}"),
148                )
149                .await?;
150                write_frame(
151                    &mut stream,
152                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
153                )
154                .await?;
155            }
156            other => {
157                send_error(
158                    &mut stream,
159                    "0A000",
160                    &format!("unsupported frame {other:?}"),
161                )
162                .await?;
163                write_frame(
164                    &mut stream,
165                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
166                )
167                .await?;
168            }
169        }
170    }
171}
172
173async fn send_auth_ok<S>(
174    stream: &mut S,
175    config: &PgWireConfig,
176    params: &super::protocol::StartupParams,
177) -> Result<(), PgWireError>
178where
179    S: AsyncWrite + Unpin,
180{
181    // Phase 3.1: trust auth. We always send AuthenticationOk.
182    write_frame(stream, &BackendMessage::AuthenticationOk).await?;
183
184    // Standard ParameterStatus frames. Drivers gate capabilities on these.
185    for (name, value) in [
186        ("server_version", config.server_version.as_str()),
187        ("server_encoding", "UTF8"),
188        ("client_encoding", "UTF8"),
189        ("DateStyle", "ISO, MDY"),
190        ("TimeZone", "UTC"),
191        ("integer_datetimes", "on"),
192        ("standard_conforming_strings", "on"),
193        (
194            "application_name",
195            params.get("application_name").unwrap_or(""),
196        ),
197    ] {
198        write_frame(
199            stream,
200            &BackendMessage::ParameterStatus {
201                name: name.to_string(),
202                value: value.to_string(),
203            },
204        )
205        .await?;
206    }
207
208    // BackendKeyData: (pid, secret_key). Used by CancelRequest; we don't
209    // honour cancels in 3.1 so random-ish values are fine.
210    write_frame(
211        stream,
212        &BackendMessage::BackendKeyData {
213            pid: std::process::id(),
214            key: 0xDEADBEEF,
215        },
216    )
217    .await?;
218
219    write_frame(
220        stream,
221        &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
222    )
223    .await?;
224    Ok(())
225}
226
227async fn handle_simple_query<S>(
228    stream: &mut S,
229    runtime: &RedDBRuntime,
230    sql: &str,
231) -> Result<(), PgWireError>
232where
233    S: AsyncWrite + Unpin,
234{
235    // Empty query convention: PG emits EmptyQueryResponse instead of a
236    // CommandComplete. Some clients (psql `\;`) rely on this.
237    if sql.trim().is_empty() {
238        write_frame(stream, &BackendMessage::EmptyQueryResponse).await?;
239        write_frame(
240            stream,
241            &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
242        )
243        .await?;
244        return Ok(());
245    }
246
247    let query_result = match translate_pg_catalog_query(runtime, sql) {
248        Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
249            query: sql.to_string(),
250            mode: crate::storage::query::modes::QueryMode::Sql,
251            statement: "select",
252            engine: "pg-catalog",
253            result,
254            affected_rows: 0,
255            statement_type: "select",
256        }),
257        Ok(None) => runtime.execute_query(sql),
258        Err(err) => Err(err),
259    };
260
261    match query_result {
262        Ok(result) => {
263            if result.statement_type == "select" {
264                emit_result_rows(stream, &result.result).await?;
265                write_frame(
266                    stream,
267                    &BackendMessage::CommandComplete(format!(
268                        "SELECT {}",
269                        result.result.records.len()
270                    )),
271                )
272                .await?;
273            } else {
274                // DDL / DML / config statements: echo the runtime's
275                // high-level statement tag back. PG format is
276                // "<CMD> [<OID>] <COUNT>"; we keep the count where
277                // applicable and fall back to the runtime's message.
278                let tag = match result.statement_type {
279                    "insert" => format!("INSERT 0 {}", result.affected_rows),
280                    "update" => format!("UPDATE {}", result.affected_rows),
281                    "delete" => format!("DELETE {}", result.affected_rows),
282                    other => other.to_uppercase(),
283                };
284                write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
285            }
286        }
287        Err(err) => {
288            // PG SQLSTATE class 42 covers syntax / binding errors; we use
289            // 42P01 (undefined_table) and 42601 (syntax_error) when we can
290            // detect; otherwise fall back to XX000 (internal error).
291            let code = classify_sqlstate(&err.to_string());
292            send_error(stream, code, &err.to_string()).await?;
293        }
294    }
295
296    write_frame(
297        stream,
298        &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
299    )
300    .await?;
301    Ok(())
302}
303
304async fn emit_result_rows<S>(
305    stream: &mut S,
306    result: &crate::storage::query::unified::UnifiedResult,
307) -> Result<(), PgWireError>
308where
309    S: AsyncWrite + Unpin,
310{
311    // RowDescription: derived from the first record's column ordering.
312    // When `result.columns` is non-empty we honour that order; otherwise
313    // we synthesise one from the record's field order.
314    let columns: Vec<String> = if !result.columns.is_empty() {
315        result.columns.clone()
316    } else if let Some(first) = result.records.first() {
317        record_field_names(first)
318    } else {
319        Vec::new()
320    };
321
322    // Peek at the first record for per-column type OIDs. When there's no
323    // data row we fall back to TEXT for every column — clients render
324    // empty result sets happily.
325    let type_oids: Vec<PgOid> = columns
326        .iter()
327        .map(|col| {
328            result
329                .records
330                .first()
331                .and_then(|r| record_get(r, col))
332                .map(PgOid::from_value)
333                .unwrap_or(PgOid::Text)
334        })
335        .collect();
336
337    let descriptors: Vec<ColumnDescriptor> = columns
338        .iter()
339        .zip(type_oids.iter())
340        .map(|(name, oid)| ColumnDescriptor {
341            name: name.clone(),
342            table_oid: 0,
343            column_attr: 0,
344            type_oid: oid.as_u32(),
345            type_size: -1,
346            type_mod: -1,
347            format: 0,
348        })
349        .collect();
350
351    write_frame(stream, &BackendMessage::RowDescription(descriptors)).await?;
352
353    for record in &result.records {
354        let fields: Vec<Option<Vec<u8>>> = columns
355            .iter()
356            .map(|col| record_get(record, col).and_then(value_to_pg_wire_bytes))
357            .collect();
358        write_frame(stream, &BackendMessage::DataRow(fields)).await?;
359    }
360
361    Ok(())
362}
363
364/// Best-effort field lookup on a `UnifiedRecord`. The record API lives in
365/// `storage::query::unified` and today uses `HashMap<String, Value>` under
366/// the hood — we use `get` if it exists, else fall back to serialised map.
367fn record_get<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
368    record.get(key)
369}
370
371/// Extract column names in iteration order from a single record. When
372/// the caller didn't supply an explicit `columns` projection we use the
373/// first record's field ordering as the canonical tuple shape.
374///
375/// HashMap iteration order is non-deterministic — for Phase 3.1 we
376/// accept the shuffle since PG clients receive the ordered header via
377/// RowDescription and match cells positionally. A stable ordering
378/// would require keeping an insertion-order index alongside `values`.
379fn record_field_names(record: &UnifiedRecord) -> Vec<String> {
380    // `column_names()` merges the columnar scan side-channel with
381    // the HashMap so scan rows (which populate only columnar) still
382    // surface their field names in PG wire output.
383    record
384        .column_names()
385        .into_iter()
386        .map(|k| k.to_string())
387        .collect()
388}
389
390async fn send_error<S>(stream: &mut S, code: &str, message: &str) -> Result<(), PgWireError>
391where
392    S: AsyncWrite + Unpin,
393{
394    write_frame(
395        stream,
396        &BackendMessage::ErrorResponse {
397            severity: "ERROR".to_string(),
398            code: code.to_string(),
399            message: message.to_string(),
400        },
401    )
402    .await
403}
404
405/// Heuristically map a runtime error message onto a PG SQLSTATE. Full
406/// coverage would map every `RedDBError` variant; this is enough for the
407/// common psql / JDBC paths.
408fn classify_sqlstate(msg: &str) -> &'static str {
409    let lower = msg.to_ascii_lowercase();
410    if lower.contains("not found") || lower.contains("does not exist") {
411        // 42P01 undefined_table; close enough for collection-not-found.
412        "42P01"
413    } else if lower.contains("parse") || lower.contains("expected") || lower.contains("syntax") {
414        "42601"
415    } else if lower.contains("already exists") {
416        "42P07"
417    } else if lower.contains("permission") || lower.contains("auth") {
418        "28000"
419    } else {
420        "XX000"
421    }
422}