Skip to main content

qail_pg/driver/
replication.rs

1//! Replication helpers.
2//!
3//! Current scope:
4//! - `IDENTIFY_SYSTEM`
5//! - `CREATE_REPLICATION_SLOT ... LOGICAL ...`
6//! - `DROP_REPLICATION_SLOT`
7//! - `START_REPLICATION SLOT ... LOGICAL ...`
8//! - CopyBoth stream message decode (`XLogData`, keepalive)
9//! - Standby status updates (`'r'`)
10
11use super::{
12    PgConnection, PgDriver, PgError, PgResult, PgRow, is_ignorable_session_message,
13    unexpected_backend_message,
14};
15use crate::protocol::{BackendMessage, PgEncoder};
16
17/// Output of `IDENTIFY_SYSTEM`.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct IdentifySystem {
20    /// Cluster system identifier.
21    pub system_id: String,
22    /// Current timeline ID.
23    pub timeline: u32,
24    /// Current WAL/LSN position as text (e.g. `0/16B6C50`).
25    pub xlog_pos: String,
26    /// Database name for logical replication sessions (if provided by server).
27    pub dbname: Option<String>,
28}
29
30/// Output of `CREATE_REPLICATION_SLOT ... LOGICAL ...`.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ReplicationSlotInfo {
33    /// Created slot name.
34    pub slot_name: String,
35    /// Consistent point at which the slot became valid.
36    pub consistent_point: String,
37    /// Exported snapshot name (if any).
38    pub snapshot_name: Option<String>,
39    /// Output plugin used by this logical slot.
40    pub output_plugin: String,
41}
42
43/// Metadata returned when a logical replication stream starts.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct ReplicationStreamStart {
46    /// Copy format (0=text, 1=binary).
47    pub format: u8,
48    /// Per-column format codes.
49    pub column_formats: Vec<u8>,
50}
51
52/// WAL payload message from `CopyData('w' ...)`.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ReplicationXLogData {
55    /// WAL start position of this payload.
56    pub wal_start: u64,
57    /// Current WAL end on server.
58    pub wal_end: u64,
59    /// Server clock in microseconds since PostgreSQL epoch (2000-01-01).
60    pub server_time_micros: i64,
61    /// Output-plugin payload bytes.
62    pub data: Vec<u8>,
63}
64
65/// Keepalive message from `CopyData('k' ...)`.
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct ReplicationKeepalive {
68    /// Current WAL end on server.
69    pub wal_end: u64,
70    /// Server clock in microseconds since PostgreSQL epoch (2000-01-01).
71    pub server_time_micros: i64,
72    /// Whether server requests an immediate status reply.
73    pub reply_requested: bool,
74}
75
76/// Logical replication stream message decoded from `CopyData`.
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum ReplicationStreamMessage {
79    /// XLog data frame (`'w'`).
80    XLogData(ReplicationXLogData),
81    /// Primary keepalive frame (`'k'`).
82    Keepalive(ReplicationKeepalive),
83    /// Unknown CopyData sub-message tag.
84    Raw { tag: u8, payload: Vec<u8> },
85}
86
87/// Plugin options used in `START_REPLICATION ... LOGICAL ... (k 'v', ...)`.
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct ReplicationOption {
90    /// Option key (strict identifier).
91    pub key: String,
92    /// Option value (quoted SQL string in command text).
93    pub value: String,
94}
95
96const MAX_REPLICATION_OPTIONS: usize = 64;
97const MAX_REPLICATION_OPTION_VALUE_BYTES: usize = 16 * 1024;
98const MAX_REPLICATION_XLOGDATA_BYTES: usize = 16 * 1024 * 1024;
99
100fn validate_ident(kind: &str, ident: &str) -> PgResult<()> {
101    if ident.is_empty() {
102        return Err(PgError::Query(format!("{} must not be empty", kind)));
103    }
104    if ident.len() > 63 {
105        return Err(PgError::Query(format!(
106            "{} '{}' exceeds PostgreSQL identifier limit (63 bytes)",
107            kind, ident
108        )));
109    }
110    let mut chars = ident.chars();
111    match chars.next() {
112        Some(c) if c == '_' || c.is_ascii_alphabetic() => {}
113        _ => {
114            return Err(PgError::Query(format!(
115                "{} '{}' must start with [A-Za-z_]",
116                kind, ident
117            )));
118        }
119    }
120    if !chars.all(|c| c == '_' || c.is_ascii_alphanumeric()) {
121        return Err(PgError::Query(format!(
122            "{} '{}' contains unsupported characters (allowed: [A-Za-z0-9_])",
123            kind, ident
124        )));
125    }
126    Ok(())
127}
128
129fn sql_single_quote_literal(value: &str) -> PgResult<String> {
130    if value.contains('\0') {
131        return Err(PgError::Query(
132            "replication option value contains NUL byte".to_string(),
133        ));
134    }
135    Ok(value.replace('\'', "''"))
136}
137
138fn parse_lsn_text(lsn: &str) -> PgResult<u64> {
139    let mut parts = lsn.split('/');
140    let high = parts
141        .next()
142        .ok_or_else(|| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
143    let low = parts
144        .next()
145        .ok_or_else(|| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
146    if parts.next().is_some() {
147        return Err(PgError::Query(format!("Invalid LSN '{}'", lsn)));
148    }
149    let high = u32::from_str_radix(high, 16)
150        .map_err(|_| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
151    let low = u32::from_str_radix(low, 16)
152        .map_err(|_| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
153    Ok(((high as u64) << 32) | (low as u64))
154}
155
156#[cfg(test)]
157fn format_lsn(lsn: u64) -> String {
158    format!("{:X}/{:08X}", (lsn >> 32) as u32, lsn as u32)
159}
160
161fn required_text(row: &PgRow, idx: usize, field: &str) -> PgResult<String> {
162    row.get_string(idx).ok_or_else(|| {
163        PgError::Protocol(format!("Missing or invalid '{}' in replication row", field))
164    })
165}
166
167fn parse_identify_system_row(row: &PgRow) -> PgResult<IdentifySystem> {
168    let system_id = required_text(row, 0, "systemid")?;
169    let timeline = required_text(row, 1, "timeline")?
170        .parse::<u32>()
171        .map_err(|e| PgError::Protocol(format!("Invalid timeline value: {}", e)))?;
172    let xlog_pos = required_text(row, 2, "xlogpos")?;
173    let dbname = row
174        .get_string(3)
175        .and_then(|v| if v.is_empty() { None } else { Some(v) });
176
177    Ok(IdentifySystem {
178        system_id,
179        timeline,
180        xlog_pos,
181        dbname,
182    })
183}
184
185fn parse_create_slot_row(row: &PgRow) -> PgResult<ReplicationSlotInfo> {
186    let slot_name = required_text(row, 0, "slot_name")?;
187    let consistent_point = required_text(row, 1, "consistent_point")?;
188    let snapshot_name = row
189        .get_string(2)
190        .and_then(|v| if v.is_empty() { None } else { Some(v) });
191    let output_plugin = required_text(row, 3, "output_plugin")?;
192
193    Ok(ReplicationSlotInfo {
194        slot_name,
195        consistent_point,
196        snapshot_name,
197        output_plugin,
198    })
199}
200
201fn build_create_logical_replication_slot_sql(
202    slot_name: &str,
203    output_plugin: &str,
204    temporary: bool,
205    two_phase: bool,
206) -> PgResult<String> {
207    validate_ident("slot_name", slot_name)?;
208    validate_ident("output_plugin", output_plugin)?;
209
210    let mut sql = String::from("CREATE_REPLICATION_SLOT ");
211    sql.push_str(slot_name);
212    if temporary {
213        sql.push_str(" TEMPORARY");
214    }
215    sql.push_str(" LOGICAL ");
216    sql.push_str(output_plugin);
217    if two_phase {
218        sql.push_str(" TWO_PHASE");
219    }
220    Ok(sql)
221}
222
223fn build_drop_replication_slot_sql(slot_name: &str, wait: bool) -> PgResult<String> {
224    validate_ident("slot_name", slot_name)?;
225    let mut sql = String::from("DROP_REPLICATION_SLOT ");
226    sql.push_str(slot_name);
227    if wait {
228        sql.push_str(" WAIT");
229    }
230    Ok(sql)
231}
232
233fn build_start_logical_replication_sql(
234    slot_name: &str,
235    start_lsn: &str,
236    options: &[ReplicationOption],
237) -> PgResult<String> {
238    validate_ident("slot_name", slot_name)?;
239    let _ = parse_lsn_text(start_lsn)?;
240    if options.len() > MAX_REPLICATION_OPTIONS {
241        return Err(PgError::Query(format!(
242            "too many replication options: {} (max {})",
243            options.len(),
244            MAX_REPLICATION_OPTIONS
245        )));
246    }
247
248    let mut sql = format!("START_REPLICATION SLOT {} LOGICAL {}", slot_name, start_lsn);
249    if !options.is_empty() {
250        sql.push_str(" (");
251        for (idx, opt) in options.iter().enumerate() {
252            validate_ident("replication option key", &opt.key)?;
253            if opt.value.len() > MAX_REPLICATION_OPTION_VALUE_BYTES {
254                return Err(PgError::Query(format!(
255                    "replication option '{}' value too large: {} bytes (max {})",
256                    opt.key,
257                    opt.value.len(),
258                    MAX_REPLICATION_OPTION_VALUE_BYTES
259                )));
260            }
261            if idx > 0 {
262                sql.push_str(", ");
263            }
264            sql.push_str(&opt.key);
265            sql.push_str(" '");
266            sql.push_str(&sql_single_quote_literal(&opt.value)?);
267            sql.push('\'');
268        }
269        sql.push(')');
270    }
271    Ok(sql)
272}
273
274fn parse_copy_data_message(payload: &[u8]) -> PgResult<ReplicationStreamMessage> {
275    if payload.is_empty() {
276        return Err(PgError::Protocol(
277            "Replication CopyData payload is empty".to_string(),
278        ));
279    }
280    match payload[0] {
281        b'w' => {
282            if payload.len() < 25 {
283                return Err(PgError::Protocol(format!(
284                    "XLogData payload too short: {} bytes",
285                    payload.len()
286                )));
287            }
288            let wal_start = u64::from_be_bytes(
289                payload[1..9]
290                    .try_into()
291                    .map_err(|_| PgError::Protocol("Invalid wal_start bytes".to_string()))?,
292            );
293            let wal_end = u64::from_be_bytes(
294                payload[9..17]
295                    .try_into()
296                    .map_err(|_| PgError::Protocol("Invalid wal_end bytes".to_string()))?,
297            );
298            let server_time_micros = i64::from_be_bytes(
299                payload[17..25]
300                    .try_into()
301                    .map_err(|_| PgError::Protocol("Invalid server time bytes".to_string()))?,
302            );
303            if wal_end < wal_start {
304                return Err(PgError::Protocol(format!(
305                    "XLogData wal_end {} is behind wal_start {}",
306                    wal_end, wal_start
307                )));
308            }
309            let data_len = payload.len() - 25;
310            if data_len > MAX_REPLICATION_XLOGDATA_BYTES {
311                return Err(PgError::Protocol(format!(
312                    "XLogData payload too large: {} bytes (max {})",
313                    data_len, MAX_REPLICATION_XLOGDATA_BYTES
314                )));
315            }
316            Ok(ReplicationStreamMessage::XLogData(ReplicationXLogData {
317                wal_start,
318                wal_end,
319                server_time_micros,
320                data: payload[25..].to_vec(),
321            }))
322        }
323        b'k' => {
324            if payload.len() != 18 {
325                return Err(PgError::Protocol(format!(
326                    "Keepalive payload must be 18 bytes, got {}",
327                    payload.len()
328                )));
329            }
330            let wal_end =
331                u64::from_be_bytes(payload[1..9].try_into().map_err(|_| {
332                    PgError::Protocol("Invalid keepalive wal_end bytes".to_string())
333                })?);
334            let server_time_micros = i64::from_be_bytes(
335                payload[9..17]
336                    .try_into()
337                    .map_err(|_| PgError::Protocol("Invalid keepalive time bytes".to_string()))?,
338            );
339            let reply_requested = match payload[17] {
340                0 => false,
341                1 => true,
342                other => {
343                    return Err(PgError::Protocol(format!(
344                        "Keepalive reply_requested must be 0 or 1, got {}",
345                        other
346                    )));
347                }
348            };
349            Ok(ReplicationStreamMessage::Keepalive(ReplicationKeepalive {
350                wal_end,
351                server_time_micros,
352                reply_requested,
353            }))
354        }
355        tag => Err(PgError::Protocol(format!(
356            "Unsupported replication CopyData tag '{}'",
357            if tag.is_ascii_graphic() {
358                tag as char
359            } else {
360                '?'
361            }
362        ))),
363    }
364}
365
366fn postgres_epoch_micros_now() -> i64 {
367    const PG_UNIX_EPOCH_DIFF_SECS: i64 = 946_684_800;
368    let now = std::time::SystemTime::now()
369        .duration_since(std::time::UNIX_EPOCH)
370        .unwrap_or_default();
371    let unix_micros = (now.as_secs() as i64) * 1_000_000 + i64::from(now.subsec_micros());
372    unix_micros - PG_UNIX_EPOCH_DIFF_SECS * 1_000_000
373}
374
375fn build_standby_status_update_payload(
376    write_lsn: u64,
377    flush_lsn: u64,
378    apply_lsn: u64,
379    reply_requested: bool,
380) -> Vec<u8> {
381    let mut payload = Vec::with_capacity(1 + 8 + 8 + 8 + 8 + 1);
382    payload.push(b'r');
383    payload.extend_from_slice(&write_lsn.to_be_bytes());
384    payload.extend_from_slice(&flush_lsn.to_be_bytes());
385    payload.extend_from_slice(&apply_lsn.to_be_bytes());
386    payload.extend_from_slice(&postgres_epoch_micros_now().to_be_bytes());
387    payload.push(if reply_requested { 1 } else { 0 });
388    payload
389}
390
391#[inline]
392fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
393    if matches!(
394        err,
395        PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
396    ) {
397        conn.mark_io_desynced();
398    }
399    Err(err)
400}
401
402impl PgConnection {
403    #[inline]
404    fn ensure_replication_mode(&self, operation: &str) -> PgResult<()> {
405        if self.replication_mode_enabled {
406            return Ok(());
407        }
408        Err(PgError::Protocol(format!(
409            "{} requires connection startup param replication=database",
410            operation
411        )))
412    }
413
414    #[inline]
415    fn ensure_replication_control_idle(&self, operation: &str) -> PgResult<()> {
416        if !self.replication_stream_active {
417            return Ok(());
418        }
419        Err(PgError::Protocol(format!(
420            "{} cannot run while replication stream is active",
421            operation
422        )))
423    }
424
425    #[inline]
426    fn advance_replication_wal_end(&mut self, source: &str, wal_end: u64) -> PgResult<()> {
427        if let Some(prev_wal_end) = self.last_replication_wal_end
428            && wal_end < prev_wal_end
429        {
430            self.replication_stream_active = false;
431            self.last_replication_wal_end = None;
432            return return_with_desync(
433                self,
434                PgError::Protocol(format!(
435                    "Replication {} wal_end regressed: previous {}, current {}",
436                    source, prev_wal_end, wal_end
437                )),
438            );
439        }
440        self.last_replication_wal_end = Some(wal_end);
441        Ok(())
442    }
443
444    /// Run `IDENTIFY_SYSTEM` on a replication connection.
445    pub async fn identify_system(&mut self) -> PgResult<IdentifySystem> {
446        self.ensure_replication_mode("IDENTIFY_SYSTEM")?;
447        self.ensure_replication_control_idle("IDENTIFY_SYSTEM")?;
448        let rows = self.simple_query("IDENTIFY_SYSTEM").await?;
449        let row = rows
450            .first()
451            .ok_or_else(|| PgError::Protocol("IDENTIFY_SYSTEM returned no rows".to_string()))?;
452        parse_identify_system_row(row)
453    }
454
455    /// Create a logical replication slot.
456    ///
457    /// `slot_name` and `output_plugin` are strict SQL identifiers.
458    pub async fn create_logical_replication_slot(
459        &mut self,
460        slot_name: &str,
461        output_plugin: &str,
462        temporary: bool,
463        two_phase: bool,
464    ) -> PgResult<ReplicationSlotInfo> {
465        self.ensure_replication_mode("CREATE_REPLICATION_SLOT")?;
466        self.ensure_replication_control_idle("CREATE_REPLICATION_SLOT")?;
467        let sql = build_create_logical_replication_slot_sql(
468            slot_name,
469            output_plugin,
470            temporary,
471            two_phase,
472        )?;
473        let rows = self.simple_query(&sql).await?;
474        let row = rows.first().ok_or_else(|| {
475            PgError::Protocol("CREATE_REPLICATION_SLOT returned no rows".to_string())
476        })?;
477        parse_create_slot_row(row)
478    }
479
480    /// Drop a replication slot.
481    ///
482    /// `wait=true` uses `DROP_REPLICATION_SLOT <slot> WAIT`.
483    pub async fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> PgResult<()> {
484        self.ensure_replication_mode("DROP_REPLICATION_SLOT")?;
485        self.ensure_replication_control_idle("DROP_REPLICATION_SLOT")?;
486        let sql = build_drop_replication_slot_sql(slot_name, wait)?;
487        self.execute_simple(&sql).await
488    }
489
490    /// Start logical replication in CopyBoth mode.
491    ///
492    /// Requires a connection started with `replication=database`.
493    pub async fn start_logical_replication(
494        &mut self,
495        slot_name: &str,
496        start_lsn: &str,
497        options: &[ReplicationOption],
498    ) -> PgResult<ReplicationStreamStart> {
499        self.ensure_replication_mode("START_REPLICATION")?;
500        if self.replication_stream_active {
501            return Err(PgError::Protocol(
502                "logical replication stream already active".to_string(),
503            ));
504        }
505        let sql = build_start_logical_replication_sql(slot_name, start_lsn, options)?;
506        let bytes = PgEncoder::try_encode_query_string(&sql)?;
507        self.write_all_with_timeout(&bytes, "stream write").await?;
508
509        let mut startup_error: Option<PgError> = None;
510        loop {
511            let msg = self.recv().await?;
512            match msg {
513                BackendMessage::CopyBothResponse {
514                    format,
515                    column_formats,
516                } => {
517                    if let Some(err) = startup_error {
518                        return return_with_desync(self, err);
519                    }
520                    if format != 0 {
521                        return return_with_desync(
522                            self,
523                            PgError::Protocol(format!(
524                                "START_REPLICATION returned unsupported CopyBothResponse format {} (expected 0/text)",
525                                format
526                            )),
527                        );
528                    }
529                    if !column_formats.is_empty() {
530                        return return_with_desync(
531                            self,
532                            PgError::Protocol(format!(
533                                "START_REPLICATION returned unexpected CopyBothResponse column formats (expected none, got {})",
534                                column_formats.len()
535                            )),
536                        );
537                    }
538                    self.replication_stream_active = true;
539                    self.last_replication_wal_end = None;
540                    return Ok(ReplicationStreamStart {
541                        format,
542                        column_formats,
543                    });
544                }
545                BackendMessage::ReadyForQuery(_) => {
546                    return return_with_desync(
547                        self,
548                        startup_error.unwrap_or_else(|| {
549                            PgError::Protocol(
550                                "START_REPLICATION ended before CopyBothResponse".to_string(),
551                            )
552                        }),
553                    );
554                }
555                BackendMessage::ErrorResponse(err) => {
556                    if startup_error.is_none() {
557                        startup_error = Some(PgError::QueryServer(err.into()));
558                    }
559                }
560                msg if is_ignorable_session_message(&msg) => {}
561                other => {
562                    return return_with_desync(
563                        self,
564                        unexpected_backend_message("start replication", &other),
565                    );
566                }
567            }
568        }
569    }
570
571    /// Receive the next logical replication stream message.
572    ///
573    /// Uses a no-timeout read path so idle periods do not fail the stream.
574    pub async fn recv_replication_message(&mut self) -> PgResult<ReplicationStreamMessage> {
575        self.ensure_replication_mode("recv_replication_message")?;
576        if !self.replication_stream_active {
577            return Err(PgError::Protocol(
578                "replication stream is not active; call START_REPLICATION first".to_string(),
579            ));
580        }
581        loop {
582            let msg = self.recv_without_timeout().await?;
583            match msg {
584                BackendMessage::CopyData(payload) => match parse_copy_data_message(&payload) {
585                    Ok(ReplicationStreamMessage::XLogData(x)) => {
586                        self.advance_replication_wal_end("XLogData", x.wal_end)?;
587                        return Ok(ReplicationStreamMessage::XLogData(x));
588                    }
589                    Ok(ReplicationStreamMessage::Keepalive(k)) => {
590                        self.advance_replication_wal_end("keepalive", k.wal_end)?;
591                        return Ok(ReplicationStreamMessage::Keepalive(k));
592                    }
593                    Ok(parsed) => return Ok(parsed),
594                    Err(err) => {
595                        self.replication_stream_active = false;
596                        self.last_replication_wal_end = None;
597                        return return_with_desync(self, err);
598                    }
599                },
600                BackendMessage::ErrorResponse(err) => {
601                    self.replication_stream_active = false;
602                    self.last_replication_wal_end = None;
603                    return Err(PgError::QueryServer(err.into()));
604                }
605                BackendMessage::CopyDone => {
606                    self.replication_stream_active = false;
607                    self.last_replication_wal_end = None;
608                    return return_with_desync(
609                        self,
610                        PgError::Protocol("Replication stream ended with CopyDone".to_string()),
611                    );
612                }
613                BackendMessage::ReadyForQuery(_) => {
614                    self.replication_stream_active = false;
615                    self.last_replication_wal_end = None;
616                    return return_with_desync(
617                        self,
618                        PgError::Protocol(
619                            "Replication stream ended with ReadyForQuery".to_string(),
620                        ),
621                    );
622                }
623                msg if is_ignorable_session_message(&msg) => {}
624                other => {
625                    self.replication_stream_active = false;
626                    self.last_replication_wal_end = None;
627                    return return_with_desync(
628                        self,
629                        unexpected_backend_message("replication stream", &other),
630                    );
631                }
632            }
633        }
634    }
635
636    /// Send a standby status update (`CopyData('r' ...)`) to the server.
637    pub async fn send_standby_status_update(
638        &mut self,
639        write_lsn: u64,
640        flush_lsn: u64,
641        apply_lsn: u64,
642        reply_requested: bool,
643    ) -> PgResult<()> {
644        self.ensure_replication_mode("send_standby_status_update")?;
645        if !self.replication_stream_active {
646            return Err(PgError::Protocol(
647                "replication stream is not active; call START_REPLICATION first".to_string(),
648            ));
649        }
650        if flush_lsn > write_lsn {
651            return Err(PgError::Protocol(format!(
652                "Invalid standby status update: flush_lsn {} exceeds write_lsn {}",
653                flush_lsn, write_lsn
654            )));
655        }
656        if apply_lsn > flush_lsn {
657            return Err(PgError::Protocol(format!(
658                "Invalid standby status update: apply_lsn {} exceeds flush_lsn {}",
659                apply_lsn, flush_lsn
660            )));
661        }
662        if let Some(last_wal_end) = self.last_replication_wal_end
663            && write_lsn > last_wal_end
664        {
665            return Err(PgError::Protocol(format!(
666                "Invalid standby status update: write_lsn {} exceeds last seen server wal_end {}",
667                write_lsn, last_wal_end
668            )));
669        }
670        let payload =
671            build_standby_status_update_payload(write_lsn, flush_lsn, apply_lsn, reply_requested);
672        self.send_copy_data(&payload).await
673    }
674}
675
676impl PgDriver {
677    /// Driver wrapper for [`PgConnection::identify_system`].
678    pub async fn identify_system(&mut self) -> PgResult<IdentifySystem> {
679        self.connection.identify_system().await
680    }
681
682    /// Driver wrapper for [`PgConnection::create_logical_replication_slot`].
683    pub async fn create_logical_replication_slot(
684        &mut self,
685        slot_name: &str,
686        output_plugin: &str,
687        temporary: bool,
688        two_phase: bool,
689    ) -> PgResult<ReplicationSlotInfo> {
690        self.connection
691            .create_logical_replication_slot(slot_name, output_plugin, temporary, two_phase)
692            .await
693    }
694
695    /// Driver wrapper for [`PgConnection::drop_replication_slot`].
696    pub async fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> PgResult<()> {
697        self.connection.drop_replication_slot(slot_name, wait).await
698    }
699
700    /// Driver wrapper for [`PgConnection::start_logical_replication`].
701    pub async fn start_logical_replication(
702        &mut self,
703        slot_name: &str,
704        start_lsn: &str,
705        options: &[ReplicationOption],
706    ) -> PgResult<ReplicationStreamStart> {
707        self.connection
708            .start_logical_replication(slot_name, start_lsn, options)
709            .await
710    }
711
712    /// Driver wrapper for [`PgConnection::recv_replication_message`].
713    pub async fn recv_replication_message(&mut self) -> PgResult<ReplicationStreamMessage> {
714        self.connection.recv_replication_message().await
715    }
716
717    /// Driver wrapper for [`PgConnection::send_standby_status_update`].
718    pub async fn send_standby_status_update(
719        &mut self,
720        write_lsn: u64,
721        flush_lsn: u64,
722        apply_lsn: u64,
723        reply_requested: bool,
724    ) -> PgResult<()> {
725        self.connection
726            .send_standby_status_update(write_lsn, flush_lsn, apply_lsn, reply_requested)
727            .await
728    }
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734
735    #[cfg(unix)]
736    fn test_conn() -> PgConnection {
737        use crate::driver::connection::StatementCache;
738        use crate::driver::stream::PgStream;
739        use bytes::BytesMut;
740        use std::collections::{HashMap, VecDeque};
741        use std::num::NonZeroUsize;
742        use tokio::net::UnixStream;
743
744        let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
745        PgConnection {
746            stream: PgStream::Unix(unix_stream),
747            buffer: BytesMut::with_capacity(1024),
748            write_buf: BytesMut::with_capacity(1024),
749            sql_buf: BytesMut::with_capacity(256),
750            params_buf: Vec::new(),
751            prepared_statements: HashMap::new(),
752            stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
753            column_info_cache: HashMap::new(),
754            process_id: 0,
755            cancel_key_bytes: Vec::new(),
756            requested_protocol_minor: PgConnection::default_protocol_minor(),
757            negotiated_protocol_minor: PgConnection::default_protocol_minor(),
758            notifications: VecDeque::new(),
759            replication_stream_active: false,
760            replication_mode_enabled: true,
761            last_replication_wal_end: None,
762            io_desynced: false,
763            pending_statement_closes: Vec::new(),
764            draining_statement_closes: false,
765        }
766    }
767
768    fn text_row(values: &[Option<&str>]) -> PgRow {
769        PgRow {
770            columns: values
771                .iter()
772                .map(|v| v.map(|s| s.as_bytes().to_vec()))
773                .collect(),
774            column_info: None,
775        }
776    }
777
778    #[cfg(unix)]
779    #[tokio::test]
780    async fn replication_wal_regression_marks_connection_desynced() {
781        let mut conn = test_conn();
782        conn.replication_stream_active = true;
783        conn.last_replication_wal_end = Some(20);
784
785        let err = conn
786            .advance_replication_wal_end("XLogData", 10)
787            .expect_err("wal regression must fail");
788
789        assert!(err.to_string().contains("wal_end regressed"));
790        assert!(!conn.replication_stream_active);
791        assert!(conn.last_replication_wal_end.is_none());
792        assert!(conn.is_io_desynced());
793    }
794
795    #[test]
796    fn validate_ident_rejects_bad_names() {
797        assert!(validate_ident("slot_name", "").is_err());
798        assert!(validate_ident("slot_name", "9slot").is_err());
799        assert!(validate_ident("slot_name", "bad-name").is_err());
800        assert!(validate_ident("slot_name", "has space").is_err());
801    }
802
803    #[test]
804    fn validate_ident_accepts_safe_names() {
805        assert!(validate_ident("slot_name", "slot_a1").is_ok());
806        assert!(validate_ident("output_plugin", "pgoutput").is_ok());
807    }
808
809    #[test]
810    fn parse_and_format_lsn_roundtrip() {
811        let lsn = parse_lsn_text("16/B6C50").unwrap();
812        assert_eq!(format_lsn(lsn), "16/000B6C50");
813    }
814
815    #[test]
816    fn build_create_logical_replication_slot_sql_variants() {
817        let sql =
818            build_create_logical_replication_slot_sql("slot_main", "pgoutput", true, true).unwrap();
819        assert_eq!(
820            sql,
821            "CREATE_REPLICATION_SLOT slot_main TEMPORARY LOGICAL pgoutput TWO_PHASE"
822        );
823    }
824
825    #[test]
826    fn build_drop_replication_slot_sql_variants() {
827        let sql = build_drop_replication_slot_sql("slot_main", true).unwrap();
828        assert_eq!(sql, "DROP_REPLICATION_SLOT slot_main WAIT");
829    }
830
831    #[test]
832    fn build_start_logical_replication_sql_with_options() {
833        let sql = build_start_logical_replication_sql(
834            "slot_main",
835            "0/16B6C50",
836            &[
837                ReplicationOption {
838                    key: "proto_version".to_string(),
839                    value: "1".to_string(),
840                },
841                ReplicationOption {
842                    key: "publication_names".to_string(),
843                    value: "pub1,pub2".to_string(),
844                },
845            ],
846        )
847        .unwrap();
848        assert_eq!(
849            sql,
850            "START_REPLICATION SLOT slot_main LOGICAL 0/16B6C50 (proto_version '1', publication_names 'pub1,pub2')"
851        );
852    }
853
854    #[test]
855    fn build_start_logical_replication_sql_rejects_too_many_options() {
856        let options: Vec<ReplicationOption> = (0..=MAX_REPLICATION_OPTIONS)
857            .map(|i| ReplicationOption {
858                key: format!("opt{}", i),
859                value: "x".to_string(),
860            })
861            .collect();
862
863        let err =
864            build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
865        assert!(err.to_string().contains("too many replication options"));
866    }
867
868    #[test]
869    fn build_start_logical_replication_sql_rejects_null_value() {
870        let options = vec![ReplicationOption {
871            key: "proto_version".to_string(),
872            value: "1\0oops".to_string(),
873        }];
874        let err =
875            build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
876        assert!(err.to_string().contains("contains NUL byte"));
877    }
878
879    #[test]
880    fn build_start_logical_replication_sql_rejects_oversized_value() {
881        let options = vec![ReplicationOption {
882            key: "publication_names".to_string(),
883            value: "x".repeat(MAX_REPLICATION_OPTION_VALUE_BYTES + 1),
884        }];
885        let err =
886            build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
887        assert!(err.to_string().contains("value too large"));
888    }
889
890    #[test]
891    fn parse_identify_system_row_happy_path() {
892        let row = text_row(&[
893            Some("7416469842679442267"),
894            Some("1"),
895            Some("0/16B6C50"),
896            Some("app"),
897        ]);
898        let parsed = parse_identify_system_row(&row).unwrap();
899        assert_eq!(parsed.system_id, "7416469842679442267");
900        assert_eq!(parsed.timeline, 1);
901        assert_eq!(parsed.xlog_pos, "0/16B6C50");
902        assert_eq!(parsed.dbname.as_deref(), Some("app"));
903    }
904
905    #[test]
906    fn parse_create_slot_row_happy_path() {
907        let row = text_row(&[
908            Some("slot_main"),
909            Some("0/16B6C88"),
910            Some("00000003-00000041-1"),
911            Some("pgoutput"),
912        ]);
913        let parsed = parse_create_slot_row(&row).unwrap();
914        assert_eq!(parsed.slot_name, "slot_main");
915        assert_eq!(parsed.consistent_point, "0/16B6C88");
916        assert_eq!(parsed.snapshot_name.as_deref(), Some("00000003-00000041-1"));
917        assert_eq!(parsed.output_plugin, "pgoutput");
918    }
919
920    #[test]
921    fn parse_copy_data_xlog_data() {
922        let mut payload = Vec::new();
923        payload.push(b'w');
924        payload.extend_from_slice(&0x10_u64.to_be_bytes());
925        payload.extend_from_slice(&0x20_u64.to_be_bytes());
926        payload.extend_from_slice(&123_i64.to_be_bytes());
927        payload.extend_from_slice(b"hello");
928
929        match parse_copy_data_message(&payload).unwrap() {
930            ReplicationStreamMessage::XLogData(x) => {
931                assert_eq!(x.wal_start, 0x10);
932                assert_eq!(x.wal_end, 0x20);
933                assert_eq!(x.server_time_micros, 123);
934                assert_eq!(x.data, b"hello");
935            }
936            _ => panic!("expected xlog data"),
937        }
938    }
939
940    #[test]
941    fn parse_copy_data_xlog_data_rejects_wal_end_before_start() {
942        let mut payload = Vec::new();
943        payload.push(b'w');
944        payload.extend_from_slice(&0x20_u64.to_be_bytes());
945        payload.extend_from_slice(&0x10_u64.to_be_bytes());
946        payload.extend_from_slice(&123_i64.to_be_bytes());
947        let err = parse_copy_data_message(&payload).unwrap_err();
948        assert!(err.to_string().contains("wal_end"));
949    }
950
951    #[test]
952    fn parse_copy_data_xlog_data_rejects_oversized_data() {
953        let mut payload = Vec::with_capacity(25 + MAX_REPLICATION_XLOGDATA_BYTES + 1);
954        payload.push(b'w');
955        payload.extend_from_slice(&0x10_u64.to_be_bytes());
956        payload.extend_from_slice(&0x20_u64.to_be_bytes());
957        payload.extend_from_slice(&123_i64.to_be_bytes());
958        payload.extend(std::iter::repeat_n(0u8, MAX_REPLICATION_XLOGDATA_BYTES + 1));
959        let err = parse_copy_data_message(&payload).unwrap_err();
960        assert!(err.to_string().contains("payload too large"));
961    }
962
963    #[test]
964    fn parse_copy_data_keepalive() {
965        let mut payload = Vec::new();
966        payload.push(b'k');
967        payload.extend_from_slice(&0xAB_u64.to_be_bytes());
968        payload.extend_from_slice(&456_i64.to_be_bytes());
969        payload.push(1);
970
971        match parse_copy_data_message(&payload).unwrap() {
972            ReplicationStreamMessage::Keepalive(k) => {
973                assert_eq!(k.wal_end, 0xAB);
974                assert_eq!(k.server_time_micros, 456);
975                assert!(k.reply_requested);
976            }
977            _ => panic!("expected keepalive"),
978        }
979    }
980
981    #[test]
982    fn parse_copy_data_keepalive_rejects_invalid_reply_requested() {
983        let mut payload = Vec::new();
984        payload.push(b'k');
985        payload.extend_from_slice(&0xAB_u64.to_be_bytes());
986        payload.extend_from_slice(&456_i64.to_be_bytes());
987        payload.push(2);
988        let err = parse_copy_data_message(&payload).unwrap_err();
989        assert!(err.to_string().contains("reply_requested must be 0 or 1"));
990    }
991
992    #[test]
993    fn parse_copy_data_unknown_tag_rejected() {
994        let payload = vec![b'x', 1, 2, 3];
995        let err = parse_copy_data_message(&payload).unwrap_err();
996        assert!(
997            err.to_string()
998                .contains("Unsupported replication CopyData tag")
999        );
1000    }
1001
1002    #[test]
1003    fn build_standby_status_update_payload_layout() {
1004        let payload = build_standby_status_update_payload(1, 2, 3, true);
1005        assert_eq!(payload.len(), 34);
1006        assert_eq!(payload[0], b'r');
1007        assert_eq!(u64::from_be_bytes(payload[1..9].try_into().unwrap()), 1);
1008        assert_eq!(u64::from_be_bytes(payload[9..17].try_into().unwrap()), 2);
1009        assert_eq!(u64::from_be_bytes(payload[17..25].try_into().unwrap()), 3);
1010        assert_eq!(payload[33], 1);
1011    }
1012}