Skip to main content

drasi_source_postgres/
protocol.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{anyhow, Result};
16use bytes::{BufMut, BytesMut};
17use std::collections::HashMap;
18
19pub const PROTOCOL_VERSION: u32 = 0x00030000; // 3.0
20
21#[derive(Debug, Clone)]
22pub enum FrontendMessage {
23    StartupMessage(StartupMessage),
24    PasswordMessage(String),
25    Query(String),
26    Parse {
27        name: String,
28        query: String,
29        param_types: Vec<u32>,
30    },
31    Bind {
32        portal: String,
33        statement: String,
34        formats: Vec<i16>,
35        values: Vec<Option<Vec<u8>>>,
36        result_formats: Vec<i16>,
37    },
38    Execute {
39        portal: String,
40        max_rows: i32,
41    },
42    Sync,
43    Terminate,
44    CopyData(Vec<u8>),
45    CopyDone,
46    CopyFail(String),
47    // SASL authentication
48    SASLInitialResponse {
49        mechanism: String,
50        data: Vec<u8>,
51    },
52    SASLResponse(Vec<u8>),
53    // Replication specific
54    StandbyStatusUpdate {
55        write_lsn: u64,
56        flush_lsn: u64,
57        apply_lsn: u64,
58        timestamp: i64,
59        reply: u8,
60    },
61}
62
63#[derive(Debug, Clone)]
64pub struct StartupMessage {
65    pub parameters: HashMap<String, String>,
66}
67
68impl StartupMessage {
69    pub fn new_replication(database: &str, user: &str) -> Self {
70        let mut parameters = HashMap::new();
71        parameters.insert("user".to_string(), user.to_string());
72        parameters.insert("database".to_string(), database.to_string());
73        parameters.insert("replication".to_string(), "database".to_string());
74        Self { parameters }
75    }
76}
77
78impl FrontendMessage {
79    pub fn encode(&self, buf: &mut BytesMut) -> Result<()> {
80        match self {
81            FrontendMessage::StartupMessage(msg) => {
82                let mut msg_buf = BytesMut::new();
83                msg_buf.put_u32(PROTOCOL_VERSION);
84
85                for (key, value) in &msg.parameters {
86                    msg_buf.put_slice(key.as_bytes());
87                    msg_buf.put_u8(0);
88                    msg_buf.put_slice(value.as_bytes());
89                    msg_buf.put_u8(0);
90                }
91                msg_buf.put_u8(0); // Final terminator
92
93                // Length includes itself
94                buf.put_u32((msg_buf.len() + 4) as u32);
95                buf.put_slice(&msg_buf);
96            }
97
98            FrontendMessage::PasswordMessage(password) => {
99                buf.put_u8(b'p');
100                buf.put_u32((4 + password.len() + 1) as u32);
101                buf.put_slice(password.as_bytes());
102                buf.put_u8(0);
103            }
104
105            FrontendMessage::Query(query) => {
106                buf.put_u8(b'Q');
107                buf.put_u32((4 + query.len() + 1) as u32);
108                buf.put_slice(query.as_bytes());
109                buf.put_u8(0);
110            }
111
112            FrontendMessage::Terminate => {
113                buf.put_u8(b'X');
114                buf.put_u32(4);
115            }
116
117            FrontendMessage::CopyData(data) => {
118                buf.put_u8(b'd');
119                buf.put_u32((4 + data.len()) as u32);
120                buf.put_slice(data);
121            }
122
123            FrontendMessage::CopyDone => {
124                buf.put_u8(b'c');
125                buf.put_u32(4);
126            }
127
128            FrontendMessage::CopyFail(msg) => {
129                buf.put_u8(b'f');
130                buf.put_u32((4 + msg.len() + 1) as u32);
131                buf.put_slice(msg.as_bytes());
132                buf.put_u8(0);
133            }
134
135            FrontendMessage::SASLInitialResponse { mechanism, data } => {
136                buf.put_u8(b'p');
137                let mut msg_buf = BytesMut::new();
138                msg_buf.put_slice(mechanism.as_bytes());
139                msg_buf.put_u8(0);
140                msg_buf.put_u32(data.len() as u32);
141                msg_buf.put_slice(data);
142                buf.put_u32((4 + msg_buf.len()) as u32);
143                buf.put_slice(&msg_buf);
144            }
145
146            FrontendMessage::SASLResponse(data) => {
147                buf.put_u8(b'p');
148                buf.put_u32((4 + data.len()) as u32);
149                buf.put_slice(data);
150            }
151
152            FrontendMessage::StandbyStatusUpdate {
153                write_lsn,
154                flush_lsn,
155                apply_lsn,
156                timestamp,
157                reply,
158            } => {
159                let mut data = BytesMut::new();
160                data.put_u8(b'r'); // Standby status update
161                data.put_u64(*write_lsn);
162                data.put_u64(*flush_lsn);
163                data.put_u64(*apply_lsn);
164                data.put_i64(*timestamp);
165                data.put_u8(*reply);
166
167                buf.put_u8(b'd'); // CopyData
168                buf.put_u32((4 + data.len()) as u32);
169                buf.put_slice(&data);
170            }
171
172            _ => return Err(anyhow!("Unsupported message type for encoding")),
173        }
174
175        Ok(())
176    }
177}
178
179#[allow(clippy::large_enum_variant)]
180#[derive(Debug)]
181pub enum BackendMessage {
182    Authentication(AuthenticationMessage),
183    BackendKeyData {
184        process_id: i32,
185        secret_key: i32,
186    },
187    BindComplete,
188    CloseComplete,
189    CommandComplete(String),
190    CopyBothResponse,
191    CopyData(Vec<u8>),
192    CopyDone,
193    CopyInResponse,
194    CopyOutResponse,
195    DataRow(Vec<Option<Vec<u8>>>),
196    EmptyQueryResponse,
197    ErrorResponse(ErrorResponse),
198    NoData,
199    NoticeResponse(NoticeResponse),
200    NotificationResponse,
201    ParameterDescription,
202    ParameterStatus {
203        name: String,
204        value: String,
205    },
206    ParseComplete,
207    PortalSuspended,
208    ReadyForQuery(TransactionStatus),
209    RowDescription(Vec<FieldDescription>),
210    // Replication specific
211    PrimaryKeepaliveMessage {
212        wal_end: u64,
213        timestamp: i64,
214        reply: u8,
215    },
216}
217
218#[derive(Debug)]
219pub enum AuthenticationMessage {
220    Ok,
221    KerberosV5,
222    CleartextPassword,
223    MD5Password([u8; 4]),
224    SCMCredential,
225    GSS,
226    GSSContinue(Vec<u8>),
227    SSPI,
228    SASL(Vec<String>),
229    SASLContinue(Vec<u8>),
230    SASLFinal(Vec<u8>),
231}
232
233#[derive(Debug)]
234pub struct ErrorResponse {
235    pub severity: String,
236    pub code: String,
237    pub message: String,
238    pub detail: Option<String>,
239    pub hint: Option<String>,
240    pub position: Option<i32>,
241    pub internal_position: Option<i32>,
242    pub internal_query: Option<String>,
243    pub where_: Option<String>,
244    pub schema: Option<String>,
245    pub table: Option<String>,
246    pub column: Option<String>,
247    pub datatype: Option<String>,
248    pub constraint: Option<String>,
249    pub file: Option<String>,
250    pub line: Option<i32>,
251    pub routine: Option<String>,
252}
253
254#[derive(Debug)]
255pub struct NoticeResponse {
256    pub severity: String,
257    pub code: String,
258    pub message: String,
259    pub detail: Option<String>,
260    pub hint: Option<String>,
261}
262
263#[derive(Debug, Clone, Copy)]
264pub enum TransactionStatus {
265    Idle,
266    Transaction,
267    Failed,
268}
269
270#[derive(Debug)]
271pub struct FieldDescription {
272    pub name: String,
273    pub table_oid: u32,
274    pub column_id: i16,
275    pub type_oid: u32,
276    pub type_size: i16,
277    pub type_modifier: i32,
278    pub format: i16,
279}
280
281pub fn parse_backend_message(msg_type: u8, body: &[u8]) -> Result<BackendMessage> {
282    match msg_type {
283        b'R' => parse_authentication(body),
284        b'K' => parse_backend_key_data(body),
285        b'Z' => parse_ready_for_query(body),
286        b'S' => parse_parameter_status(body),
287        b'E' => parse_error_response(body),
288        b'N' => parse_notice_response(body),
289        b'C' => parse_command_complete(body),
290        b'T' => parse_row_description(body),
291        b'D' => parse_data_row(body),
292        b'W' => parse_copy_both_response(body),
293        b'd' => Ok(BackendMessage::CopyData(body.to_vec())),
294        b'c' => Ok(BackendMessage::CopyDone),
295        b'1' => Ok(BackendMessage::ParseComplete),
296        b'2' => Ok(BackendMessage::BindComplete),
297        b'3' => Ok(BackendMessage::CloseComplete),
298        b'n' => Ok(BackendMessage::NoData),
299        b'I' => Ok(BackendMessage::EmptyQueryResponse),
300        b's' => Ok(BackendMessage::PortalSuspended),
301        _ => Err(anyhow!(
302            "Unknown backend message type: {}",
303            msg_type as char
304        )),
305    }
306}
307
308fn parse_authentication(body: &[u8]) -> Result<BackendMessage> {
309    if body.len() < 4 {
310        return Err(anyhow!("Authentication message too short"));
311    }
312
313    let auth_type = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
314    let auth = match auth_type {
315        0 => AuthenticationMessage::Ok,
316        3 => AuthenticationMessage::CleartextPassword,
317        5 => {
318            if body.len() < 8 {
319                return Err(anyhow!("MD5 authentication message too short"));
320            }
321            let mut salt = [0u8; 4];
322            salt.copy_from_slice(&body[4..8]);
323            AuthenticationMessage::MD5Password(salt)
324        }
325        10 => {
326            if body.len() > 4 {
327                let mechanisms = parse_sasl_mechanisms(&body[4..])?;
328                AuthenticationMessage::SASL(mechanisms)
329            } else {
330                AuthenticationMessage::SASL(vec![])
331            }
332        }
333        11 => {
334            if body.len() > 4 {
335                AuthenticationMessage::SASLContinue(body[4..].to_vec())
336            } else {
337                AuthenticationMessage::SASLContinue(vec![])
338            }
339        }
340        12 => {
341            if body.len() > 4 {
342                AuthenticationMessage::SASLFinal(body[4..].to_vec())
343            } else {
344                AuthenticationMessage::SASLFinal(vec![])
345            }
346        }
347        _ => return Err(anyhow!("Unsupported authentication type: {auth_type}")),
348    };
349
350    Ok(BackendMessage::Authentication(auth))
351}
352
353fn parse_sasl_mechanisms(body: &[u8]) -> Result<Vec<String>> {
354    let mut mechanisms = Vec::new();
355    let mut pos = 0;
356
357    while pos < body.len() {
358        let end = body[pos..]
359            .iter()
360            .position(|&b| b == 0)
361            .ok_or_else(|| anyhow!("Unterminated SASL mechanism"))?;
362
363        if end == 0 {
364            break; // Double null terminator
365        }
366
367        mechanisms.push(String::from_utf8_lossy(&body[pos..pos + end]).to_string());
368        pos += end + 1;
369    }
370
371    Ok(mechanisms)
372}
373
374fn parse_backend_key_data(body: &[u8]) -> Result<BackendMessage> {
375    if body.len() != 8 {
376        return Err(anyhow!("BackendKeyData message wrong size"));
377    }
378
379    let process_id = i32::from_be_bytes([body[0], body[1], body[2], body[3]]);
380    let secret_key = i32::from_be_bytes([body[4], body[5], body[6], body[7]]);
381
382    Ok(BackendMessage::BackendKeyData {
383        process_id,
384        secret_key,
385    })
386}
387
388fn parse_ready_for_query(body: &[u8]) -> Result<BackendMessage> {
389    if body.len() != 1 {
390        return Err(anyhow!("ReadyForQuery message wrong size"));
391    }
392
393    let status = match body[0] {
394        b'I' => TransactionStatus::Idle,
395        b'T' => TransactionStatus::Transaction,
396        b'E' => TransactionStatus::Failed,
397        _ => return Err(anyhow!("Unknown transaction status: {}", body[0])),
398    };
399
400    Ok(BackendMessage::ReadyForQuery(status))
401}
402
403fn parse_parameter_status(body: &[u8]) -> Result<BackendMessage> {
404    let name_end = body
405        .iter()
406        .position(|&b| b == 0)
407        .ok_or_else(|| anyhow!("Unterminated parameter name"))?;
408
409    let name = String::from_utf8_lossy(&body[..name_end]).to_string();
410
411    let value_start = name_end + 1;
412    let value_end = body[value_start..]
413        .iter()
414        .position(|&b| b == 0)
415        .ok_or_else(|| anyhow!("Unterminated parameter value"))?;
416
417    let value = String::from_utf8_lossy(&body[value_start..value_start + value_end]).to_string();
418
419    Ok(BackendMessage::ParameterStatus { name, value })
420}
421
422fn parse_error_response(body: &[u8]) -> Result<BackendMessage> {
423    let fields = parse_notice_fields(body)?;
424    Ok(BackendMessage::ErrorResponse(ErrorResponse {
425        severity: fields.get("S").cloned().unwrap_or_default(),
426        code: fields.get("C").cloned().unwrap_or_default(),
427        message: fields.get("M").cloned().unwrap_or_default(),
428        detail: fields.get("D").cloned(),
429        hint: fields.get("H").cloned(),
430        position: fields.get("P").and_then(|s| s.parse().ok()),
431        internal_position: fields.get("p").and_then(|s| s.parse().ok()),
432        internal_query: fields.get("q").cloned(),
433        where_: fields.get("W").cloned(),
434        schema: fields.get("s").cloned(),
435        table: fields.get("t").cloned(),
436        column: fields.get("c").cloned(),
437        datatype: fields.get("d").cloned(),
438        constraint: fields.get("n").cloned(),
439        file: fields.get("F").cloned(),
440        line: fields.get("L").and_then(|s| s.parse().ok()),
441        routine: fields.get("R").cloned(),
442    }))
443}
444
445fn parse_notice_response(body: &[u8]) -> Result<BackendMessage> {
446    let fields = parse_notice_fields(body)?;
447    Ok(BackendMessage::NoticeResponse(NoticeResponse {
448        severity: fields.get("S").cloned().unwrap_or_default(),
449        code: fields.get("C").cloned().unwrap_or_default(),
450        message: fields.get("M").cloned().unwrap_or_default(),
451        detail: fields.get("D").cloned(),
452        hint: fields.get("H").cloned(),
453    }))
454}
455
456fn parse_notice_fields(body: &[u8]) -> Result<HashMap<String, String>> {
457    let mut fields = HashMap::new();
458    let mut pos = 0;
459
460    while pos < body.len() && body[pos] != 0 {
461        let field_type = body[pos] as char;
462        pos += 1;
463
464        let end = body[pos..]
465            .iter()
466            .position(|&b| b == 0)
467            .ok_or_else(|| anyhow!("Unterminated field value"))?;
468
469        let value = String::from_utf8_lossy(&body[pos..pos + end]).to_string();
470        fields.insert(field_type.to_string(), value);
471
472        pos += end + 1;
473    }
474
475    Ok(fields)
476}
477
478fn parse_command_complete(body: &[u8]) -> Result<BackendMessage> {
479    let end = body
480        .iter()
481        .position(|&b| b == 0)
482        .ok_or_else(|| anyhow!("Unterminated command tag"))?;
483
484    let tag = String::from_utf8_lossy(&body[..end]).to_string();
485    Ok(BackendMessage::CommandComplete(tag))
486}
487
488fn parse_row_description(body: &[u8]) -> Result<BackendMessage> {
489    let mut pos = 0;
490    let field_count = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
491    pos += 2;
492
493    let mut fields = Vec::with_capacity(field_count);
494
495    for _ in 0..field_count {
496        let name_end = body[pos..]
497            .iter()
498            .position(|&b| b == 0)
499            .ok_or_else(|| anyhow!("Unterminated field name"))?;
500
501        let name = String::from_utf8_lossy(&body[pos..pos + name_end]).to_string();
502        pos += name_end + 1;
503
504        if pos + 18 > body.len() {
505            return Err(anyhow!("Row description truncated"));
506        }
507
508        let table_oid =
509            u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
510        pos += 4;
511
512        let column_id = i16::from_be_bytes([body[pos], body[pos + 1]]);
513        pos += 2;
514
515        let type_oid = u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
516        pos += 4;
517
518        let type_size = i16::from_be_bytes([body[pos], body[pos + 1]]);
519        pos += 2;
520
521        let type_modifier =
522            i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
523        pos += 4;
524
525        let format = i16::from_be_bytes([body[pos], body[pos + 1]]);
526        pos += 2;
527
528        fields.push(FieldDescription {
529            name,
530            table_oid,
531            column_id,
532            type_oid,
533            type_size,
534            type_modifier,
535            format,
536        });
537    }
538
539    Ok(BackendMessage::RowDescription(fields))
540}
541
542fn parse_data_row(body: &[u8]) -> Result<BackendMessage> {
543    let mut pos = 0;
544    let column_count = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
545    pos += 2;
546
547    let mut columns = Vec::with_capacity(column_count);
548
549    for _ in 0..column_count {
550        if pos + 4 > body.len() {
551            return Err(anyhow!("Data row truncated"));
552        }
553
554        let length = i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
555        pos += 4;
556
557        if length == -1 {
558            columns.push(None);
559        } else {
560            let length = length as usize;
561            if pos + length > body.len() {
562                return Err(anyhow!("Data row value truncated"));
563            }
564            columns.push(Some(body[pos..pos + length].to_vec()));
565            pos += length;
566        }
567    }
568
569    Ok(BackendMessage::DataRow(columns))
570}
571
572fn parse_copy_both_response(_body: &[u8]) -> Result<BackendMessage> {
573    // Note: PostgreSQL CopyBothResponse includes format codes (overall format + per-column formats),
574    // but we don't need to parse them for replication streaming because:
575    // 1. Replication protocol uses a known binary format
576    // 2. We handle the actual data parsing in the replication-specific message handlers
577    // 3. Format codes are primarily useful for COPY operations with variable formats
578    //
579    // PostgreSQL protocol spec for CopyBothResponse ('W'):
580    // - Int32: message length
581    // - Int8: overall copy format (0=text, 1=binary)
582    // - Int16: number of columns
583    // - Int16[]: format code for each column (0=text, 1=binary)
584    //
585    // If variable format support is needed in the future, parse format codes here.
586    Ok(BackendMessage::CopyBothResponse)
587}