reddb_server/wire/postgres/
server.rs1use std::sync::Arc;
13
14use tokio::io::{AsyncRead, AsyncWrite};
15use tokio::net::TcpListener;
16
17use super::catalog_views::translate_pg_catalog_query;
18use super::protocol::{
19 read_frame, read_startup, write_frame, write_raw_byte, BackendMessage, ColumnDescriptor,
20 FrontendMessage, PgWireError, TransactionStatus,
21};
22use super::types::{value_to_pg_wire_bytes, PgOid};
23use crate::runtime::RedDBRuntime;
24use crate::storage::query::unified::UnifiedRecord;
25use crate::storage::schema::Value;
26
27#[derive(Debug, Clone)]
29pub struct PgWireConfig {
30 pub bind_addr: String,
33 pub server_version: String,
37}
38
39impl Default for PgWireConfig {
40 fn default() -> Self {
41 Self {
42 bind_addr: "127.0.0.1:5432".to_string(),
43 server_version: "15.0 (RedDB 3.1)".to_string(),
44 }
45 }
46}
47
48pub async fn start_pg_wire_listener(
51 config: PgWireConfig,
52 runtime: Arc<RedDBRuntime>,
53) -> Result<(), Box<dyn std::error::Error>> {
54 let listener = TcpListener::bind(&config.bind_addr).await?;
55 tracing::info!(
56 transport = "pg-wire",
57 bind = %config.bind_addr,
58 "listener online"
59 );
60 let cfg = Arc::new(config);
61 loop {
62 let (stream, peer) = listener.accept().await?;
63 let rt = Arc::clone(&runtime);
64 let cfg = Arc::clone(&cfg);
65 let peer_str = peer.to_string();
66 tokio::spawn(async move {
67 if let Err(e) = handle_connection(stream, rt, cfg).await {
68 tracing::warn!(
69 transport = "pg-wire",
70 peer = %peer_str,
71 err = %e,
72 "connection failed"
73 );
74 }
75 });
76 }
77}
78
79pub(crate) async fn handle_connection<S>(
81 mut stream: S,
82 runtime: Arc<RedDBRuntime>,
83 config: Arc<PgWireConfig>,
84) -> Result<(), PgWireError>
85where
86 S: AsyncRead + AsyncWrite + Unpin + Send,
87{
88 loop {
93 match read_startup(&mut stream).await? {
94 FrontendMessage::SslRequest | FrontendMessage::GssEncRequest => {
95 write_raw_byte(&mut stream, b'N').await?;
98 continue;
99 }
100 FrontendMessage::Startup(params) => {
101 send_auth_ok(&mut stream, &config, ¶ms).await?;
102 break;
103 }
104 FrontendMessage::Unknown { .. } => {
105 return Ok(());
107 }
108 other => {
109 return Err(PgWireError::Protocol(format!(
110 "unexpected startup frame: {other:?}"
111 )));
112 }
113 }
114 }
115
116 loop {
118 let frame = match read_frame(&mut stream).await {
119 Ok(f) => f,
120 Err(PgWireError::Eof) => return Ok(()),
121 Err(e) => return Err(e),
122 };
123
124 match frame {
125 FrontendMessage::Query(sql) => {
126 handle_simple_query(&mut stream, &runtime, &sql).await?;
127 }
128 FrontendMessage::Terminate => return Ok(()),
129 FrontendMessage::Sync | FrontendMessage::Flush => {
130 write_frame(
134 &mut stream,
135 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
136 )
137 .await?;
138 }
139 FrontendMessage::PasswordMessage(_) => {
140 continue;
142 }
143 FrontendMessage::Unknown { tag, .. } => {
144 send_error(
145 &mut stream,
146 "0A000",
147 &format!("unsupported frame tag 0x{tag:02x}"),
148 )
149 .await?;
150 write_frame(
151 &mut stream,
152 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
153 )
154 .await?;
155 }
156 other => {
157 send_error(
158 &mut stream,
159 "0A000",
160 &format!("unsupported frame {other:?}"),
161 )
162 .await?;
163 write_frame(
164 &mut stream,
165 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
166 )
167 .await?;
168 }
169 }
170 }
171}
172
173async fn send_auth_ok<S>(
174 stream: &mut S,
175 config: &PgWireConfig,
176 params: &super::protocol::StartupParams,
177) -> Result<(), PgWireError>
178where
179 S: AsyncWrite + Unpin,
180{
181 write_frame(stream, &BackendMessage::AuthenticationOk).await?;
183
184 for (name, value) in [
186 ("server_version", config.server_version.as_str()),
187 ("server_encoding", "UTF8"),
188 ("client_encoding", "UTF8"),
189 ("DateStyle", "ISO, MDY"),
190 ("TimeZone", "UTC"),
191 ("integer_datetimes", "on"),
192 ("standard_conforming_strings", "on"),
193 (
194 "application_name",
195 params.get("application_name").unwrap_or(""),
196 ),
197 ] {
198 write_frame(
199 stream,
200 &BackendMessage::ParameterStatus {
201 name: name.to_string(),
202 value: value.to_string(),
203 },
204 )
205 .await?;
206 }
207
208 write_frame(
211 stream,
212 &BackendMessage::BackendKeyData {
213 pid: std::process::id(),
214 key: 0xDEADBEEF,
215 },
216 )
217 .await?;
218
219 write_frame(
220 stream,
221 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
222 )
223 .await?;
224 Ok(())
225}
226
227async fn handle_simple_query<S>(
228 stream: &mut S,
229 runtime: &RedDBRuntime,
230 sql: &str,
231) -> Result<(), PgWireError>
232where
233 S: AsyncWrite + Unpin,
234{
235 if sql.trim().is_empty() {
238 write_frame(stream, &BackendMessage::EmptyQueryResponse).await?;
239 write_frame(
240 stream,
241 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
242 )
243 .await?;
244 return Ok(());
245 }
246
247 let query_result = match translate_pg_catalog_query(runtime, sql) {
248 Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
249 query: sql.to_string(),
250 mode: crate::storage::query::modes::QueryMode::Sql,
251 statement: "select",
252 engine: "pg-catalog",
253 result,
254 affected_rows: 0,
255 statement_type: "select",
256 }),
257 Ok(None) => runtime.execute_query(sql),
258 Err(err) => Err(err),
259 };
260
261 match query_result {
262 Ok(result) => {
263 if result.statement_type == "select" {
264 emit_result_rows(stream, &result.result).await?;
265 write_frame(
266 stream,
267 &BackendMessage::CommandComplete(format!(
268 "SELECT {}",
269 result.result.records.len()
270 )),
271 )
272 .await?;
273 } else {
274 let tag = match result.statement_type {
279 "insert" => format!("INSERT 0 {}", result.affected_rows),
280 "update" => format!("UPDATE {}", result.affected_rows),
281 "delete" => format!("DELETE {}", result.affected_rows),
282 other => other.to_uppercase(),
283 };
284 write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
285 }
286 }
287 Err(err) => {
288 let code = classify_sqlstate(&err.to_string());
292 send_error(stream, code, &err.to_string()).await?;
293 }
294 }
295
296 write_frame(
297 stream,
298 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
299 )
300 .await?;
301 Ok(())
302}
303
304async fn emit_result_rows<S>(
305 stream: &mut S,
306 result: &crate::storage::query::unified::UnifiedResult,
307) -> Result<(), PgWireError>
308where
309 S: AsyncWrite + Unpin,
310{
311 let columns: Vec<String> = if !result.columns.is_empty() {
315 result.columns.clone()
316 } else if let Some(first) = result.records.first() {
317 record_field_names(first)
318 } else {
319 Vec::new()
320 };
321
322 let type_oids: Vec<PgOid> = columns
326 .iter()
327 .map(|col| {
328 result
329 .records
330 .first()
331 .and_then(|r| record_get(r, col))
332 .map(PgOid::from_value)
333 .unwrap_or(PgOid::Text)
334 })
335 .collect();
336
337 let descriptors: Vec<ColumnDescriptor> = columns
338 .iter()
339 .zip(type_oids.iter())
340 .map(|(name, oid)| ColumnDescriptor {
341 name: name.clone(),
342 table_oid: 0,
343 column_attr: 0,
344 type_oid: oid.as_u32(),
345 type_size: -1,
346 type_mod: -1,
347 format: 0,
348 })
349 .collect();
350
351 write_frame(stream, &BackendMessage::RowDescription(descriptors)).await?;
352
353 for record in &result.records {
354 let fields: Vec<Option<Vec<u8>>> = columns
355 .iter()
356 .map(|col| record_get(record, col).and_then(value_to_pg_wire_bytes))
357 .collect();
358 write_frame(stream, &BackendMessage::DataRow(fields)).await?;
359 }
360
361 Ok(())
362}
363
364fn record_get<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
368 record.get(key)
369}
370
371fn record_field_names(record: &UnifiedRecord) -> Vec<String> {
380 record
384 .column_names()
385 .into_iter()
386 .map(|k| k.to_string())
387 .collect()
388}
389
390async fn send_error<S>(stream: &mut S, code: &str, message: &str) -> Result<(), PgWireError>
391where
392 S: AsyncWrite + Unpin,
393{
394 write_frame(
395 stream,
396 &BackendMessage::ErrorResponse {
397 severity: "ERROR".to_string(),
398 code: code.to_string(),
399 message: message.to_string(),
400 },
401 )
402 .await
403}
404
405fn classify_sqlstate(msg: &str) -> &'static str {
409 let lower = msg.to_ascii_lowercase();
410 if lower.contains("not found") || lower.contains("does not exist") {
411 "42P01"
413 } else if lower.contains("parse") || lower.contains("expected") || lower.contains("syntax") {
414 "42601"
415 } else if lower.contains("already exists") {
416 "42P07"
417 } else if lower.contains("permission") || lower.contains("auth") {
418 "28000"
419 } else {
420 "XX000"
421 }
422}