Skip to main content

qail_pg/driver/
replication.rs

1//! Replication helpers.
2//!
3//! Current scope:
4//! - `IDENTIFY_SYSTEM`
5//! - `CREATE_REPLICATION_SLOT ... LOGICAL ...`
6//! - `DROP_REPLICATION_SLOT`
7//! - `START_REPLICATION SLOT ... LOGICAL ...`
8//! - CopyBoth stream message decode (`XLogData`, keepalive)
9//! - Standby status updates (`'r'`)
10
11use super::{
12    PgConnection, PgDriver, PgError, PgResult, PgRow, is_ignorable_session_message,
13    unexpected_backend_message,
14};
15use crate::protocol::{BackendMessage, PgEncoder};
16
17/// Output of `IDENTIFY_SYSTEM`.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct IdentifySystem {
20    /// Cluster system identifier.
21    pub system_id: String,
22    /// Current timeline ID.
23    pub timeline: u32,
24    /// Current WAL/LSN position as text (e.g. `0/16B6C50`).
25    pub xlog_pos: String,
26    /// Database name for logical replication sessions (if provided by server).
27    pub dbname: Option<String>,
28}
29
30/// Output of `CREATE_REPLICATION_SLOT ... LOGICAL ...`.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ReplicationSlotInfo {
33    /// Created slot name.
34    pub slot_name: String,
35    /// Consistent point at which the slot became valid.
36    pub consistent_point: String,
37    /// Exported snapshot name (if any).
38    pub snapshot_name: Option<String>,
39    /// Output plugin used by this logical slot.
40    pub output_plugin: String,
41}
42
43/// Metadata returned when a logical replication stream starts.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct ReplicationStreamStart {
46    /// Copy format (0=text, 1=binary).
47    pub format: u8,
48    /// Per-column format codes.
49    pub column_formats: Vec<u8>,
50}
51
52/// WAL payload message from `CopyData('w' ...)`.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ReplicationXLogData {
55    /// WAL start position of this payload.
56    pub wal_start: u64,
57    /// Current WAL end on server.
58    pub wal_end: u64,
59    /// Server clock in microseconds since PostgreSQL epoch (2000-01-01).
60    pub server_time_micros: i64,
61    /// Output-plugin payload bytes.
62    pub data: Vec<u8>,
63}
64
65/// Keepalive message from `CopyData('k' ...)`.
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct ReplicationKeepalive {
68    /// Current WAL end on server.
69    pub wal_end: u64,
70    /// Server clock in microseconds since PostgreSQL epoch (2000-01-01).
71    pub server_time_micros: i64,
72    /// Whether server requests an immediate status reply.
73    pub reply_requested: bool,
74}
75
76/// Logical replication stream message decoded from `CopyData`.
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum ReplicationStreamMessage {
79    /// XLog data frame (`'w'`).
80    XLogData(ReplicationXLogData),
81    /// Primary keepalive frame (`'k'`).
82    Keepalive(ReplicationKeepalive),
83    /// Unknown CopyData sub-message tag.
84    Raw { tag: u8, payload: Vec<u8> },
85}
86
87/// Plugin options used in `START_REPLICATION ... LOGICAL ... (k 'v', ...)`.
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct ReplicationOption {
90    /// Option key (strict identifier).
91    pub key: String,
92    /// Option value (quoted SQL string in command text).
93    pub value: String,
94}
95
96const MAX_REPLICATION_OPTIONS: usize = 64;
97const MAX_REPLICATION_OPTION_VALUE_BYTES: usize = 16 * 1024;
98const MAX_REPLICATION_XLOGDATA_BYTES: usize = 16 * 1024 * 1024;
99
100fn validate_ident(kind: &str, ident: &str) -> PgResult<()> {
101    if ident.is_empty() {
102        return Err(PgError::Query(format!("{} must not be empty", kind)));
103    }
104    if ident.len() > 63 {
105        return Err(PgError::Query(format!(
106            "{} '{}' exceeds PostgreSQL identifier limit (63 bytes)",
107            kind, ident
108        )));
109    }
110    let mut chars = ident.chars();
111    match chars.next() {
112        Some(c) if c == '_' || c.is_ascii_alphabetic() => {}
113        _ => {
114            return Err(PgError::Query(format!(
115                "{} '{}' must start with [A-Za-z_]",
116                kind, ident
117            )));
118        }
119    }
120    if !chars.all(|c| c == '_' || c.is_ascii_alphanumeric()) {
121        return Err(PgError::Query(format!(
122            "{} '{}' contains unsupported characters (allowed: [A-Za-z0-9_])",
123            kind, ident
124        )));
125    }
126    Ok(())
127}
128
129fn sql_single_quote_literal(value: &str) -> PgResult<String> {
130    if value.contains('\0') {
131        return Err(PgError::Query(
132            "replication option value contains NUL byte".to_string(),
133        ));
134    }
135    Ok(value.replace('\'', "''"))
136}
137
138fn parse_lsn_text(lsn: &str) -> PgResult<u64> {
139    let mut parts = lsn.split('/');
140    let high = parts
141        .next()
142        .ok_or_else(|| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
143    let low = parts
144        .next()
145        .ok_or_else(|| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
146    if parts.next().is_some() {
147        return Err(PgError::Query(format!("Invalid LSN '{}'", lsn)));
148    }
149    let high = u32::from_str_radix(high, 16)
150        .map_err(|_| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
151    let low = u32::from_str_radix(low, 16)
152        .map_err(|_| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
153    Ok(((high as u64) << 32) | (low as u64))
154}
155
156#[cfg(test)]
157fn format_lsn(lsn: u64) -> String {
158    format!("{:X}/{:08X}", (lsn >> 32) as u32, lsn as u32)
159}
160
161fn required_text(row: &PgRow, idx: usize, field: &str) -> PgResult<String> {
162    row.get_string(idx).ok_or_else(|| {
163        PgError::Protocol(format!("Missing or invalid '{}' in replication row", field))
164    })
165}
166
167fn parse_identify_system_row(row: &PgRow) -> PgResult<IdentifySystem> {
168    let system_id = required_text(row, 0, "systemid")?;
169    let timeline = required_text(row, 1, "timeline")?
170        .parse::<u32>()
171        .map_err(|e| PgError::Protocol(format!("Invalid timeline value: {}", e)))?;
172    let xlog_pos = required_text(row, 2, "xlogpos")?;
173    let dbname = row
174        .get_string(3)
175        .and_then(|v| if v.is_empty() { None } else { Some(v) });
176
177    Ok(IdentifySystem {
178        system_id,
179        timeline,
180        xlog_pos,
181        dbname,
182    })
183}
184
185fn parse_create_slot_row(row: &PgRow) -> PgResult<ReplicationSlotInfo> {
186    let slot_name = required_text(row, 0, "slot_name")?;
187    let consistent_point = required_text(row, 1, "consistent_point")?;
188    let snapshot_name = row
189        .get_string(2)
190        .and_then(|v| if v.is_empty() { None } else { Some(v) });
191    let output_plugin = required_text(row, 3, "output_plugin")?;
192
193    Ok(ReplicationSlotInfo {
194        slot_name,
195        consistent_point,
196        snapshot_name,
197        output_plugin,
198    })
199}
200
201fn build_create_logical_replication_slot_sql(
202    slot_name: &str,
203    output_plugin: &str,
204    temporary: bool,
205    two_phase: bool,
206) -> PgResult<String> {
207    validate_ident("slot_name", slot_name)?;
208    validate_ident("output_plugin", output_plugin)?;
209
210    let mut sql = String::from("CREATE_REPLICATION_SLOT ");
211    sql.push_str(slot_name);
212    if temporary {
213        sql.push_str(" TEMPORARY");
214    }
215    sql.push_str(" LOGICAL ");
216    sql.push_str(output_plugin);
217    if two_phase {
218        sql.push_str(" TWO_PHASE");
219    }
220    Ok(sql)
221}
222
223fn build_drop_replication_slot_sql(slot_name: &str, wait: bool) -> PgResult<String> {
224    validate_ident("slot_name", slot_name)?;
225    let mut sql = String::from("DROP_REPLICATION_SLOT ");
226    sql.push_str(slot_name);
227    if wait {
228        sql.push_str(" WAIT");
229    }
230    Ok(sql)
231}
232
233fn build_start_logical_replication_sql(
234    slot_name: &str,
235    start_lsn: &str,
236    options: &[ReplicationOption],
237) -> PgResult<String> {
238    validate_ident("slot_name", slot_name)?;
239    let _ = parse_lsn_text(start_lsn)?;
240    if options.len() > MAX_REPLICATION_OPTIONS {
241        return Err(PgError::Query(format!(
242            "too many replication options: {} (max {})",
243            options.len(),
244            MAX_REPLICATION_OPTIONS
245        )));
246    }
247
248    let mut sql = format!("START_REPLICATION SLOT {} LOGICAL {}", slot_name, start_lsn);
249    if !options.is_empty() {
250        sql.push_str(" (");
251        for (idx, opt) in options.iter().enumerate() {
252            validate_ident("replication option key", &opt.key)?;
253            if opt.value.len() > MAX_REPLICATION_OPTION_VALUE_BYTES {
254                return Err(PgError::Query(format!(
255                    "replication option '{}' value too large: {} bytes (max {})",
256                    opt.key,
257                    opt.value.len(),
258                    MAX_REPLICATION_OPTION_VALUE_BYTES
259                )));
260            }
261            if idx > 0 {
262                sql.push_str(", ");
263            }
264            sql.push_str(&opt.key);
265            sql.push_str(" '");
266            sql.push_str(&sql_single_quote_literal(&opt.value)?);
267            sql.push('\'');
268        }
269        sql.push(')');
270    }
271    Ok(sql)
272}
273
274fn parse_copy_data_message(payload: &[u8]) -> PgResult<ReplicationStreamMessage> {
275    if payload.is_empty() {
276        return Err(PgError::Protocol(
277            "Replication CopyData payload is empty".to_string(),
278        ));
279    }
280    match payload[0] {
281        b'w' => {
282            if payload.len() < 25 {
283                return Err(PgError::Protocol(format!(
284                    "XLogData payload too short: {} bytes",
285                    payload.len()
286                )));
287            }
288            let wal_start = u64::from_be_bytes(
289                payload[1..9]
290                    .try_into()
291                    .map_err(|_| PgError::Protocol("Invalid wal_start bytes".to_string()))?,
292            );
293            let wal_end = u64::from_be_bytes(
294                payload[9..17]
295                    .try_into()
296                    .map_err(|_| PgError::Protocol("Invalid wal_end bytes".to_string()))?,
297            );
298            let server_time_micros = i64::from_be_bytes(
299                payload[17..25]
300                    .try_into()
301                    .map_err(|_| PgError::Protocol("Invalid server time bytes".to_string()))?,
302            );
303            if wal_end < wal_start {
304                return Err(PgError::Protocol(format!(
305                    "XLogData wal_end {} is behind wal_start {}",
306                    wal_end, wal_start
307                )));
308            }
309            let data_len = payload.len() - 25;
310            if data_len > MAX_REPLICATION_XLOGDATA_BYTES {
311                return Err(PgError::Protocol(format!(
312                    "XLogData payload too large: {} bytes (max {})",
313                    data_len, MAX_REPLICATION_XLOGDATA_BYTES
314                )));
315            }
316            Ok(ReplicationStreamMessage::XLogData(ReplicationXLogData {
317                wal_start,
318                wal_end,
319                server_time_micros,
320                data: payload[25..].to_vec(),
321            }))
322        }
323        b'k' => {
324            if payload.len() != 18 {
325                return Err(PgError::Protocol(format!(
326                    "Keepalive payload must be 18 bytes, got {}",
327                    payload.len()
328                )));
329            }
330            let wal_end =
331                u64::from_be_bytes(payload[1..9].try_into().map_err(|_| {
332                    PgError::Protocol("Invalid keepalive wal_end bytes".to_string())
333                })?);
334            let server_time_micros = i64::from_be_bytes(
335                payload[9..17]
336                    .try_into()
337                    .map_err(|_| PgError::Protocol("Invalid keepalive time bytes".to_string()))?,
338            );
339            let reply_requested = match payload[17] {
340                0 => false,
341                1 => true,
342                other => {
343                    return Err(PgError::Protocol(format!(
344                        "Keepalive reply_requested must be 0 or 1, got {}",
345                        other
346                    )));
347                }
348            };
349            Ok(ReplicationStreamMessage::Keepalive(ReplicationKeepalive {
350                wal_end,
351                server_time_micros,
352                reply_requested,
353            }))
354        }
355        tag => Err(PgError::Protocol(format!(
356            "Unsupported replication CopyData tag '{}'",
357            if tag.is_ascii_graphic() {
358                tag as char
359            } else {
360                '?'
361            }
362        ))),
363    }
364}
365
366fn postgres_epoch_micros_now() -> i64 {
367    const PG_UNIX_EPOCH_DIFF_SECS: i64 = 946_684_800;
368    let now = std::time::SystemTime::now()
369        .duration_since(std::time::UNIX_EPOCH)
370        .unwrap_or_default();
371    let unix_micros = (now.as_secs() as i64) * 1_000_000 + i64::from(now.subsec_micros());
372    unix_micros - PG_UNIX_EPOCH_DIFF_SECS * 1_000_000
373}
374
375fn build_standby_status_update_payload(
376    write_lsn: u64,
377    flush_lsn: u64,
378    apply_lsn: u64,
379    reply_requested: bool,
380) -> Vec<u8> {
381    let mut payload = Vec::with_capacity(1 + 8 + 8 + 8 + 8 + 1);
382    payload.push(b'r');
383    payload.extend_from_slice(&write_lsn.to_be_bytes());
384    payload.extend_from_slice(&flush_lsn.to_be_bytes());
385    payload.extend_from_slice(&apply_lsn.to_be_bytes());
386    payload.extend_from_slice(&postgres_epoch_micros_now().to_be_bytes());
387    payload.push(if reply_requested { 1 } else { 0 });
388    payload
389}
390
391impl PgConnection {
392    #[inline]
393    fn ensure_replication_mode(&self, operation: &str) -> PgResult<()> {
394        if self.replication_mode_enabled {
395            return Ok(());
396        }
397        Err(PgError::Protocol(format!(
398            "{} requires connection startup param replication=database",
399            operation
400        )))
401    }
402
403    #[inline]
404    fn ensure_replication_control_idle(&self, operation: &str) -> PgResult<()> {
405        if !self.replication_stream_active {
406            return Ok(());
407        }
408        Err(PgError::Protocol(format!(
409            "{} cannot run while replication stream is active",
410            operation
411        )))
412    }
413
414    #[inline]
415    fn advance_replication_wal_end(&mut self, source: &str, wal_end: u64) -> PgResult<()> {
416        if let Some(prev_wal_end) = self.last_replication_wal_end
417            && wal_end < prev_wal_end
418        {
419            self.replication_stream_active = false;
420            self.last_replication_wal_end = None;
421            return Err(PgError::Protocol(format!(
422                "Replication {} wal_end regressed: previous {}, current {}",
423                source, prev_wal_end, wal_end
424            )));
425        }
426        self.last_replication_wal_end = Some(wal_end);
427        Ok(())
428    }
429
430    /// Run `IDENTIFY_SYSTEM` on a replication connection.
431    pub async fn identify_system(&mut self) -> PgResult<IdentifySystem> {
432        self.ensure_replication_mode("IDENTIFY_SYSTEM")?;
433        self.ensure_replication_control_idle("IDENTIFY_SYSTEM")?;
434        let rows = self.simple_query("IDENTIFY_SYSTEM").await?;
435        let row = rows
436            .first()
437            .ok_or_else(|| PgError::Protocol("IDENTIFY_SYSTEM returned no rows".to_string()))?;
438        parse_identify_system_row(row)
439    }
440
441    /// Create a logical replication slot.
442    ///
443    /// `slot_name` and `output_plugin` are strict SQL identifiers.
444    pub async fn create_logical_replication_slot(
445        &mut self,
446        slot_name: &str,
447        output_plugin: &str,
448        temporary: bool,
449        two_phase: bool,
450    ) -> PgResult<ReplicationSlotInfo> {
451        self.ensure_replication_mode("CREATE_REPLICATION_SLOT")?;
452        self.ensure_replication_control_idle("CREATE_REPLICATION_SLOT")?;
453        let sql = build_create_logical_replication_slot_sql(
454            slot_name,
455            output_plugin,
456            temporary,
457            two_phase,
458        )?;
459        let rows = self.simple_query(&sql).await?;
460        let row = rows.first().ok_or_else(|| {
461            PgError::Protocol("CREATE_REPLICATION_SLOT returned no rows".to_string())
462        })?;
463        parse_create_slot_row(row)
464    }
465
466    /// Drop a replication slot.
467    ///
468    /// `wait=true` uses `DROP_REPLICATION_SLOT <slot> WAIT`.
469    pub async fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> PgResult<()> {
470        self.ensure_replication_mode("DROP_REPLICATION_SLOT")?;
471        self.ensure_replication_control_idle("DROP_REPLICATION_SLOT")?;
472        let sql = build_drop_replication_slot_sql(slot_name, wait)?;
473        self.execute_simple(&sql).await
474    }
475
476    /// Start logical replication in CopyBoth mode.
477    ///
478    /// Requires a connection started with `replication=database`.
479    pub async fn start_logical_replication(
480        &mut self,
481        slot_name: &str,
482        start_lsn: &str,
483        options: &[ReplicationOption],
484    ) -> PgResult<ReplicationStreamStart> {
485        self.ensure_replication_mode("START_REPLICATION")?;
486        if self.replication_stream_active {
487            return Err(PgError::Protocol(
488                "logical replication stream already active".to_string(),
489            ));
490        }
491        let sql = build_start_logical_replication_sql(slot_name, start_lsn, options)?;
492        let bytes = PgEncoder::try_encode_query_string(&sql)?;
493        self.write_all_with_timeout(&bytes, "stream write").await?;
494
495        let mut startup_error: Option<PgError> = None;
496        loop {
497            let msg = self.recv().await?;
498            match msg {
499                BackendMessage::CopyBothResponse {
500                    format,
501                    column_formats,
502                } => {
503                    if let Some(err) = startup_error {
504                        return Err(err);
505                    }
506                    if format != 0 {
507                        return Err(PgError::Protocol(format!(
508                            "START_REPLICATION returned unsupported CopyBothResponse format {} (expected 0/text)",
509                            format
510                        )));
511                    }
512                    if !column_formats.is_empty() {
513                        return Err(PgError::Protocol(format!(
514                            "START_REPLICATION returned unexpected CopyBothResponse column formats (expected none, got {})",
515                            column_formats.len()
516                        )));
517                    }
518                    self.replication_stream_active = true;
519                    self.last_replication_wal_end = None;
520                    return Ok(ReplicationStreamStart {
521                        format,
522                        column_formats,
523                    });
524                }
525                BackendMessage::ReadyForQuery(_) => {
526                    return Err(startup_error.unwrap_or_else(|| {
527                        PgError::Protocol(
528                            "START_REPLICATION ended before CopyBothResponse".to_string(),
529                        )
530                    }));
531                }
532                BackendMessage::ErrorResponse(err) => {
533                    if startup_error.is_none() {
534                        startup_error = Some(PgError::QueryServer(err.into()));
535                    }
536                }
537                msg if is_ignorable_session_message(&msg) => {}
538                other => return Err(unexpected_backend_message("start replication", &other)),
539            }
540        }
541    }
542
543    /// Receive the next logical replication stream message.
544    ///
545    /// Uses a no-timeout read path so idle periods do not fail the stream.
546    pub async fn recv_replication_message(&mut self) -> PgResult<ReplicationStreamMessage> {
547        self.ensure_replication_mode("recv_replication_message")?;
548        if !self.replication_stream_active {
549            return Err(PgError::Protocol(
550                "replication stream is not active; call START_REPLICATION first".to_string(),
551            ));
552        }
553        loop {
554            let msg = self.recv_without_timeout().await?;
555            match msg {
556                BackendMessage::CopyData(payload) => match parse_copy_data_message(&payload) {
557                    Ok(ReplicationStreamMessage::XLogData(x)) => {
558                        self.advance_replication_wal_end("XLogData", x.wal_end)?;
559                        return Ok(ReplicationStreamMessage::XLogData(x));
560                    }
561                    Ok(ReplicationStreamMessage::Keepalive(k)) => {
562                        self.advance_replication_wal_end("keepalive", k.wal_end)?;
563                        return Ok(ReplicationStreamMessage::Keepalive(k));
564                    }
565                    Ok(parsed) => return Ok(parsed),
566                    Err(err) => {
567                        self.replication_stream_active = false;
568                        self.last_replication_wal_end = None;
569                        return Err(err);
570                    }
571                },
572                BackendMessage::ErrorResponse(err) => {
573                    self.replication_stream_active = false;
574                    self.last_replication_wal_end = None;
575                    return Err(PgError::QueryServer(err.into()));
576                }
577                BackendMessage::CopyDone => {
578                    self.replication_stream_active = false;
579                    self.last_replication_wal_end = None;
580                    return Err(PgError::Protocol(
581                        "Replication stream ended with CopyDone".to_string(),
582                    ));
583                }
584                BackendMessage::ReadyForQuery(_) => {
585                    self.replication_stream_active = false;
586                    self.last_replication_wal_end = None;
587                    return Err(PgError::Protocol(
588                        "Replication stream ended with ReadyForQuery".to_string(),
589                    ));
590                }
591                msg if is_ignorable_session_message(&msg) => {}
592                other => {
593                    self.replication_stream_active = false;
594                    self.last_replication_wal_end = None;
595                    return Err(unexpected_backend_message("replication stream", &other));
596                }
597            }
598        }
599    }
600
601    /// Send a standby status update (`CopyData('r' ...)`) to the server.
602    pub async fn send_standby_status_update(
603        &mut self,
604        write_lsn: u64,
605        flush_lsn: u64,
606        apply_lsn: u64,
607        reply_requested: bool,
608    ) -> PgResult<()> {
609        self.ensure_replication_mode("send_standby_status_update")?;
610        if !self.replication_stream_active {
611            return Err(PgError::Protocol(
612                "replication stream is not active; call START_REPLICATION first".to_string(),
613            ));
614        }
615        if flush_lsn > write_lsn {
616            return Err(PgError::Protocol(format!(
617                "Invalid standby status update: flush_lsn {} exceeds write_lsn {}",
618                flush_lsn, write_lsn
619            )));
620        }
621        if apply_lsn > flush_lsn {
622            return Err(PgError::Protocol(format!(
623                "Invalid standby status update: apply_lsn {} exceeds flush_lsn {}",
624                apply_lsn, flush_lsn
625            )));
626        }
627        if let Some(last_wal_end) = self.last_replication_wal_end
628            && write_lsn > last_wal_end
629        {
630            return Err(PgError::Protocol(format!(
631                "Invalid standby status update: write_lsn {} exceeds last seen server wal_end {}",
632                write_lsn, last_wal_end
633            )));
634        }
635        let payload =
636            build_standby_status_update_payload(write_lsn, flush_lsn, apply_lsn, reply_requested);
637        self.send_copy_data(&payload).await
638    }
639}
640
641impl PgDriver {
642    /// Driver wrapper for [`PgConnection::identify_system`].
643    pub async fn identify_system(&mut self) -> PgResult<IdentifySystem> {
644        self.connection.identify_system().await
645    }
646
647    /// Driver wrapper for [`PgConnection::create_logical_replication_slot`].
648    pub async fn create_logical_replication_slot(
649        &mut self,
650        slot_name: &str,
651        output_plugin: &str,
652        temporary: bool,
653        two_phase: bool,
654    ) -> PgResult<ReplicationSlotInfo> {
655        self.connection
656            .create_logical_replication_slot(slot_name, output_plugin, temporary, two_phase)
657            .await
658    }
659
660    /// Driver wrapper for [`PgConnection::drop_replication_slot`].
661    pub async fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> PgResult<()> {
662        self.connection.drop_replication_slot(slot_name, wait).await
663    }
664
665    /// Driver wrapper for [`PgConnection::start_logical_replication`].
666    pub async fn start_logical_replication(
667        &mut self,
668        slot_name: &str,
669        start_lsn: &str,
670        options: &[ReplicationOption],
671    ) -> PgResult<ReplicationStreamStart> {
672        self.connection
673            .start_logical_replication(slot_name, start_lsn, options)
674            .await
675    }
676
677    /// Driver wrapper for [`PgConnection::recv_replication_message`].
678    pub async fn recv_replication_message(&mut self) -> PgResult<ReplicationStreamMessage> {
679        self.connection.recv_replication_message().await
680    }
681
682    /// Driver wrapper for [`PgConnection::send_standby_status_update`].
683    pub async fn send_standby_status_update(
684        &mut self,
685        write_lsn: u64,
686        flush_lsn: u64,
687        apply_lsn: u64,
688        reply_requested: bool,
689    ) -> PgResult<()> {
690        self.connection
691            .send_standby_status_update(write_lsn, flush_lsn, apply_lsn, reply_requested)
692            .await
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    fn text_row(values: &[Option<&str>]) -> PgRow {
701        PgRow {
702            columns: values
703                .iter()
704                .map(|v| v.map(|s| s.as_bytes().to_vec()))
705                .collect(),
706            column_info: None,
707        }
708    }
709
710    #[test]
711    fn validate_ident_rejects_bad_names() {
712        assert!(validate_ident("slot_name", "").is_err());
713        assert!(validate_ident("slot_name", "9slot").is_err());
714        assert!(validate_ident("slot_name", "bad-name").is_err());
715        assert!(validate_ident("slot_name", "has space").is_err());
716    }
717
718    #[test]
719    fn validate_ident_accepts_safe_names() {
720        assert!(validate_ident("slot_name", "slot_a1").is_ok());
721        assert!(validate_ident("output_plugin", "pgoutput").is_ok());
722    }
723
724    #[test]
725    fn parse_and_format_lsn_roundtrip() {
726        let lsn = parse_lsn_text("16/B6C50").unwrap();
727        assert_eq!(format_lsn(lsn), "16/000B6C50");
728    }
729
730    #[test]
731    fn build_create_logical_replication_slot_sql_variants() {
732        let sql =
733            build_create_logical_replication_slot_sql("slot_main", "pgoutput", true, true).unwrap();
734        assert_eq!(
735            sql,
736            "CREATE_REPLICATION_SLOT slot_main TEMPORARY LOGICAL pgoutput TWO_PHASE"
737        );
738    }
739
740    #[test]
741    fn build_drop_replication_slot_sql_variants() {
742        let sql = build_drop_replication_slot_sql("slot_main", true).unwrap();
743        assert_eq!(sql, "DROP_REPLICATION_SLOT slot_main WAIT");
744    }
745
746    #[test]
747    fn build_start_logical_replication_sql_with_options() {
748        let sql = build_start_logical_replication_sql(
749            "slot_main",
750            "0/16B6C50",
751            &[
752                ReplicationOption {
753                    key: "proto_version".to_string(),
754                    value: "1".to_string(),
755                },
756                ReplicationOption {
757                    key: "publication_names".to_string(),
758                    value: "pub1,pub2".to_string(),
759                },
760            ],
761        )
762        .unwrap();
763        assert_eq!(
764            sql,
765            "START_REPLICATION SLOT slot_main LOGICAL 0/16B6C50 (proto_version '1', publication_names 'pub1,pub2')"
766        );
767    }
768
769    #[test]
770    fn build_start_logical_replication_sql_rejects_too_many_options() {
771        let options: Vec<ReplicationOption> = (0..=MAX_REPLICATION_OPTIONS)
772            .map(|i| ReplicationOption {
773                key: format!("opt{}", i),
774                value: "x".to_string(),
775            })
776            .collect();
777
778        let err =
779            build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
780        assert!(err.to_string().contains("too many replication options"));
781    }
782
783    #[test]
784    fn build_start_logical_replication_sql_rejects_null_value() {
785        let options = vec![ReplicationOption {
786            key: "proto_version".to_string(),
787            value: "1\0oops".to_string(),
788        }];
789        let err =
790            build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
791        assert!(err.to_string().contains("contains NUL byte"));
792    }
793
794    #[test]
795    fn build_start_logical_replication_sql_rejects_oversized_value() {
796        let options = vec![ReplicationOption {
797            key: "publication_names".to_string(),
798            value: "x".repeat(MAX_REPLICATION_OPTION_VALUE_BYTES + 1),
799        }];
800        let err =
801            build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
802        assert!(err.to_string().contains("value too large"));
803    }
804
805    #[test]
806    fn parse_identify_system_row_happy_path() {
807        let row = text_row(&[
808            Some("7416469842679442267"),
809            Some("1"),
810            Some("0/16B6C50"),
811            Some("app"),
812        ]);
813        let parsed = parse_identify_system_row(&row).unwrap();
814        assert_eq!(parsed.system_id, "7416469842679442267");
815        assert_eq!(parsed.timeline, 1);
816        assert_eq!(parsed.xlog_pos, "0/16B6C50");
817        assert_eq!(parsed.dbname.as_deref(), Some("app"));
818    }
819
820    #[test]
821    fn parse_create_slot_row_happy_path() {
822        let row = text_row(&[
823            Some("slot_main"),
824            Some("0/16B6C88"),
825            Some("00000003-00000041-1"),
826            Some("pgoutput"),
827        ]);
828        let parsed = parse_create_slot_row(&row).unwrap();
829        assert_eq!(parsed.slot_name, "slot_main");
830        assert_eq!(parsed.consistent_point, "0/16B6C88");
831        assert_eq!(parsed.snapshot_name.as_deref(), Some("00000003-00000041-1"));
832        assert_eq!(parsed.output_plugin, "pgoutput");
833    }
834
835    #[test]
836    fn parse_copy_data_xlog_data() {
837        let mut payload = Vec::new();
838        payload.push(b'w');
839        payload.extend_from_slice(&0x10_u64.to_be_bytes());
840        payload.extend_from_slice(&0x20_u64.to_be_bytes());
841        payload.extend_from_slice(&123_i64.to_be_bytes());
842        payload.extend_from_slice(b"hello");
843
844        match parse_copy_data_message(&payload).unwrap() {
845            ReplicationStreamMessage::XLogData(x) => {
846                assert_eq!(x.wal_start, 0x10);
847                assert_eq!(x.wal_end, 0x20);
848                assert_eq!(x.server_time_micros, 123);
849                assert_eq!(x.data, b"hello");
850            }
851            _ => panic!("expected xlog data"),
852        }
853    }
854
855    #[test]
856    fn parse_copy_data_xlog_data_rejects_wal_end_before_start() {
857        let mut payload = Vec::new();
858        payload.push(b'w');
859        payload.extend_from_slice(&0x20_u64.to_be_bytes());
860        payload.extend_from_slice(&0x10_u64.to_be_bytes());
861        payload.extend_from_slice(&123_i64.to_be_bytes());
862        let err = parse_copy_data_message(&payload).unwrap_err();
863        assert!(err.to_string().contains("wal_end"));
864    }
865
866    #[test]
867    fn parse_copy_data_xlog_data_rejects_oversized_data() {
868        let mut payload = Vec::with_capacity(25 + MAX_REPLICATION_XLOGDATA_BYTES + 1);
869        payload.push(b'w');
870        payload.extend_from_slice(&0x10_u64.to_be_bytes());
871        payload.extend_from_slice(&0x20_u64.to_be_bytes());
872        payload.extend_from_slice(&123_i64.to_be_bytes());
873        payload.extend(std::iter::repeat_n(0u8, MAX_REPLICATION_XLOGDATA_BYTES + 1));
874        let err = parse_copy_data_message(&payload).unwrap_err();
875        assert!(err.to_string().contains("payload too large"));
876    }
877
878    #[test]
879    fn parse_copy_data_keepalive() {
880        let mut payload = Vec::new();
881        payload.push(b'k');
882        payload.extend_from_slice(&0xAB_u64.to_be_bytes());
883        payload.extend_from_slice(&456_i64.to_be_bytes());
884        payload.push(1);
885
886        match parse_copy_data_message(&payload).unwrap() {
887            ReplicationStreamMessage::Keepalive(k) => {
888                assert_eq!(k.wal_end, 0xAB);
889                assert_eq!(k.server_time_micros, 456);
890                assert!(k.reply_requested);
891            }
892            _ => panic!("expected keepalive"),
893        }
894    }
895
896    #[test]
897    fn parse_copy_data_keepalive_rejects_invalid_reply_requested() {
898        let mut payload = Vec::new();
899        payload.push(b'k');
900        payload.extend_from_slice(&0xAB_u64.to_be_bytes());
901        payload.extend_from_slice(&456_i64.to_be_bytes());
902        payload.push(2);
903        let err = parse_copy_data_message(&payload).unwrap_err();
904        assert!(err.to_string().contains("reply_requested must be 0 or 1"));
905    }
906
907    #[test]
908    fn parse_copy_data_unknown_tag_rejected() {
909        let payload = vec![b'x', 1, 2, 3];
910        let err = parse_copy_data_message(&payload).unwrap_err();
911        assert!(
912            err.to_string()
913                .contains("Unsupported replication CopyData tag")
914        );
915    }
916
917    #[test]
918    fn build_standby_status_update_payload_layout() {
919        let payload = build_standby_status_update_payload(1, 2, 3, true);
920        assert_eq!(payload.len(), 34);
921        assert_eq!(payload[0], b'r');
922        assert_eq!(u64::from_be_bytes(payload[1..9].try_into().unwrap()), 1);
923        assert_eq!(u64::from_be_bytes(payload[9..17].try_into().unwrap()), 2);
924        assert_eq!(u64::from_be_bytes(payload[17..25].try_into().unwrap()), 3);
925        assert_eq!(payload[33], 1);
926    }
927}