1use super::{
12 PgConnection, PgDriver, PgError, PgResult, PgRow, is_ignorable_session_message,
13 unexpected_backend_message,
14};
15use crate::protocol::{BackendMessage, PgEncoder};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct IdentifySystem {
20 pub system_id: String,
22 pub timeline: u32,
24 pub xlog_pos: String,
26 pub dbname: Option<String>,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ReplicationSlotInfo {
33 pub slot_name: String,
35 pub consistent_point: String,
37 pub snapshot_name: Option<String>,
39 pub output_plugin: String,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct ReplicationStreamStart {
46 pub format: u8,
48 pub column_formats: Vec<u8>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ReplicationXLogData {
55 pub wal_start: u64,
57 pub wal_end: u64,
59 pub server_time_micros: i64,
61 pub data: Vec<u8>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct ReplicationKeepalive {
68 pub wal_end: u64,
70 pub server_time_micros: i64,
72 pub reply_requested: bool,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum ReplicationStreamMessage {
79 XLogData(ReplicationXLogData),
81 Keepalive(ReplicationKeepalive),
83 Raw { tag: u8, payload: Vec<u8> },
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct ReplicationOption {
90 pub key: String,
92 pub value: String,
94}
95
96const MAX_REPLICATION_OPTIONS: usize = 64;
97const MAX_REPLICATION_OPTION_VALUE_BYTES: usize = 16 * 1024;
98const MAX_REPLICATION_XLOGDATA_BYTES: usize = 16 * 1024 * 1024;
99
100fn validate_ident(kind: &str, ident: &str) -> PgResult<()> {
101 if ident.is_empty() {
102 return Err(PgError::Query(format!("{} must not be empty", kind)));
103 }
104 if ident.len() > 63 {
105 return Err(PgError::Query(format!(
106 "{} '{}' exceeds PostgreSQL identifier limit (63 bytes)",
107 kind, ident
108 )));
109 }
110 let mut chars = ident.chars();
111 match chars.next() {
112 Some(c) if c == '_' || c.is_ascii_alphabetic() => {}
113 _ => {
114 return Err(PgError::Query(format!(
115 "{} '{}' must start with [A-Za-z_]",
116 kind, ident
117 )));
118 }
119 }
120 if !chars.all(|c| c == '_' || c.is_ascii_alphanumeric()) {
121 return Err(PgError::Query(format!(
122 "{} '{}' contains unsupported characters (allowed: [A-Za-z0-9_])",
123 kind, ident
124 )));
125 }
126 Ok(())
127}
128
129fn sql_single_quote_literal(value: &str) -> PgResult<String> {
130 if value.contains('\0') {
131 return Err(PgError::Query(
132 "replication option value contains NUL byte".to_string(),
133 ));
134 }
135 Ok(value.replace('\'', "''"))
136}
137
138fn parse_lsn_text(lsn: &str) -> PgResult<u64> {
139 let mut parts = lsn.split('/');
140 let high = parts
141 .next()
142 .ok_or_else(|| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
143 let low = parts
144 .next()
145 .ok_or_else(|| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
146 if parts.next().is_some() {
147 return Err(PgError::Query(format!("Invalid LSN '{}'", lsn)));
148 }
149 let high = u32::from_str_radix(high, 16)
150 .map_err(|_| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
151 let low = u32::from_str_radix(low, 16)
152 .map_err(|_| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
153 Ok(((high as u64) << 32) | (low as u64))
154}
155
156#[cfg(test)]
157fn format_lsn(lsn: u64) -> String {
158 format!("{:X}/{:08X}", (lsn >> 32) as u32, lsn as u32)
159}
160
161fn required_text(row: &PgRow, idx: usize, field: &str) -> PgResult<String> {
162 row.get_string(idx).ok_or_else(|| {
163 PgError::Protocol(format!("Missing or invalid '{}' in replication row", field))
164 })
165}
166
167fn parse_identify_system_row(row: &PgRow) -> PgResult<IdentifySystem> {
168 let system_id = required_text(row, 0, "systemid")?;
169 let timeline = required_text(row, 1, "timeline")?
170 .parse::<u32>()
171 .map_err(|e| PgError::Protocol(format!("Invalid timeline value: {}", e)))?;
172 let xlog_pos = required_text(row, 2, "xlogpos")?;
173 let dbname = row
174 .get_string(3)
175 .and_then(|v| if v.is_empty() { None } else { Some(v) });
176
177 Ok(IdentifySystem {
178 system_id,
179 timeline,
180 xlog_pos,
181 dbname,
182 })
183}
184
185fn parse_create_slot_row(row: &PgRow) -> PgResult<ReplicationSlotInfo> {
186 let slot_name = required_text(row, 0, "slot_name")?;
187 let consistent_point = required_text(row, 1, "consistent_point")?;
188 let snapshot_name = row
189 .get_string(2)
190 .and_then(|v| if v.is_empty() { None } else { Some(v) });
191 let output_plugin = required_text(row, 3, "output_plugin")?;
192
193 Ok(ReplicationSlotInfo {
194 slot_name,
195 consistent_point,
196 snapshot_name,
197 output_plugin,
198 })
199}
200
201fn build_create_logical_replication_slot_sql(
202 slot_name: &str,
203 output_plugin: &str,
204 temporary: bool,
205 two_phase: bool,
206) -> PgResult<String> {
207 validate_ident("slot_name", slot_name)?;
208 validate_ident("output_plugin", output_plugin)?;
209
210 let mut sql = String::from("CREATE_REPLICATION_SLOT ");
211 sql.push_str(slot_name);
212 if temporary {
213 sql.push_str(" TEMPORARY");
214 }
215 sql.push_str(" LOGICAL ");
216 sql.push_str(output_plugin);
217 if two_phase {
218 sql.push_str(" TWO_PHASE");
219 }
220 Ok(sql)
221}
222
223fn build_drop_replication_slot_sql(slot_name: &str, wait: bool) -> PgResult<String> {
224 validate_ident("slot_name", slot_name)?;
225 let mut sql = String::from("DROP_REPLICATION_SLOT ");
226 sql.push_str(slot_name);
227 if wait {
228 sql.push_str(" WAIT");
229 }
230 Ok(sql)
231}
232
233fn build_start_logical_replication_sql(
234 slot_name: &str,
235 start_lsn: &str,
236 options: &[ReplicationOption],
237) -> PgResult<String> {
238 validate_ident("slot_name", slot_name)?;
239 let _ = parse_lsn_text(start_lsn)?;
240 if options.len() > MAX_REPLICATION_OPTIONS {
241 return Err(PgError::Query(format!(
242 "too many replication options: {} (max {})",
243 options.len(),
244 MAX_REPLICATION_OPTIONS
245 )));
246 }
247
248 let mut sql = format!("START_REPLICATION SLOT {} LOGICAL {}", slot_name, start_lsn);
249 if !options.is_empty() {
250 sql.push_str(" (");
251 for (idx, opt) in options.iter().enumerate() {
252 validate_ident("replication option key", &opt.key)?;
253 if opt.value.len() > MAX_REPLICATION_OPTION_VALUE_BYTES {
254 return Err(PgError::Query(format!(
255 "replication option '{}' value too large: {} bytes (max {})",
256 opt.key,
257 opt.value.len(),
258 MAX_REPLICATION_OPTION_VALUE_BYTES
259 )));
260 }
261 if idx > 0 {
262 sql.push_str(", ");
263 }
264 sql.push_str(&opt.key);
265 sql.push_str(" '");
266 sql.push_str(&sql_single_quote_literal(&opt.value)?);
267 sql.push('\'');
268 }
269 sql.push(')');
270 }
271 Ok(sql)
272}
273
274fn parse_copy_data_message(payload: &[u8]) -> PgResult<ReplicationStreamMessage> {
275 if payload.is_empty() {
276 return Err(PgError::Protocol(
277 "Replication CopyData payload is empty".to_string(),
278 ));
279 }
280 match payload[0] {
281 b'w' => {
282 if payload.len() < 25 {
283 return Err(PgError::Protocol(format!(
284 "XLogData payload too short: {} bytes",
285 payload.len()
286 )));
287 }
288 let wal_start = u64::from_be_bytes(
289 payload[1..9]
290 .try_into()
291 .map_err(|_| PgError::Protocol("Invalid wal_start bytes".to_string()))?,
292 );
293 let wal_end = u64::from_be_bytes(
294 payload[9..17]
295 .try_into()
296 .map_err(|_| PgError::Protocol("Invalid wal_end bytes".to_string()))?,
297 );
298 let server_time_micros = i64::from_be_bytes(
299 payload[17..25]
300 .try_into()
301 .map_err(|_| PgError::Protocol("Invalid server time bytes".to_string()))?,
302 );
303 if wal_end < wal_start {
304 return Err(PgError::Protocol(format!(
305 "XLogData wal_end {} is behind wal_start {}",
306 wal_end, wal_start
307 )));
308 }
309 let data_len = payload.len() - 25;
310 if data_len > MAX_REPLICATION_XLOGDATA_BYTES {
311 return Err(PgError::Protocol(format!(
312 "XLogData payload too large: {} bytes (max {})",
313 data_len, MAX_REPLICATION_XLOGDATA_BYTES
314 )));
315 }
316 Ok(ReplicationStreamMessage::XLogData(ReplicationXLogData {
317 wal_start,
318 wal_end,
319 server_time_micros,
320 data: payload[25..].to_vec(),
321 }))
322 }
323 b'k' => {
324 if payload.len() != 18 {
325 return Err(PgError::Protocol(format!(
326 "Keepalive payload must be 18 bytes, got {}",
327 payload.len()
328 )));
329 }
330 let wal_end =
331 u64::from_be_bytes(payload[1..9].try_into().map_err(|_| {
332 PgError::Protocol("Invalid keepalive wal_end bytes".to_string())
333 })?);
334 let server_time_micros = i64::from_be_bytes(
335 payload[9..17]
336 .try_into()
337 .map_err(|_| PgError::Protocol("Invalid keepalive time bytes".to_string()))?,
338 );
339 let reply_requested = match payload[17] {
340 0 => false,
341 1 => true,
342 other => {
343 return Err(PgError::Protocol(format!(
344 "Keepalive reply_requested must be 0 or 1, got {}",
345 other
346 )));
347 }
348 };
349 Ok(ReplicationStreamMessage::Keepalive(ReplicationKeepalive {
350 wal_end,
351 server_time_micros,
352 reply_requested,
353 }))
354 }
355 tag => Err(PgError::Protocol(format!(
356 "Unsupported replication CopyData tag '{}'",
357 if tag.is_ascii_graphic() {
358 tag as char
359 } else {
360 '?'
361 }
362 ))),
363 }
364}
365
366fn postgres_epoch_micros_now() -> i64 {
367 const PG_UNIX_EPOCH_DIFF_SECS: i64 = 946_684_800;
368 let now = std::time::SystemTime::now()
369 .duration_since(std::time::UNIX_EPOCH)
370 .unwrap_or_default();
371 let unix_micros = (now.as_secs() as i64) * 1_000_000 + i64::from(now.subsec_micros());
372 unix_micros - PG_UNIX_EPOCH_DIFF_SECS * 1_000_000
373}
374
375fn build_standby_status_update_payload(
376 write_lsn: u64,
377 flush_lsn: u64,
378 apply_lsn: u64,
379 reply_requested: bool,
380) -> Vec<u8> {
381 let mut payload = Vec::with_capacity(1 + 8 + 8 + 8 + 8 + 1);
382 payload.push(b'r');
383 payload.extend_from_slice(&write_lsn.to_be_bytes());
384 payload.extend_from_slice(&flush_lsn.to_be_bytes());
385 payload.extend_from_slice(&apply_lsn.to_be_bytes());
386 payload.extend_from_slice(&postgres_epoch_micros_now().to_be_bytes());
387 payload.push(if reply_requested { 1 } else { 0 });
388 payload
389}
390
391impl PgConnection {
392 #[inline]
393 fn ensure_replication_mode(&self, operation: &str) -> PgResult<()> {
394 if self.replication_mode_enabled {
395 return Ok(());
396 }
397 Err(PgError::Protocol(format!(
398 "{} requires connection startup param replication=database",
399 operation
400 )))
401 }
402
403 #[inline]
404 fn ensure_replication_control_idle(&self, operation: &str) -> PgResult<()> {
405 if !self.replication_stream_active {
406 return Ok(());
407 }
408 Err(PgError::Protocol(format!(
409 "{} cannot run while replication stream is active",
410 operation
411 )))
412 }
413
414 #[inline]
415 fn advance_replication_wal_end(&mut self, source: &str, wal_end: u64) -> PgResult<()> {
416 if let Some(prev_wal_end) = self.last_replication_wal_end
417 && wal_end < prev_wal_end
418 {
419 self.replication_stream_active = false;
420 self.last_replication_wal_end = None;
421 return Err(PgError::Protocol(format!(
422 "Replication {} wal_end regressed: previous {}, current {}",
423 source, prev_wal_end, wal_end
424 )));
425 }
426 self.last_replication_wal_end = Some(wal_end);
427 Ok(())
428 }
429
430 pub async fn identify_system(&mut self) -> PgResult<IdentifySystem> {
432 self.ensure_replication_mode("IDENTIFY_SYSTEM")?;
433 self.ensure_replication_control_idle("IDENTIFY_SYSTEM")?;
434 let rows = self.simple_query("IDENTIFY_SYSTEM").await?;
435 let row = rows
436 .first()
437 .ok_or_else(|| PgError::Protocol("IDENTIFY_SYSTEM returned no rows".to_string()))?;
438 parse_identify_system_row(row)
439 }
440
441 pub async fn create_logical_replication_slot(
445 &mut self,
446 slot_name: &str,
447 output_plugin: &str,
448 temporary: bool,
449 two_phase: bool,
450 ) -> PgResult<ReplicationSlotInfo> {
451 self.ensure_replication_mode("CREATE_REPLICATION_SLOT")?;
452 self.ensure_replication_control_idle("CREATE_REPLICATION_SLOT")?;
453 let sql = build_create_logical_replication_slot_sql(
454 slot_name,
455 output_plugin,
456 temporary,
457 two_phase,
458 )?;
459 let rows = self.simple_query(&sql).await?;
460 let row = rows.first().ok_or_else(|| {
461 PgError::Protocol("CREATE_REPLICATION_SLOT returned no rows".to_string())
462 })?;
463 parse_create_slot_row(row)
464 }
465
466 pub async fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> PgResult<()> {
470 self.ensure_replication_mode("DROP_REPLICATION_SLOT")?;
471 self.ensure_replication_control_idle("DROP_REPLICATION_SLOT")?;
472 let sql = build_drop_replication_slot_sql(slot_name, wait)?;
473 self.execute_simple(&sql).await
474 }
475
476 pub async fn start_logical_replication(
480 &mut self,
481 slot_name: &str,
482 start_lsn: &str,
483 options: &[ReplicationOption],
484 ) -> PgResult<ReplicationStreamStart> {
485 self.ensure_replication_mode("START_REPLICATION")?;
486 if self.replication_stream_active {
487 return Err(PgError::Protocol(
488 "logical replication stream already active".to_string(),
489 ));
490 }
491 let sql = build_start_logical_replication_sql(slot_name, start_lsn, options)?;
492 let bytes = PgEncoder::try_encode_query_string(&sql)?;
493 self.write_all_with_timeout(&bytes, "stream write").await?;
494
495 let mut startup_error: Option<PgError> = None;
496 loop {
497 let msg = self.recv().await?;
498 match msg {
499 BackendMessage::CopyBothResponse {
500 format,
501 column_formats,
502 } => {
503 if let Some(err) = startup_error {
504 return Err(err);
505 }
506 if format != 0 {
507 return Err(PgError::Protocol(format!(
508 "START_REPLICATION returned unsupported CopyBothResponse format {} (expected 0/text)",
509 format
510 )));
511 }
512 if !column_formats.is_empty() {
513 return Err(PgError::Protocol(format!(
514 "START_REPLICATION returned unexpected CopyBothResponse column formats (expected none, got {})",
515 column_formats.len()
516 )));
517 }
518 self.replication_stream_active = true;
519 self.last_replication_wal_end = None;
520 return Ok(ReplicationStreamStart {
521 format,
522 column_formats,
523 });
524 }
525 BackendMessage::ReadyForQuery(_) => {
526 return Err(startup_error.unwrap_or_else(|| {
527 PgError::Protocol(
528 "START_REPLICATION ended before CopyBothResponse".to_string(),
529 )
530 }));
531 }
532 BackendMessage::ErrorResponse(err) => {
533 if startup_error.is_none() {
534 startup_error = Some(PgError::QueryServer(err.into()));
535 }
536 }
537 msg if is_ignorable_session_message(&msg) => {}
538 other => return Err(unexpected_backend_message("start replication", &other)),
539 }
540 }
541 }
542
543 pub async fn recv_replication_message(&mut self) -> PgResult<ReplicationStreamMessage> {
547 self.ensure_replication_mode("recv_replication_message")?;
548 if !self.replication_stream_active {
549 return Err(PgError::Protocol(
550 "replication stream is not active; call START_REPLICATION first".to_string(),
551 ));
552 }
553 loop {
554 let msg = self.recv_without_timeout().await?;
555 match msg {
556 BackendMessage::CopyData(payload) => match parse_copy_data_message(&payload) {
557 Ok(ReplicationStreamMessage::XLogData(x)) => {
558 self.advance_replication_wal_end("XLogData", x.wal_end)?;
559 return Ok(ReplicationStreamMessage::XLogData(x));
560 }
561 Ok(ReplicationStreamMessage::Keepalive(k)) => {
562 self.advance_replication_wal_end("keepalive", k.wal_end)?;
563 return Ok(ReplicationStreamMessage::Keepalive(k));
564 }
565 Ok(parsed) => return Ok(parsed),
566 Err(err) => {
567 self.replication_stream_active = false;
568 self.last_replication_wal_end = None;
569 return Err(err);
570 }
571 },
572 BackendMessage::ErrorResponse(err) => {
573 self.replication_stream_active = false;
574 self.last_replication_wal_end = None;
575 return Err(PgError::QueryServer(err.into()));
576 }
577 BackendMessage::CopyDone => {
578 self.replication_stream_active = false;
579 self.last_replication_wal_end = None;
580 return Err(PgError::Protocol(
581 "Replication stream ended with CopyDone".to_string(),
582 ));
583 }
584 BackendMessage::ReadyForQuery(_) => {
585 self.replication_stream_active = false;
586 self.last_replication_wal_end = None;
587 return Err(PgError::Protocol(
588 "Replication stream ended with ReadyForQuery".to_string(),
589 ));
590 }
591 msg if is_ignorable_session_message(&msg) => {}
592 other => {
593 self.replication_stream_active = false;
594 self.last_replication_wal_end = None;
595 return Err(unexpected_backend_message("replication stream", &other));
596 }
597 }
598 }
599 }
600
601 pub async fn send_standby_status_update(
603 &mut self,
604 write_lsn: u64,
605 flush_lsn: u64,
606 apply_lsn: u64,
607 reply_requested: bool,
608 ) -> PgResult<()> {
609 self.ensure_replication_mode("send_standby_status_update")?;
610 if !self.replication_stream_active {
611 return Err(PgError::Protocol(
612 "replication stream is not active; call START_REPLICATION first".to_string(),
613 ));
614 }
615 if flush_lsn > write_lsn {
616 return Err(PgError::Protocol(format!(
617 "Invalid standby status update: flush_lsn {} exceeds write_lsn {}",
618 flush_lsn, write_lsn
619 )));
620 }
621 if apply_lsn > flush_lsn {
622 return Err(PgError::Protocol(format!(
623 "Invalid standby status update: apply_lsn {} exceeds flush_lsn {}",
624 apply_lsn, flush_lsn
625 )));
626 }
627 if let Some(last_wal_end) = self.last_replication_wal_end
628 && write_lsn > last_wal_end
629 {
630 return Err(PgError::Protocol(format!(
631 "Invalid standby status update: write_lsn {} exceeds last seen server wal_end {}",
632 write_lsn, last_wal_end
633 )));
634 }
635 let payload =
636 build_standby_status_update_payload(write_lsn, flush_lsn, apply_lsn, reply_requested);
637 self.send_copy_data(&payload).await
638 }
639}
640
641impl PgDriver {
642 pub async fn identify_system(&mut self) -> PgResult<IdentifySystem> {
644 self.connection.identify_system().await
645 }
646
647 pub async fn create_logical_replication_slot(
649 &mut self,
650 slot_name: &str,
651 output_plugin: &str,
652 temporary: bool,
653 two_phase: bool,
654 ) -> PgResult<ReplicationSlotInfo> {
655 self.connection
656 .create_logical_replication_slot(slot_name, output_plugin, temporary, two_phase)
657 .await
658 }
659
660 pub async fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> PgResult<()> {
662 self.connection.drop_replication_slot(slot_name, wait).await
663 }
664
665 pub async fn start_logical_replication(
667 &mut self,
668 slot_name: &str,
669 start_lsn: &str,
670 options: &[ReplicationOption],
671 ) -> PgResult<ReplicationStreamStart> {
672 self.connection
673 .start_logical_replication(slot_name, start_lsn, options)
674 .await
675 }
676
677 pub async fn recv_replication_message(&mut self) -> PgResult<ReplicationStreamMessage> {
679 self.connection.recv_replication_message().await
680 }
681
682 pub async fn send_standby_status_update(
684 &mut self,
685 write_lsn: u64,
686 flush_lsn: u64,
687 apply_lsn: u64,
688 reply_requested: bool,
689 ) -> PgResult<()> {
690 self.connection
691 .send_standby_status_update(write_lsn, flush_lsn, apply_lsn, reply_requested)
692 .await
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699
700 fn text_row(values: &[Option<&str>]) -> PgRow {
701 PgRow {
702 columns: values
703 .iter()
704 .map(|v| v.map(|s| s.as_bytes().to_vec()))
705 .collect(),
706 column_info: None,
707 }
708 }
709
710 #[test]
711 fn validate_ident_rejects_bad_names() {
712 assert!(validate_ident("slot_name", "").is_err());
713 assert!(validate_ident("slot_name", "9slot").is_err());
714 assert!(validate_ident("slot_name", "bad-name").is_err());
715 assert!(validate_ident("slot_name", "has space").is_err());
716 }
717
718 #[test]
719 fn validate_ident_accepts_safe_names() {
720 assert!(validate_ident("slot_name", "slot_a1").is_ok());
721 assert!(validate_ident("output_plugin", "pgoutput").is_ok());
722 }
723
724 #[test]
725 fn parse_and_format_lsn_roundtrip() {
726 let lsn = parse_lsn_text("16/B6C50").unwrap();
727 assert_eq!(format_lsn(lsn), "16/000B6C50");
728 }
729
730 #[test]
731 fn build_create_logical_replication_slot_sql_variants() {
732 let sql =
733 build_create_logical_replication_slot_sql("slot_main", "pgoutput", true, true).unwrap();
734 assert_eq!(
735 sql,
736 "CREATE_REPLICATION_SLOT slot_main TEMPORARY LOGICAL pgoutput TWO_PHASE"
737 );
738 }
739
740 #[test]
741 fn build_drop_replication_slot_sql_variants() {
742 let sql = build_drop_replication_slot_sql("slot_main", true).unwrap();
743 assert_eq!(sql, "DROP_REPLICATION_SLOT slot_main WAIT");
744 }
745
746 #[test]
747 fn build_start_logical_replication_sql_with_options() {
748 let sql = build_start_logical_replication_sql(
749 "slot_main",
750 "0/16B6C50",
751 &[
752 ReplicationOption {
753 key: "proto_version".to_string(),
754 value: "1".to_string(),
755 },
756 ReplicationOption {
757 key: "publication_names".to_string(),
758 value: "pub1,pub2".to_string(),
759 },
760 ],
761 )
762 .unwrap();
763 assert_eq!(
764 sql,
765 "START_REPLICATION SLOT slot_main LOGICAL 0/16B6C50 (proto_version '1', publication_names 'pub1,pub2')"
766 );
767 }
768
769 #[test]
770 fn build_start_logical_replication_sql_rejects_too_many_options() {
771 let options: Vec<ReplicationOption> = (0..=MAX_REPLICATION_OPTIONS)
772 .map(|i| ReplicationOption {
773 key: format!("opt{}", i),
774 value: "x".to_string(),
775 })
776 .collect();
777
778 let err =
779 build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
780 assert!(err.to_string().contains("too many replication options"));
781 }
782
783 #[test]
784 fn build_start_logical_replication_sql_rejects_null_value() {
785 let options = vec![ReplicationOption {
786 key: "proto_version".to_string(),
787 value: "1\0oops".to_string(),
788 }];
789 let err =
790 build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
791 assert!(err.to_string().contains("contains NUL byte"));
792 }
793
794 #[test]
795 fn build_start_logical_replication_sql_rejects_oversized_value() {
796 let options = vec![ReplicationOption {
797 key: "publication_names".to_string(),
798 value: "x".repeat(MAX_REPLICATION_OPTION_VALUE_BYTES + 1),
799 }];
800 let err =
801 build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
802 assert!(err.to_string().contains("value too large"));
803 }
804
805 #[test]
806 fn parse_identify_system_row_happy_path() {
807 let row = text_row(&[
808 Some("7416469842679442267"),
809 Some("1"),
810 Some("0/16B6C50"),
811 Some("app"),
812 ]);
813 let parsed = parse_identify_system_row(&row).unwrap();
814 assert_eq!(parsed.system_id, "7416469842679442267");
815 assert_eq!(parsed.timeline, 1);
816 assert_eq!(parsed.xlog_pos, "0/16B6C50");
817 assert_eq!(parsed.dbname.as_deref(), Some("app"));
818 }
819
820 #[test]
821 fn parse_create_slot_row_happy_path() {
822 let row = text_row(&[
823 Some("slot_main"),
824 Some("0/16B6C88"),
825 Some("00000003-00000041-1"),
826 Some("pgoutput"),
827 ]);
828 let parsed = parse_create_slot_row(&row).unwrap();
829 assert_eq!(parsed.slot_name, "slot_main");
830 assert_eq!(parsed.consistent_point, "0/16B6C88");
831 assert_eq!(parsed.snapshot_name.as_deref(), Some("00000003-00000041-1"));
832 assert_eq!(parsed.output_plugin, "pgoutput");
833 }
834
835 #[test]
836 fn parse_copy_data_xlog_data() {
837 let mut payload = Vec::new();
838 payload.push(b'w');
839 payload.extend_from_slice(&0x10_u64.to_be_bytes());
840 payload.extend_from_slice(&0x20_u64.to_be_bytes());
841 payload.extend_from_slice(&123_i64.to_be_bytes());
842 payload.extend_from_slice(b"hello");
843
844 match parse_copy_data_message(&payload).unwrap() {
845 ReplicationStreamMessage::XLogData(x) => {
846 assert_eq!(x.wal_start, 0x10);
847 assert_eq!(x.wal_end, 0x20);
848 assert_eq!(x.server_time_micros, 123);
849 assert_eq!(x.data, b"hello");
850 }
851 _ => panic!("expected xlog data"),
852 }
853 }
854
855 #[test]
856 fn parse_copy_data_xlog_data_rejects_wal_end_before_start() {
857 let mut payload = Vec::new();
858 payload.push(b'w');
859 payload.extend_from_slice(&0x20_u64.to_be_bytes());
860 payload.extend_from_slice(&0x10_u64.to_be_bytes());
861 payload.extend_from_slice(&123_i64.to_be_bytes());
862 let err = parse_copy_data_message(&payload).unwrap_err();
863 assert!(err.to_string().contains("wal_end"));
864 }
865
866 #[test]
867 fn parse_copy_data_xlog_data_rejects_oversized_data() {
868 let mut payload = Vec::with_capacity(25 + MAX_REPLICATION_XLOGDATA_BYTES + 1);
869 payload.push(b'w');
870 payload.extend_from_slice(&0x10_u64.to_be_bytes());
871 payload.extend_from_slice(&0x20_u64.to_be_bytes());
872 payload.extend_from_slice(&123_i64.to_be_bytes());
873 payload.extend(std::iter::repeat_n(0u8, MAX_REPLICATION_XLOGDATA_BYTES + 1));
874 let err = parse_copy_data_message(&payload).unwrap_err();
875 assert!(err.to_string().contains("payload too large"));
876 }
877
878 #[test]
879 fn parse_copy_data_keepalive() {
880 let mut payload = Vec::new();
881 payload.push(b'k');
882 payload.extend_from_slice(&0xAB_u64.to_be_bytes());
883 payload.extend_from_slice(&456_i64.to_be_bytes());
884 payload.push(1);
885
886 match parse_copy_data_message(&payload).unwrap() {
887 ReplicationStreamMessage::Keepalive(k) => {
888 assert_eq!(k.wal_end, 0xAB);
889 assert_eq!(k.server_time_micros, 456);
890 assert!(k.reply_requested);
891 }
892 _ => panic!("expected keepalive"),
893 }
894 }
895
896 #[test]
897 fn parse_copy_data_keepalive_rejects_invalid_reply_requested() {
898 let mut payload = Vec::new();
899 payload.push(b'k');
900 payload.extend_from_slice(&0xAB_u64.to_be_bytes());
901 payload.extend_from_slice(&456_i64.to_be_bytes());
902 payload.push(2);
903 let err = parse_copy_data_message(&payload).unwrap_err();
904 assert!(err.to_string().contains("reply_requested must be 0 or 1"));
905 }
906
907 #[test]
908 fn parse_copy_data_unknown_tag_rejected() {
909 let payload = vec![b'x', 1, 2, 3];
910 let err = parse_copy_data_message(&payload).unwrap_err();
911 assert!(
912 err.to_string()
913 .contains("Unsupported replication CopyData tag")
914 );
915 }
916
917 #[test]
918 fn build_standby_status_update_payload_layout() {
919 let payload = build_standby_status_update_payload(1, 2, 3, true);
920 assert_eq!(payload.len(), 34);
921 assert_eq!(payload[0], b'r');
922 assert_eq!(u64::from_be_bytes(payload[1..9].try_into().unwrap()), 1);
923 assert_eq!(u64::from_be_bytes(payload[9..17].try_into().unwrap()), 2);
924 assert_eq!(u64::from_be_bytes(payload[17..25].try_into().unwrap()), 3);
925 assert_eq!(payload[33], 1);
926 }
927}