1use super::{
12 PgConnection, PgDriver, PgError, PgResult, PgRow, is_ignorable_session_message,
13 unexpected_backend_message,
14};
15use crate::protocol::{BackendMessage, PgEncoder};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct IdentifySystem {
20 pub system_id: String,
22 pub timeline: u32,
24 pub xlog_pos: String,
26 pub dbname: Option<String>,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct ReplicationSlotInfo {
33 pub slot_name: String,
35 pub consistent_point: String,
37 pub snapshot_name: Option<String>,
39 pub output_plugin: String,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct ReplicationStreamStart {
46 pub format: u8,
48 pub column_formats: Vec<u8>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct ReplicationXLogData {
55 pub wal_start: u64,
57 pub wal_end: u64,
59 pub server_time_micros: i64,
61 pub data: Vec<u8>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct ReplicationKeepalive {
68 pub wal_end: u64,
70 pub server_time_micros: i64,
72 pub reply_requested: bool,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum ReplicationStreamMessage {
79 XLogData(ReplicationXLogData),
81 Keepalive(ReplicationKeepalive),
83 Raw { tag: u8, payload: Vec<u8> },
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct ReplicationOption {
90 pub key: String,
92 pub value: String,
94}
95
96const MAX_REPLICATION_OPTIONS: usize = 64;
97const MAX_REPLICATION_OPTION_VALUE_BYTES: usize = 16 * 1024;
98const MAX_REPLICATION_XLOGDATA_BYTES: usize = 16 * 1024 * 1024;
99
100fn validate_ident(kind: &str, ident: &str) -> PgResult<()> {
101 if ident.is_empty() {
102 return Err(PgError::Query(format!("{} must not be empty", kind)));
103 }
104 if ident.len() > 63 {
105 return Err(PgError::Query(format!(
106 "{} '{}' exceeds PostgreSQL identifier limit (63 bytes)",
107 kind, ident
108 )));
109 }
110 let mut chars = ident.chars();
111 match chars.next() {
112 Some(c) if c == '_' || c.is_ascii_alphabetic() => {}
113 _ => {
114 return Err(PgError::Query(format!(
115 "{} '{}' must start with [A-Za-z_]",
116 kind, ident
117 )));
118 }
119 }
120 if !chars.all(|c| c == '_' || c.is_ascii_alphanumeric()) {
121 return Err(PgError::Query(format!(
122 "{} '{}' contains unsupported characters (allowed: [A-Za-z0-9_])",
123 kind, ident
124 )));
125 }
126 Ok(())
127}
128
129fn sql_single_quote_literal(value: &str) -> PgResult<String> {
130 if value.contains('\0') {
131 return Err(PgError::Query(
132 "replication option value contains NUL byte".to_string(),
133 ));
134 }
135 Ok(value.replace('\'', "''"))
136}
137
138fn parse_lsn_text(lsn: &str) -> PgResult<u64> {
139 let mut parts = lsn.split('/');
140 let high = parts
141 .next()
142 .ok_or_else(|| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
143 let low = parts
144 .next()
145 .ok_or_else(|| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
146 if parts.next().is_some() {
147 return Err(PgError::Query(format!("Invalid LSN '{}'", lsn)));
148 }
149 let high = u32::from_str_radix(high, 16)
150 .map_err(|_| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
151 let low = u32::from_str_radix(low, 16)
152 .map_err(|_| PgError::Query(format!("Invalid LSN '{}'", lsn)))?;
153 Ok(((high as u64) << 32) | (low as u64))
154}
155
156#[cfg(test)]
157fn format_lsn(lsn: u64) -> String {
158 format!("{:X}/{:08X}", (lsn >> 32) as u32, lsn as u32)
159}
160
161fn required_text(row: &PgRow, idx: usize, field: &str) -> PgResult<String> {
162 row.get_string(idx).ok_or_else(|| {
163 PgError::Protocol(format!("Missing or invalid '{}' in replication row", field))
164 })
165}
166
167fn parse_identify_system_row(row: &PgRow) -> PgResult<IdentifySystem> {
168 let system_id = required_text(row, 0, "systemid")?;
169 let timeline = required_text(row, 1, "timeline")?
170 .parse::<u32>()
171 .map_err(|e| PgError::Protocol(format!("Invalid timeline value: {}", e)))?;
172 let xlog_pos = required_text(row, 2, "xlogpos")?;
173 let dbname = row
174 .get_string(3)
175 .and_then(|v| if v.is_empty() { None } else { Some(v) });
176
177 Ok(IdentifySystem {
178 system_id,
179 timeline,
180 xlog_pos,
181 dbname,
182 })
183}
184
185fn parse_create_slot_row(row: &PgRow) -> PgResult<ReplicationSlotInfo> {
186 let slot_name = required_text(row, 0, "slot_name")?;
187 let consistent_point = required_text(row, 1, "consistent_point")?;
188 let snapshot_name = row
189 .get_string(2)
190 .and_then(|v| if v.is_empty() { None } else { Some(v) });
191 let output_plugin = required_text(row, 3, "output_plugin")?;
192
193 Ok(ReplicationSlotInfo {
194 slot_name,
195 consistent_point,
196 snapshot_name,
197 output_plugin,
198 })
199}
200
201fn build_create_logical_replication_slot_sql(
202 slot_name: &str,
203 output_plugin: &str,
204 temporary: bool,
205 two_phase: bool,
206) -> PgResult<String> {
207 validate_ident("slot_name", slot_name)?;
208 validate_ident("output_plugin", output_plugin)?;
209
210 let mut sql = String::from("CREATE_REPLICATION_SLOT ");
211 sql.push_str(slot_name);
212 if temporary {
213 sql.push_str(" TEMPORARY");
214 }
215 sql.push_str(" LOGICAL ");
216 sql.push_str(output_plugin);
217 if two_phase {
218 sql.push_str(" TWO_PHASE");
219 }
220 Ok(sql)
221}
222
223fn build_drop_replication_slot_sql(slot_name: &str, wait: bool) -> PgResult<String> {
224 validate_ident("slot_name", slot_name)?;
225 let mut sql = String::from("DROP_REPLICATION_SLOT ");
226 sql.push_str(slot_name);
227 if wait {
228 sql.push_str(" WAIT");
229 }
230 Ok(sql)
231}
232
233fn build_start_logical_replication_sql(
234 slot_name: &str,
235 start_lsn: &str,
236 options: &[ReplicationOption],
237) -> PgResult<String> {
238 validate_ident("slot_name", slot_name)?;
239 let _ = parse_lsn_text(start_lsn)?;
240 if options.len() > MAX_REPLICATION_OPTIONS {
241 return Err(PgError::Query(format!(
242 "too many replication options: {} (max {})",
243 options.len(),
244 MAX_REPLICATION_OPTIONS
245 )));
246 }
247
248 let mut sql = format!("START_REPLICATION SLOT {} LOGICAL {}", slot_name, start_lsn);
249 if !options.is_empty() {
250 sql.push_str(" (");
251 for (idx, opt) in options.iter().enumerate() {
252 validate_ident("replication option key", &opt.key)?;
253 if opt.value.len() > MAX_REPLICATION_OPTION_VALUE_BYTES {
254 return Err(PgError::Query(format!(
255 "replication option '{}' value too large: {} bytes (max {})",
256 opt.key,
257 opt.value.len(),
258 MAX_REPLICATION_OPTION_VALUE_BYTES
259 )));
260 }
261 if idx > 0 {
262 sql.push_str(", ");
263 }
264 sql.push_str(&opt.key);
265 sql.push_str(" '");
266 sql.push_str(&sql_single_quote_literal(&opt.value)?);
267 sql.push('\'');
268 }
269 sql.push(')');
270 }
271 Ok(sql)
272}
273
274fn parse_copy_data_message(payload: &[u8]) -> PgResult<ReplicationStreamMessage> {
275 if payload.is_empty() {
276 return Err(PgError::Protocol(
277 "Replication CopyData payload is empty".to_string(),
278 ));
279 }
280 match payload[0] {
281 b'w' => {
282 if payload.len() < 25 {
283 return Err(PgError::Protocol(format!(
284 "XLogData payload too short: {} bytes",
285 payload.len()
286 )));
287 }
288 let wal_start = u64::from_be_bytes(
289 payload[1..9]
290 .try_into()
291 .map_err(|_| PgError::Protocol("Invalid wal_start bytes".to_string()))?,
292 );
293 let wal_end = u64::from_be_bytes(
294 payload[9..17]
295 .try_into()
296 .map_err(|_| PgError::Protocol("Invalid wal_end bytes".to_string()))?,
297 );
298 let server_time_micros = i64::from_be_bytes(
299 payload[17..25]
300 .try_into()
301 .map_err(|_| PgError::Protocol("Invalid server time bytes".to_string()))?,
302 );
303 if wal_end < wal_start {
304 return Err(PgError::Protocol(format!(
305 "XLogData wal_end {} is behind wal_start {}",
306 wal_end, wal_start
307 )));
308 }
309 let data_len = payload.len() - 25;
310 if data_len > MAX_REPLICATION_XLOGDATA_BYTES {
311 return Err(PgError::Protocol(format!(
312 "XLogData payload too large: {} bytes (max {})",
313 data_len, MAX_REPLICATION_XLOGDATA_BYTES
314 )));
315 }
316 Ok(ReplicationStreamMessage::XLogData(ReplicationXLogData {
317 wal_start,
318 wal_end,
319 server_time_micros,
320 data: payload[25..].to_vec(),
321 }))
322 }
323 b'k' => {
324 if payload.len() != 18 {
325 return Err(PgError::Protocol(format!(
326 "Keepalive payload must be 18 bytes, got {}",
327 payload.len()
328 )));
329 }
330 let wal_end =
331 u64::from_be_bytes(payload[1..9].try_into().map_err(|_| {
332 PgError::Protocol("Invalid keepalive wal_end bytes".to_string())
333 })?);
334 let server_time_micros = i64::from_be_bytes(
335 payload[9..17]
336 .try_into()
337 .map_err(|_| PgError::Protocol("Invalid keepalive time bytes".to_string()))?,
338 );
339 let reply_requested = match payload[17] {
340 0 => false,
341 1 => true,
342 other => {
343 return Err(PgError::Protocol(format!(
344 "Keepalive reply_requested must be 0 or 1, got {}",
345 other
346 )));
347 }
348 };
349 Ok(ReplicationStreamMessage::Keepalive(ReplicationKeepalive {
350 wal_end,
351 server_time_micros,
352 reply_requested,
353 }))
354 }
355 tag => Err(PgError::Protocol(format!(
356 "Unsupported replication CopyData tag '{}'",
357 if tag.is_ascii_graphic() {
358 tag as char
359 } else {
360 '?'
361 }
362 ))),
363 }
364}
365
366fn postgres_epoch_micros_now() -> i64 {
367 const PG_UNIX_EPOCH_DIFF_SECS: i64 = 946_684_800;
368 let now = std::time::SystemTime::now()
369 .duration_since(std::time::UNIX_EPOCH)
370 .unwrap_or_default();
371 let unix_micros = (now.as_secs() as i64) * 1_000_000 + i64::from(now.subsec_micros());
372 unix_micros - PG_UNIX_EPOCH_DIFF_SECS * 1_000_000
373}
374
375fn build_standby_status_update_payload(
376 write_lsn: u64,
377 flush_lsn: u64,
378 apply_lsn: u64,
379 reply_requested: bool,
380) -> Vec<u8> {
381 let mut payload = Vec::with_capacity(1 + 8 + 8 + 8 + 8 + 1);
382 payload.push(b'r');
383 payload.extend_from_slice(&write_lsn.to_be_bytes());
384 payload.extend_from_slice(&flush_lsn.to_be_bytes());
385 payload.extend_from_slice(&apply_lsn.to_be_bytes());
386 payload.extend_from_slice(&postgres_epoch_micros_now().to_be_bytes());
387 payload.push(if reply_requested { 1 } else { 0 });
388 payload
389}
390
391#[inline]
392fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
393 if matches!(
394 err,
395 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
396 ) {
397 conn.mark_io_desynced();
398 }
399 Err(err)
400}
401
402impl PgConnection {
403 #[inline]
404 fn ensure_replication_mode(&self, operation: &str) -> PgResult<()> {
405 if self.replication_mode_enabled {
406 return Ok(());
407 }
408 Err(PgError::Protocol(format!(
409 "{} requires connection startup param replication=database",
410 operation
411 )))
412 }
413
414 #[inline]
415 fn ensure_replication_control_idle(&self, operation: &str) -> PgResult<()> {
416 if !self.replication_stream_active {
417 return Ok(());
418 }
419 Err(PgError::Protocol(format!(
420 "{} cannot run while replication stream is active",
421 operation
422 )))
423 }
424
425 #[inline]
426 fn advance_replication_wal_end(&mut self, source: &str, wal_end: u64) -> PgResult<()> {
427 if let Some(prev_wal_end) = self.last_replication_wal_end
428 && wal_end < prev_wal_end
429 {
430 self.replication_stream_active = false;
431 self.last_replication_wal_end = None;
432 return return_with_desync(
433 self,
434 PgError::Protocol(format!(
435 "Replication {} wal_end regressed: previous {}, current {}",
436 source, prev_wal_end, wal_end
437 )),
438 );
439 }
440 self.last_replication_wal_end = Some(wal_end);
441 Ok(())
442 }
443
444 pub async fn identify_system(&mut self) -> PgResult<IdentifySystem> {
446 self.ensure_replication_mode("IDENTIFY_SYSTEM")?;
447 self.ensure_replication_control_idle("IDENTIFY_SYSTEM")?;
448 let rows = self.simple_query("IDENTIFY_SYSTEM").await?;
449 let row = rows
450 .first()
451 .ok_or_else(|| PgError::Protocol("IDENTIFY_SYSTEM returned no rows".to_string()))?;
452 parse_identify_system_row(row)
453 }
454
455 pub async fn create_logical_replication_slot(
459 &mut self,
460 slot_name: &str,
461 output_plugin: &str,
462 temporary: bool,
463 two_phase: bool,
464 ) -> PgResult<ReplicationSlotInfo> {
465 self.ensure_replication_mode("CREATE_REPLICATION_SLOT")?;
466 self.ensure_replication_control_idle("CREATE_REPLICATION_SLOT")?;
467 let sql = build_create_logical_replication_slot_sql(
468 slot_name,
469 output_plugin,
470 temporary,
471 two_phase,
472 )?;
473 let rows = self.simple_query(&sql).await?;
474 let row = rows.first().ok_or_else(|| {
475 PgError::Protocol("CREATE_REPLICATION_SLOT returned no rows".to_string())
476 })?;
477 parse_create_slot_row(row)
478 }
479
480 pub async fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> PgResult<()> {
484 self.ensure_replication_mode("DROP_REPLICATION_SLOT")?;
485 self.ensure_replication_control_idle("DROP_REPLICATION_SLOT")?;
486 let sql = build_drop_replication_slot_sql(slot_name, wait)?;
487 self.execute_simple(&sql).await
488 }
489
490 pub async fn start_logical_replication(
494 &mut self,
495 slot_name: &str,
496 start_lsn: &str,
497 options: &[ReplicationOption],
498 ) -> PgResult<ReplicationStreamStart> {
499 self.ensure_replication_mode("START_REPLICATION")?;
500 if self.replication_stream_active {
501 return Err(PgError::Protocol(
502 "logical replication stream already active".to_string(),
503 ));
504 }
505 let sql = build_start_logical_replication_sql(slot_name, start_lsn, options)?;
506 let bytes = PgEncoder::try_encode_query_string(&sql)?;
507 self.write_all_with_timeout(&bytes, "stream write").await?;
508
509 let mut startup_error: Option<PgError> = None;
510 loop {
511 let msg = self.recv().await?;
512 match msg {
513 BackendMessage::CopyBothResponse {
514 format,
515 column_formats,
516 } => {
517 if let Some(err) = startup_error {
518 return return_with_desync(self, err);
519 }
520 if format != 0 {
521 return return_with_desync(
522 self,
523 PgError::Protocol(format!(
524 "START_REPLICATION returned unsupported CopyBothResponse format {} (expected 0/text)",
525 format
526 )),
527 );
528 }
529 if !column_formats.is_empty() {
530 return return_with_desync(
531 self,
532 PgError::Protocol(format!(
533 "START_REPLICATION returned unexpected CopyBothResponse column formats (expected none, got {})",
534 column_formats.len()
535 )),
536 );
537 }
538 self.replication_stream_active = true;
539 self.last_replication_wal_end = None;
540 return Ok(ReplicationStreamStart {
541 format,
542 column_formats,
543 });
544 }
545 BackendMessage::ReadyForQuery(_) => {
546 return return_with_desync(
547 self,
548 startup_error.unwrap_or_else(|| {
549 PgError::Protocol(
550 "START_REPLICATION ended before CopyBothResponse".to_string(),
551 )
552 }),
553 );
554 }
555 BackendMessage::ErrorResponse(err) => {
556 if startup_error.is_none() {
557 startup_error = Some(PgError::QueryServer(err.into()));
558 }
559 }
560 msg if is_ignorable_session_message(&msg) => {}
561 other => {
562 return return_with_desync(
563 self,
564 unexpected_backend_message("start replication", &other),
565 );
566 }
567 }
568 }
569 }
570
571 pub async fn recv_replication_message(&mut self) -> PgResult<ReplicationStreamMessage> {
575 self.ensure_replication_mode("recv_replication_message")?;
576 if !self.replication_stream_active {
577 return Err(PgError::Protocol(
578 "replication stream is not active; call START_REPLICATION first".to_string(),
579 ));
580 }
581 loop {
582 let msg = self.recv_without_timeout().await?;
583 match msg {
584 BackendMessage::CopyData(payload) => match parse_copy_data_message(&payload) {
585 Ok(ReplicationStreamMessage::XLogData(x)) => {
586 self.advance_replication_wal_end("XLogData", x.wal_end)?;
587 return Ok(ReplicationStreamMessage::XLogData(x));
588 }
589 Ok(ReplicationStreamMessage::Keepalive(k)) => {
590 self.advance_replication_wal_end("keepalive", k.wal_end)?;
591 return Ok(ReplicationStreamMessage::Keepalive(k));
592 }
593 Ok(parsed) => return Ok(parsed),
594 Err(err) => {
595 self.replication_stream_active = false;
596 self.last_replication_wal_end = None;
597 return return_with_desync(self, err);
598 }
599 },
600 BackendMessage::ErrorResponse(err) => {
601 self.replication_stream_active = false;
602 self.last_replication_wal_end = None;
603 return Err(PgError::QueryServer(err.into()));
604 }
605 BackendMessage::CopyDone => {
606 self.replication_stream_active = false;
607 self.last_replication_wal_end = None;
608 return return_with_desync(
609 self,
610 PgError::Protocol("Replication stream ended with CopyDone".to_string()),
611 );
612 }
613 BackendMessage::ReadyForQuery(_) => {
614 self.replication_stream_active = false;
615 self.last_replication_wal_end = None;
616 return return_with_desync(
617 self,
618 PgError::Protocol(
619 "Replication stream ended with ReadyForQuery".to_string(),
620 ),
621 );
622 }
623 msg if is_ignorable_session_message(&msg) => {}
624 other => {
625 self.replication_stream_active = false;
626 self.last_replication_wal_end = None;
627 return return_with_desync(
628 self,
629 unexpected_backend_message("replication stream", &other),
630 );
631 }
632 }
633 }
634 }
635
636 pub async fn send_standby_status_update(
638 &mut self,
639 write_lsn: u64,
640 flush_lsn: u64,
641 apply_lsn: u64,
642 reply_requested: bool,
643 ) -> PgResult<()> {
644 self.ensure_replication_mode("send_standby_status_update")?;
645 if !self.replication_stream_active {
646 return Err(PgError::Protocol(
647 "replication stream is not active; call START_REPLICATION first".to_string(),
648 ));
649 }
650 if flush_lsn > write_lsn {
651 return Err(PgError::Protocol(format!(
652 "Invalid standby status update: flush_lsn {} exceeds write_lsn {}",
653 flush_lsn, write_lsn
654 )));
655 }
656 if apply_lsn > flush_lsn {
657 return Err(PgError::Protocol(format!(
658 "Invalid standby status update: apply_lsn {} exceeds flush_lsn {}",
659 apply_lsn, flush_lsn
660 )));
661 }
662 if let Some(last_wal_end) = self.last_replication_wal_end
663 && write_lsn > last_wal_end
664 {
665 return Err(PgError::Protocol(format!(
666 "Invalid standby status update: write_lsn {} exceeds last seen server wal_end {}",
667 write_lsn, last_wal_end
668 )));
669 }
670 let payload =
671 build_standby_status_update_payload(write_lsn, flush_lsn, apply_lsn, reply_requested);
672 self.send_copy_data(&payload).await
673 }
674}
675
676impl PgDriver {
677 pub async fn identify_system(&mut self) -> PgResult<IdentifySystem> {
679 self.connection.identify_system().await
680 }
681
682 pub async fn create_logical_replication_slot(
684 &mut self,
685 slot_name: &str,
686 output_plugin: &str,
687 temporary: bool,
688 two_phase: bool,
689 ) -> PgResult<ReplicationSlotInfo> {
690 self.connection
691 .create_logical_replication_slot(slot_name, output_plugin, temporary, two_phase)
692 .await
693 }
694
695 pub async fn drop_replication_slot(&mut self, slot_name: &str, wait: bool) -> PgResult<()> {
697 self.connection.drop_replication_slot(slot_name, wait).await
698 }
699
700 pub async fn start_logical_replication(
702 &mut self,
703 slot_name: &str,
704 start_lsn: &str,
705 options: &[ReplicationOption],
706 ) -> PgResult<ReplicationStreamStart> {
707 self.connection
708 .start_logical_replication(slot_name, start_lsn, options)
709 .await
710 }
711
712 pub async fn recv_replication_message(&mut self) -> PgResult<ReplicationStreamMessage> {
714 self.connection.recv_replication_message().await
715 }
716
717 pub async fn send_standby_status_update(
719 &mut self,
720 write_lsn: u64,
721 flush_lsn: u64,
722 apply_lsn: u64,
723 reply_requested: bool,
724 ) -> PgResult<()> {
725 self.connection
726 .send_standby_status_update(write_lsn, flush_lsn, apply_lsn, reply_requested)
727 .await
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734
735 #[cfg(unix)]
736 fn test_conn() -> PgConnection {
737 use crate::driver::connection::StatementCache;
738 use crate::driver::stream::PgStream;
739 use bytes::BytesMut;
740 use std::collections::{HashMap, VecDeque};
741 use std::num::NonZeroUsize;
742 use tokio::net::UnixStream;
743
744 let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
745 PgConnection {
746 stream: PgStream::Unix(unix_stream),
747 buffer: BytesMut::with_capacity(1024),
748 write_buf: BytesMut::with_capacity(1024),
749 sql_buf: BytesMut::with_capacity(256),
750 params_buf: Vec::new(),
751 prepared_statements: HashMap::new(),
752 stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
753 column_info_cache: HashMap::new(),
754 process_id: 0,
755 cancel_key_bytes: Vec::new(),
756 requested_protocol_minor: PgConnection::default_protocol_minor(),
757 negotiated_protocol_minor: PgConnection::default_protocol_minor(),
758 notifications: VecDeque::new(),
759 replication_stream_active: false,
760 replication_mode_enabled: true,
761 last_replication_wal_end: None,
762 io_desynced: false,
763 pending_statement_closes: Vec::new(),
764 draining_statement_closes: false,
765 }
766 }
767
768 fn text_row(values: &[Option<&str>]) -> PgRow {
769 PgRow {
770 columns: values
771 .iter()
772 .map(|v| v.map(|s| s.as_bytes().to_vec()))
773 .collect(),
774 column_info: None,
775 }
776 }
777
778 #[cfg(unix)]
779 #[tokio::test]
780 async fn replication_wal_regression_marks_connection_desynced() {
781 let mut conn = test_conn();
782 conn.replication_stream_active = true;
783 conn.last_replication_wal_end = Some(20);
784
785 let err = conn
786 .advance_replication_wal_end("XLogData", 10)
787 .expect_err("wal regression must fail");
788
789 assert!(err.to_string().contains("wal_end regressed"));
790 assert!(!conn.replication_stream_active);
791 assert!(conn.last_replication_wal_end.is_none());
792 assert!(conn.is_io_desynced());
793 }
794
795 #[test]
796 fn validate_ident_rejects_bad_names() {
797 assert!(validate_ident("slot_name", "").is_err());
798 assert!(validate_ident("slot_name", "9slot").is_err());
799 assert!(validate_ident("slot_name", "bad-name").is_err());
800 assert!(validate_ident("slot_name", "has space").is_err());
801 }
802
803 #[test]
804 fn validate_ident_accepts_safe_names() {
805 assert!(validate_ident("slot_name", "slot_a1").is_ok());
806 assert!(validate_ident("output_plugin", "pgoutput").is_ok());
807 }
808
809 #[test]
810 fn parse_and_format_lsn_roundtrip() {
811 let lsn = parse_lsn_text("16/B6C50").unwrap();
812 assert_eq!(format_lsn(lsn), "16/000B6C50");
813 }
814
815 #[test]
816 fn build_create_logical_replication_slot_sql_variants() {
817 let sql =
818 build_create_logical_replication_slot_sql("slot_main", "pgoutput", true, true).unwrap();
819 assert_eq!(
820 sql,
821 "CREATE_REPLICATION_SLOT slot_main TEMPORARY LOGICAL pgoutput TWO_PHASE"
822 );
823 }
824
825 #[test]
826 fn build_drop_replication_slot_sql_variants() {
827 let sql = build_drop_replication_slot_sql("slot_main", true).unwrap();
828 assert_eq!(sql, "DROP_REPLICATION_SLOT slot_main WAIT");
829 }
830
831 #[test]
832 fn build_start_logical_replication_sql_with_options() {
833 let sql = build_start_logical_replication_sql(
834 "slot_main",
835 "0/16B6C50",
836 &[
837 ReplicationOption {
838 key: "proto_version".to_string(),
839 value: "1".to_string(),
840 },
841 ReplicationOption {
842 key: "publication_names".to_string(),
843 value: "pub1,pub2".to_string(),
844 },
845 ],
846 )
847 .unwrap();
848 assert_eq!(
849 sql,
850 "START_REPLICATION SLOT slot_main LOGICAL 0/16B6C50 (proto_version '1', publication_names 'pub1,pub2')"
851 );
852 }
853
854 #[test]
855 fn build_start_logical_replication_sql_rejects_too_many_options() {
856 let options: Vec<ReplicationOption> = (0..=MAX_REPLICATION_OPTIONS)
857 .map(|i| ReplicationOption {
858 key: format!("opt{}", i),
859 value: "x".to_string(),
860 })
861 .collect();
862
863 let err =
864 build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
865 assert!(err.to_string().contains("too many replication options"));
866 }
867
868 #[test]
869 fn build_start_logical_replication_sql_rejects_null_value() {
870 let options = vec![ReplicationOption {
871 key: "proto_version".to_string(),
872 value: "1\0oops".to_string(),
873 }];
874 let err =
875 build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
876 assert!(err.to_string().contains("contains NUL byte"));
877 }
878
879 #[test]
880 fn build_start_logical_replication_sql_rejects_oversized_value() {
881 let options = vec![ReplicationOption {
882 key: "publication_names".to_string(),
883 value: "x".repeat(MAX_REPLICATION_OPTION_VALUE_BYTES + 1),
884 }];
885 let err =
886 build_start_logical_replication_sql("slot_main", "0/16B6C50", &options).unwrap_err();
887 assert!(err.to_string().contains("value too large"));
888 }
889
890 #[test]
891 fn parse_identify_system_row_happy_path() {
892 let row = text_row(&[
893 Some("7416469842679442267"),
894 Some("1"),
895 Some("0/16B6C50"),
896 Some("app"),
897 ]);
898 let parsed = parse_identify_system_row(&row).unwrap();
899 assert_eq!(parsed.system_id, "7416469842679442267");
900 assert_eq!(parsed.timeline, 1);
901 assert_eq!(parsed.xlog_pos, "0/16B6C50");
902 assert_eq!(parsed.dbname.as_deref(), Some("app"));
903 }
904
905 #[test]
906 fn parse_create_slot_row_happy_path() {
907 let row = text_row(&[
908 Some("slot_main"),
909 Some("0/16B6C88"),
910 Some("00000003-00000041-1"),
911 Some("pgoutput"),
912 ]);
913 let parsed = parse_create_slot_row(&row).unwrap();
914 assert_eq!(parsed.slot_name, "slot_main");
915 assert_eq!(parsed.consistent_point, "0/16B6C88");
916 assert_eq!(parsed.snapshot_name.as_deref(), Some("00000003-00000041-1"));
917 assert_eq!(parsed.output_plugin, "pgoutput");
918 }
919
920 #[test]
921 fn parse_copy_data_xlog_data() {
922 let mut payload = Vec::new();
923 payload.push(b'w');
924 payload.extend_from_slice(&0x10_u64.to_be_bytes());
925 payload.extend_from_slice(&0x20_u64.to_be_bytes());
926 payload.extend_from_slice(&123_i64.to_be_bytes());
927 payload.extend_from_slice(b"hello");
928
929 match parse_copy_data_message(&payload).unwrap() {
930 ReplicationStreamMessage::XLogData(x) => {
931 assert_eq!(x.wal_start, 0x10);
932 assert_eq!(x.wal_end, 0x20);
933 assert_eq!(x.server_time_micros, 123);
934 assert_eq!(x.data, b"hello");
935 }
936 _ => panic!("expected xlog data"),
937 }
938 }
939
940 #[test]
941 fn parse_copy_data_xlog_data_rejects_wal_end_before_start() {
942 let mut payload = Vec::new();
943 payload.push(b'w');
944 payload.extend_from_slice(&0x20_u64.to_be_bytes());
945 payload.extend_from_slice(&0x10_u64.to_be_bytes());
946 payload.extend_from_slice(&123_i64.to_be_bytes());
947 let err = parse_copy_data_message(&payload).unwrap_err();
948 assert!(err.to_string().contains("wal_end"));
949 }
950
951 #[test]
952 fn parse_copy_data_xlog_data_rejects_oversized_data() {
953 let mut payload = Vec::with_capacity(25 + MAX_REPLICATION_XLOGDATA_BYTES + 1);
954 payload.push(b'w');
955 payload.extend_from_slice(&0x10_u64.to_be_bytes());
956 payload.extend_from_slice(&0x20_u64.to_be_bytes());
957 payload.extend_from_slice(&123_i64.to_be_bytes());
958 payload.extend(std::iter::repeat_n(0u8, MAX_REPLICATION_XLOGDATA_BYTES + 1));
959 let err = parse_copy_data_message(&payload).unwrap_err();
960 assert!(err.to_string().contains("payload too large"));
961 }
962
963 #[test]
964 fn parse_copy_data_keepalive() {
965 let mut payload = Vec::new();
966 payload.push(b'k');
967 payload.extend_from_slice(&0xAB_u64.to_be_bytes());
968 payload.extend_from_slice(&456_i64.to_be_bytes());
969 payload.push(1);
970
971 match parse_copy_data_message(&payload).unwrap() {
972 ReplicationStreamMessage::Keepalive(k) => {
973 assert_eq!(k.wal_end, 0xAB);
974 assert_eq!(k.server_time_micros, 456);
975 assert!(k.reply_requested);
976 }
977 _ => panic!("expected keepalive"),
978 }
979 }
980
981 #[test]
982 fn parse_copy_data_keepalive_rejects_invalid_reply_requested() {
983 let mut payload = Vec::new();
984 payload.push(b'k');
985 payload.extend_from_slice(&0xAB_u64.to_be_bytes());
986 payload.extend_from_slice(&456_i64.to_be_bytes());
987 payload.push(2);
988 let err = parse_copy_data_message(&payload).unwrap_err();
989 assert!(err.to_string().contains("reply_requested must be 0 or 1"));
990 }
991
992 #[test]
993 fn parse_copy_data_unknown_tag_rejected() {
994 let payload = vec![b'x', 1, 2, 3];
995 let err = parse_copy_data_message(&payload).unwrap_err();
996 assert!(
997 err.to_string()
998 .contains("Unsupported replication CopyData tag")
999 );
1000 }
1001
1002 #[test]
1003 fn build_standby_status_update_payload_layout() {
1004 let payload = build_standby_status_update_payload(1, 2, 3, true);
1005 assert_eq!(payload.len(), 34);
1006 assert_eq!(payload[0], b'r');
1007 assert_eq!(u64::from_be_bytes(payload[1..9].try_into().unwrap()), 1);
1008 assert_eq!(u64::from_be_bytes(payload[9..17].try_into().unwrap()), 2);
1009 assert_eq!(u64::from_be_bytes(payload[17..25].try_into().unwrap()), 3);
1010 assert_eq!(payload[33], 1);
1011 }
1012}