1use crate::buffer::BufferWriter;
33use crate::error::{ReplicationError, Result};
34use crate::types::{format_lsn, system_time_to_postgres_timestamp, XLogRecPtr};
35use libpq_sys::*;
36use std::ffi::{CStr, CString};
37use std::os::raw::c_void;
38use std::os::unix::io::RawFd;
39use std::time::SystemTime;
40use std::{ptr, slice};
41use tokio::io::unix::AsyncFd;
42use tokio_util::sync::CancellationToken;
43use tracing::{debug, info, warn};
44
45pub use crate::types::INVALID_XLOG_REC_PTR;
46
47#[derive(Debug)]
49enum ReadResult {
50 Data(Vec<u8>),
52 WouldBlock,
54 CopyDone,
56}
57
58pub struct PgReplicationConnection {
91 conn: *mut PGconn,
92 is_replication_conn: bool,
93 async_fd: Option<AsyncFd<RawFd>>,
94}
95
96impl PgReplicationConnection {
97 pub fn connect(conninfo: &str) -> Result<Self> {
133 unsafe {
135 let library_version = PQlibVersion();
136 debug!("Using libpq version: {}", library_version);
137 }
138
139 let c_conninfo = CString::new(conninfo)
140 .map_err(|e| ReplicationError::connection(format!("Invalid connection string: {e}")))?;
141
142 let conn = unsafe { PQconnectdb(c_conninfo.as_ptr()) };
143
144 if conn.is_null() {
145 return Err(ReplicationError::transient_connection(
146 "Failed to allocate PostgreSQL connection object".to_string(),
147 ));
148 }
149
150 let status = unsafe { PQstatus(conn) };
151 if status != ConnStatusType::CONNECTION_OK {
152 let error_msg = unsafe {
153 let error_ptr = PQerrorMessage(conn);
154 if error_ptr.is_null() {
155 "Unknown connection error".to_string()
156 } else {
157 CStr::from_ptr(error_ptr).to_string_lossy().into_owned()
158 }
159 };
160 unsafe { PQfinish(conn) };
161
162 let error_msg_lower = error_msg.to_lowercase();
164 if error_msg_lower.contains("authentication failed")
165 || error_msg_lower.contains("password authentication failed")
166 || error_msg_lower.contains("role does not exist")
167 {
168 return Err(ReplicationError::authentication(format!(
169 "PostgreSQL authentication failed: {error_msg}"
170 )));
171 } else if error_msg_lower.contains("database does not exist")
172 || error_msg_lower.contains("invalid connection string")
173 || error_msg_lower.contains("unsupported")
174 {
175 return Err(ReplicationError::permanent_connection(format!(
176 "PostgreSQL connection failed (permanent): {error_msg}"
177 )));
178 } else {
179 return Err(ReplicationError::transient_connection(format!(
180 "PostgreSQL connection failed (transient): {error_msg}"
181 )));
182 }
183 }
184
185 let server_version = unsafe { PQserverVersion(conn) };
187 if server_version < 140000 {
188 unsafe { PQfinish(conn) };
189 return Err(ReplicationError::permanent_connection(format!(
190 "PostgreSQL version {server_version} is not supported. Logical replication requires PostgreSQL 14+"
191 )));
192 }
193
194 debug!("Connected to PostgreSQL server version: {}", server_version);
195
196 Ok(Self {
197 conn,
198 is_replication_conn: false,
199 async_fd: None,
200 })
201 }
202
203 pub fn exec(&self, query: &str) -> Result<PgResult> {
205 let c_query = CString::new(query)
206 .map_err(|e| ReplicationError::protocol(format!("Invalid query string: {e}")))?;
207
208 let result = unsafe { PQexec(self.conn, c_query.as_ptr()) };
209
210 if result.is_null() {
211 return Err(ReplicationError::protocol(
212 "Query execution failed - null result".to_string(),
213 ));
214 }
215
216 let pg_result = PgResult::new(result);
217 let status = pg_result.status();
219 info!(
220 "query : {} pg_result.status() : {:?}",
221 query,
222 pg_result.status()
223 );
224 if !matches!(
225 status,
226 ExecStatusType::PGRES_TUPLES_OK
227 | ExecStatusType::PGRES_COMMAND_OK
228 | ExecStatusType::PGRES_COPY_BOTH
229 ) {
230 let error_msg = pg_result
231 .error_message()
232 .unwrap_or_else(|| "Unknown error".to_string());
233 return Err(ReplicationError::protocol(format!(
234 "Query execution failed: {error_msg}"
235 )));
236 }
237
238 Ok(pg_result)
239 }
240
241 pub fn identify_system(&self) -> Result<PgResult> {
243 debug!("Sending IDENTIFY_SYSTEM command");
244 let result = self.exec("IDENTIFY_SYSTEM")?;
245
246 if result.ntuples() > 0 {
247 if let (Some(systemid), Some(timeline), Some(xlogpos)) = (
248 result.get_value(0, 0),
249 result.get_value(0, 1),
250 result.get_value(0, 2),
251 ) {
252 debug!(
253 "System identification: systemid={}, timeline={}, xlogpos={}",
254 systemid, timeline, xlogpos
255 );
256 }
257 }
258
259 Ok(result)
260 }
261
262 pub fn create_replication_slot(
264 &self,
265 slot_name: &str,
266 output_plugin: &str,
267 ) -> Result<PgResult> {
268 let create_slot_sql = format!(
269 "CREATE_REPLICATION_SLOT \"{slot_name}\" LOGICAL {output_plugin} NOEXPORT_SNAPSHOT;"
270 );
271
272 let result = self.exec(&create_slot_sql)?;
273
274 if result.ntuples() > 0 {
275 if let Some(slot_name_result) = result.get_value(0, 0) {
276 debug!("Replication slot created: {}", slot_name_result);
277 }
278 }
279
280 Ok(result)
281 }
282
283 pub fn start_replication(
285 &mut self,
286 slot_name: &str,
287 start_lsn: XLogRecPtr,
288 options: &[(&str, &str)],
289 ) -> Result<()> {
290 let mut options_str = String::new();
291 for (i, (key, value)) in options.iter().enumerate() {
292 if i > 0 {
293 options_str.push_str(", ");
294 }
295 options_str.push_str(&format!("\"{key}\" '{value}'"));
296 }
297
298 let start_replication_sql = if start_lsn == INVALID_XLOG_REC_PTR {
299 format!("START_REPLICATION SLOT \"{slot_name}\" LOGICAL 0/0 ({options_str})")
300 } else {
301 format!(
302 "START_REPLICATION SLOT \"{}\" LOGICAL {} ({})",
303 slot_name,
304 format_lsn(start_lsn),
305 options_str
306 )
307 };
308
309 debug!("Starting replication: {}", start_replication_sql);
310 let _result = self.exec(&start_replication_sql)?;
311
312 self.is_replication_conn = true;
313
314 self.initialize_async_socket()?;
316
317 debug!("Replication started successfully");
318 Ok(())
319 }
320
321 pub fn send_standby_status_update(
323 &self,
324 received_lsn: XLogRecPtr,
325 flushed_lsn: XLogRecPtr,
326 applied_lsn: XLogRecPtr,
327 reply_requested: bool,
328 ) -> Result<()> {
329 if !self.is_replication_conn {
330 return Err(ReplicationError::protocol(
331 "Connection is not in replication mode".to_string(),
332 ));
333 }
334
335 let timestamp = system_time_to_postgres_timestamp(SystemTime::now());
336
337 let mut buffer = BufferWriter::with_capacity(34); buffer.write_u8(b'r')?; buffer.write_u64(received_lsn)?;
342 buffer.write_u64(flushed_lsn)?;
343 buffer.write_u64(applied_lsn)?;
344 buffer.write_i64(timestamp)?;
345 buffer.write_u8(if reply_requested { 1 } else { 0 })?;
346
347 let reply_data = buffer.freeze();
348
349 let result = unsafe {
350 PQputCopyData(
351 self.conn,
352 reply_data.as_ptr() as *const std::os::raw::c_char,
353 reply_data.len() as i32,
354 )
355 };
356
357 if result != 1 {
358 let error_msg = self.last_error_message();
359 return Err(ReplicationError::protocol(format!(
360 "Failed to send standby status update: {error_msg}"
361 )));
362 }
363
364 let flush_result = unsafe { PQflush(self.conn) };
366 if flush_result != 0 {
367 let error_msg = self.last_error_message();
368 return Err(ReplicationError::protocol(format!(
369 "Failed to flush connection: {error_msg}"
370 )));
371 }
372
373 info!(
374 "Sent standby status update: received={}, flushed={}, applied={}, reply_requested={}",
375 format_lsn(received_lsn),
376 format_lsn(flushed_lsn),
377 format_lsn(applied_lsn),
378 reply_requested
379 );
380
381 Ok(())
382 }
383
384 fn initialize_async_socket(&mut self) -> Result<()> {
386 let sock: RawFd = unsafe { PQsocket(self.conn) };
387 if sock < 0 {
388 return Err(ReplicationError::protocol(
389 "Invalid PostgreSQL socket".to_string(),
390 ));
391 }
392
393 let async_fd = AsyncFd::new(sock)
394 .map_err(|e| ReplicationError::protocol(format!("Failed to create AsyncFd: {e}")))?;
395
396 self.async_fd = Some(async_fd);
397 Ok(())
398 }
399
400 pub async fn get_copy_data_async(
414 &mut self,
415 cancellation_token: &CancellationToken,
416 ) -> Result<Vec<u8>> {
417 if !self.is_replication_conn {
418 return Err(ReplicationError::protocol(
419 "Connection is not in replication mode".to_string(),
420 ));
421 }
422
423 let async_fd = self
424 .async_fd
425 .as_ref()
426 .ok_or_else(|| ReplicationError::protocol("AsyncFd not initialized".to_string()))?;
427
428 loop {
429 match self.try_read_buffered_data()? {
431 ReadResult::Data(data) => return Ok(data),
432 ReadResult::CopyDone => {
433 debug!("COPY stream ended gracefully");
434 return Err(ReplicationError::Cancelled("COPY stream ended".to_string()));
435 }
436 ReadResult::WouldBlock => {}
437 }
438
439 tokio::select! {
441 biased;
442
443 _ = cancellation_token.cancelled() => {
444 debug!("Cancellation detected in get_copy_data_async");
445 match self.try_read_buffered_data()? {
447 ReadResult::Data(data) => {
448 info!("Found buffered data after cancellation, returning it");
449 return Ok(data);
450 }
451 ReadResult::CopyDone => {
452 info!("Cancellation token triggered COPY stream ended during cancellation check");
453 return Err(ReplicationError::Cancelled(
454 "COPY stream ended".to_string(),
455 ));
456 }
457 ReadResult::WouldBlock => {
458 info!("Cancellation token triggered with no buffered data");
459 }
460 }
461 return Err(ReplicationError::Cancelled("Operation cancelled".to_string()));
462 }
463
464 guard_result = async_fd.readable() => {
466 let mut guard = guard_result.map_err(|e| {
467 ReplicationError::protocol(format!("Failed to wait for socket readability: {e}"))
468 })?;
469
470 let consumed = unsafe { PQconsumeInput(self.conn) };
473 if consumed == 0 {
474 let error_msg = self.last_error_message();
475 return Err(ReplicationError::protocol(format!(
476 "PQconsumeInput failed: {error_msg}"
477 )));
478 }
479
480 match self.try_read_buffered_data()? {
484 ReadResult::Data(data) => {
485 return Ok(data);
486 }
487 ReadResult::CopyDone => {
488 debug!("COPY stream ended after consuming input");
489 return Err(ReplicationError::Cancelled(
490 "COPY stream ended".to_string(),
491 ));
492 }
493 ReadResult::WouldBlock => {
494 guard.clear_ready();
497 }
498 }
499 }
500 }
501 }
502 }
503
504 #[inline]
515 fn try_read_buffered_data(&self) -> Result<ReadResult> {
516 let mut buffer: *mut std::os::raw::c_char = ptr::null_mut();
518 let result = unsafe { PQgetCopyData(self.conn, &mut buffer, 1) };
519
520 match result {
521 len if len > 0 => {
522 if buffer.is_null() {
523 return Err(ReplicationError::buffer(
524 "Received null buffer from PQgetCopyData".to_string(),
525 ));
526 }
527
528 let data =
529 unsafe { slice::from_raw_parts(buffer as *const u8, len as usize).to_vec() };
530
531 unsafe { PQfreemem(buffer as *mut c_void) };
533 Ok(ReadResult::Data(data))
534 }
535 0 | -2 => {
536 Ok(ReadResult::WouldBlock)
539 }
540 -1 => {
541 debug!("COPY stream finished (PQgetCopyData returned -1)");
543 Ok(ReadResult::CopyDone)
544 }
545 other => Err(ReplicationError::protocol(format!(
546 "Unexpected PQgetCopyData result: {other}"
547 ))),
548 }
549 }
550
551 fn last_error_message(&self) -> String {
553 unsafe {
554 let error_ptr = PQerrorMessage(self.conn);
555 if error_ptr.is_null() {
556 "Unknown error".to_string()
557 } else {
558 CStr::from_ptr(error_ptr).to_string_lossy().into_owned()
559 }
560 }
561 }
562
563 pub fn is_alive(&self) -> bool {
565 if self.conn.is_null() {
566 return false;
567 }
568
569 unsafe { PQstatus(self.conn) == ConnStatusType::CONNECTION_OK }
570 }
571
572 pub fn server_version(&self) -> i32 {
574 unsafe { PQserverVersion(self.conn) }
575 }
576
577 fn close_replication_connection(&mut self) {
578 if !self.conn.is_null() {
579 info!("Closing PostgreSQL replication connection");
580
581 if self.is_replication_conn {
583 debug!("Ending COPY mode before closing connection");
584 unsafe {
585 let result = PQputCopyEnd(self.conn, ptr::null());
587 if result != 1 {
588 warn!(
589 "Failed to end COPY mode gracefully: {}",
590 self.last_error_message()
591 );
592 } else {
593 debug!("COPY mode ended gracefully");
594 }
595 }
596 self.is_replication_conn = false;
597 }
598
599 unsafe {
601 PQfinish(self.conn);
602 }
603
604 self.conn = std::ptr::null_mut();
606 self.async_fd = None;
607
608 info!("PostgreSQL replication connection closed and cleaned up");
609 } else {
610 info!("Connection already closed or was never initialized");
611 }
612 }
613}
614
615impl Drop for PgReplicationConnection {
616 fn drop(&mut self) {
617 self.close_replication_connection();
618 }
619}
620
621unsafe impl Send for PgReplicationConnection {}
623
624pub struct PgResult {
626 result: *mut PGresult,
627}
628
629impl PgResult {
630 fn new(result: *mut PGresult) -> Self {
631 Self { result }
632 }
633
634 pub fn status(&self) -> ExecStatusType {
636 unsafe { PQresultStatus(self.result) }
637 }
638
639 pub fn is_ok(&self) -> bool {
641 matches!(
642 self.status(),
643 ExecStatusType::PGRES_TUPLES_OK | ExecStatusType::PGRES_COMMAND_OK
644 )
645 }
646
647 pub fn ntuples(&self) -> i32 {
649 unsafe { PQntuples(self.result) }
650 }
651
652 pub fn nfields(&self) -> i32 {
654 unsafe { PQnfields(self.result) }
655 }
656
657 pub fn get_value(&self, row: i32, col: i32) -> Option<String> {
659 if row >= self.ntuples() || col >= self.nfields() {
660 return None;
661 }
662
663 let value_ptr = unsafe { PQgetvalue(self.result, row, col) };
664 if value_ptr.is_null() {
665 None
666 } else {
667 unsafe { Some(CStr::from_ptr(value_ptr).to_string_lossy().into_owned()) }
668 }
669 }
670
671 pub fn error_message(&self) -> Option<String> {
673 let error_ptr = unsafe { PQresultErrorMessage(self.result) };
674 if error_ptr.is_null() {
675 None
676 } else {
677 unsafe { Some(CStr::from_ptr(error_ptr).to_string_lossy().into_owned()) }
678 }
679 }
680}
681
682impl Drop for PgResult {
683 fn drop(&mut self) {
684 if !self.result.is_null() {
685 unsafe {
686 PQclear(self.result);
687 }
688 }
689 }
690}