pg_walstream/
connection.rs

1//! Low-level PostgreSQL connection using libpq-sys
2//!
3//! This module provides safe wrappers around libpq functions for logical replication.
4//! It's an optional feature that requires the `libpq` feature flag.
5//!
6//! # Async I/O Architecture
7//!
8//! This module implements truly async, non-blocking I/O using tokio's `AsyncFd` wrapper
9//! around libpq's file descriptor. The key design principles are:
10//!
11//! - **Non-blocking socket operations**: Uses `AsyncFd::readable()` with proper drain pattern
12//!   to handle edge-triggered epoll notifications correctly
13//! - **Edge-triggered drain**: When the socket becomes readable, ALL available messages are
14//!   drained from libpq's buffer before clearing the ready flag, preventing message loss
15//! - **Thread release**: When waiting for data, the task is suspended and the thread is
16//!   released back to the executor to run other tasks, preventing thread pool starvation
17//! - **Cancellation-aware**: All async operations support cancellation tokens for graceful
18//!   shutdown without resource leaks
19//! - **Graceful COPY termination**: Properly detects and handles COPY stream end
20//!
21//! ## How it works
22//!
23//! 1. `get_copy_data_async()` first checks libpq's internal buffer (non-blocking)
24//! 2. If no data available, it awaits `AsyncFd::readable()` which yields the task
25//! 3. When the socket becomes readable, tokio wakes the task
26//! 4. The task calls `PQconsumeInput()` to transfer data from OS socket to libpq's buffer
27//! 5. **Critical**: It then drains ALL available messages in a loop before clearing ready flag
28//! 6. If no complete message yet, `clear_ready()` is called and the loop repeats
29//!
30//! This ensures that no thread is blocked waiting for network I/O, maximizing
31//! throughput and enabling efficient concurrent processing of multiple replication streams.
32use crate::buffer::BufferWriter;
33use crate::error::{ReplicationError, Result};
34use crate::types::{format_lsn, system_time_to_postgres_timestamp, XLogRecPtr};
35use libpq_sys::*;
36use std::ffi::{CStr, CString};
37use std::os::raw::c_void;
38use std::os::unix::io::RawFd;
39use std::time::SystemTime;
40use std::{ptr, slice};
41use tokio::io::unix::AsyncFd;
42use tokio_util::sync::CancellationToken;
43use tracing::{debug, info, warn};
44
45pub use crate::types::INVALID_XLOG_REC_PTR;
46
47/// Result of attempting to read from libpq's internal buffer
48#[derive(Debug)]
49enum ReadResult {
50    /// Successfully read complete data
51    Data(Vec<u8>),
52    /// No complete message available (would block)
53    WouldBlock,
54    /// COPY stream has ended gracefully
55    CopyDone,
56}
57
58/// Safe wrapper around PostgreSQL connection for replication
59///
60/// This struct provides a safe, high-level interface to libpq for PostgreSQL
61/// logical replication. It handles connection management, replication slot
62/// creation, and COPY protocol communication.
63///
64/// # Safety
65///
66/// This struct safely wraps the unsafe libpq C API. All unsafe operations
67/// are properly encapsulated and validated.
68///
69/// # Example
70///
71/// ```no_run
72/// use pg_walstream::PgReplicationConnection;
73///
74/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
75/// let mut conn = PgReplicationConnection::connect(
76///     "postgresql://postgres:password@localhost/mydb?replication=database"
77/// )?;
78///
79/// // Identify the system
80/// conn.identify_system()?;
81///
82/// // Create a replication slot
83/// conn.create_replication_slot("my_slot", "pgoutput")?;
84///
85/// // Start replication
86/// conn.start_replication("my_slot", 0, &[("proto_version", "2")])?
87/// # ; Ok(())
88/// # }
89/// ```
90pub struct PgReplicationConnection {
91    conn: *mut PGconn,
92    is_replication_conn: bool,
93    async_fd: Option<AsyncFd<RawFd>>,
94}
95
96impl PgReplicationConnection {
97    /// Create a new PostgreSQL connection for logical replication
98    ///
99    /// Establishes a connection to PostgreSQL using the provided connection string.
100    /// The connection string must include the `replication=database` parameter to
101    /// enable logical replication.
102    ///
103    /// # Arguments
104    ///
105    /// * `conninfo` - PostgreSQL connection string. Must include `replication=database`.
106    ///   Example: `"postgresql://user:pass@host:5432/dbname?replication=database"`
107    ///
108    /// # Returns
109    ///
110    /// Returns a new `PgReplicationConnection` if successful.
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if:
115    /// - Connection string is invalid
116    /// - Cannot connect to PostgreSQL server (transient or permanent)
117    /// - Authentication fails
118    /// - PostgreSQL version is too old (< 14.0)
119    ///
120    /// # Example
121    ///
122    /// ```no_run
123    /// use pg_walstream::PgReplicationConnection;
124    ///
125    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
126    /// let conn = PgReplicationConnection::connect(
127    ///     "postgresql://postgres:password@localhost:5432/mydb?replication=database"
128    /// )?;
129    /// # Ok(())
130    /// # }
131    /// ```
132    pub fn connect(conninfo: &str) -> Result<Self> {
133        // Ensure libpq is properly initialized
134        unsafe {
135            let library_version = PQlibVersion();
136            debug!("Using libpq version: {}", library_version);
137        }
138
139        let c_conninfo = CString::new(conninfo)
140            .map_err(|e| ReplicationError::connection(format!("Invalid connection string: {e}")))?;
141
142        let conn = unsafe { PQconnectdb(c_conninfo.as_ptr()) };
143
144        if conn.is_null() {
145            return Err(ReplicationError::transient_connection(
146                "Failed to allocate PostgreSQL connection object".to_string(),
147            ));
148        }
149
150        let status = unsafe { PQstatus(conn) };
151        if status != ConnStatusType::CONNECTION_OK {
152            let error_msg = unsafe {
153                let error_ptr = PQerrorMessage(conn);
154                if error_ptr.is_null() {
155                    "Unknown connection error".to_string()
156                } else {
157                    CStr::from_ptr(error_ptr).to_string_lossy().into_owned()
158                }
159            };
160            unsafe { PQfinish(conn) };
161
162            // Categorize the connection error
163            let error_msg_lower = error_msg.to_lowercase();
164            if error_msg_lower.contains("authentication failed")
165                || error_msg_lower.contains("password authentication failed")
166                || error_msg_lower.contains("role does not exist")
167            {
168                return Err(ReplicationError::authentication(format!(
169                    "PostgreSQL authentication failed: {error_msg}"
170                )));
171            } else if error_msg_lower.contains("database does not exist")
172                || error_msg_lower.contains("invalid connection string")
173                || error_msg_lower.contains("unsupported")
174            {
175                return Err(ReplicationError::permanent_connection(format!(
176                    "PostgreSQL connection failed (permanent): {error_msg}"
177                )));
178            } else {
179                return Err(ReplicationError::transient_connection(format!(
180                    "PostgreSQL connection failed (transient): {error_msg}"
181                )));
182            }
183        }
184
185        // Check server version - logical replication requires PostgreSQL 14+
186        let server_version = unsafe { PQserverVersion(conn) };
187        if server_version < 140000 {
188            unsafe { PQfinish(conn) };
189            return Err(ReplicationError::permanent_connection(format!(
190                "PostgreSQL version {server_version} is not supported. Logical replication requires PostgreSQL 14+"
191            )));
192        }
193
194        debug!("Connected to PostgreSQL server version: {}", server_version);
195
196        Ok(Self {
197            conn,
198            is_replication_conn: false,
199            async_fd: None,
200        })
201    }
202
203    /// Execute a replication command (like IDENTIFY_SYSTEM)
204    pub fn exec(&self, query: &str) -> Result<PgResult> {
205        let c_query = CString::new(query)
206            .map_err(|e| ReplicationError::protocol(format!("Invalid query string: {e}")))?;
207
208        let result = unsafe { PQexec(self.conn, c_query.as_ptr()) };
209
210        if result.is_null() {
211            return Err(ReplicationError::protocol(
212                "Query execution failed - null result".to_string(),
213            ));
214        }
215
216        let pg_result = PgResult::new(result);
217        // Check for errors
218        let status = pg_result.status();
219        info!(
220            "query : {} pg_result.status() : {:?}",
221            query,
222            pg_result.status()
223        );
224        if !matches!(
225            status,
226            ExecStatusType::PGRES_TUPLES_OK
227                | ExecStatusType::PGRES_COMMAND_OK
228                | ExecStatusType::PGRES_COPY_BOTH
229        ) {
230            let error_msg = pg_result
231                .error_message()
232                .unwrap_or_else(|| "Unknown error".to_string());
233            return Err(ReplicationError::protocol(format!(
234                "Query execution failed: {error_msg}"
235            )));
236        }
237
238        Ok(pg_result)
239    }
240
241    /// Send IDENTIFY_SYSTEM command
242    pub fn identify_system(&self) -> Result<PgResult> {
243        debug!("Sending IDENTIFY_SYSTEM command");
244        let result = self.exec("IDENTIFY_SYSTEM")?;
245
246        if result.ntuples() > 0 {
247            if let (Some(systemid), Some(timeline), Some(xlogpos)) = (
248                result.get_value(0, 0),
249                result.get_value(0, 1),
250                result.get_value(0, 2),
251            ) {
252                debug!(
253                    "System identification: systemid={}, timeline={}, xlogpos={}",
254                    systemid, timeline, xlogpos
255                );
256            }
257        }
258
259        Ok(result)
260    }
261
262    /// Create a replication slot
263    pub fn create_replication_slot(
264        &self,
265        slot_name: &str,
266        output_plugin: &str,
267    ) -> Result<PgResult> {
268        let create_slot_sql = format!(
269            "CREATE_REPLICATION_SLOT \"{slot_name}\" LOGICAL {output_plugin} NOEXPORT_SNAPSHOT;"
270        );
271
272        let result = self.exec(&create_slot_sql)?;
273
274        if result.ntuples() > 0 {
275            if let Some(slot_name_result) = result.get_value(0, 0) {
276                debug!("Replication slot created: {}", slot_name_result);
277            }
278        }
279
280        Ok(result)
281    }
282
283    /// Start logical replication
284    pub fn start_replication(
285        &mut self,
286        slot_name: &str,
287        start_lsn: XLogRecPtr,
288        options: &[(&str, &str)],
289    ) -> Result<()> {
290        let mut options_str = String::new();
291        for (i, (key, value)) in options.iter().enumerate() {
292            if i > 0 {
293                options_str.push_str(", ");
294            }
295            options_str.push_str(&format!("\"{key}\" '{value}'"));
296        }
297
298        let start_replication_sql = if start_lsn == INVALID_XLOG_REC_PTR {
299            format!("START_REPLICATION SLOT \"{slot_name}\" LOGICAL 0/0 ({options_str})")
300        } else {
301            format!(
302                "START_REPLICATION SLOT \"{}\" LOGICAL {} ({})",
303                slot_name,
304                format_lsn(start_lsn),
305                options_str
306            )
307        };
308
309        debug!("Starting replication: {}", start_replication_sql);
310        let _result = self.exec(&start_replication_sql)?;
311
312        self.is_replication_conn = true;
313
314        // Initialize async socket for non-blocking operations
315        self.initialize_async_socket()?;
316
317        debug!("Replication started successfully");
318        Ok(())
319    }
320
321    /// Send feedback to the server (standby status update)
322    pub fn send_standby_status_update(
323        &self,
324        received_lsn: XLogRecPtr,
325        flushed_lsn: XLogRecPtr,
326        applied_lsn: XLogRecPtr,
327        reply_requested: bool,
328    ) -> Result<()> {
329        if !self.is_replication_conn {
330            return Err(ReplicationError::protocol(
331                "Connection is not in replication mode".to_string(),
332            ));
333        }
334
335        let timestamp = system_time_to_postgres_timestamp(SystemTime::now());
336
337        // Build the standby status update message using BufferWriter
338        let mut buffer = BufferWriter::with_capacity(34); // 1 + 8 + 8 + 8 + 8 + 1
339
340        buffer.write_u8(b'r')?; // Message type
341        buffer.write_u64(received_lsn)?;
342        buffer.write_u64(flushed_lsn)?;
343        buffer.write_u64(applied_lsn)?;
344        buffer.write_i64(timestamp)?;
345        buffer.write_u8(if reply_requested { 1 } else { 0 })?;
346
347        let reply_data = buffer.freeze();
348
349        let result = unsafe {
350            PQputCopyData(
351                self.conn,
352                reply_data.as_ptr() as *const std::os::raw::c_char,
353                reply_data.len() as i32,
354            )
355        };
356
357        if result != 1 {
358            let error_msg = self.last_error_message();
359            return Err(ReplicationError::protocol(format!(
360                "Failed to send standby status update: {error_msg}"
361            )));
362        }
363
364        // Flush the connection
365        let flush_result = unsafe { PQflush(self.conn) };
366        if flush_result != 0 {
367            let error_msg = self.last_error_message();
368            return Err(ReplicationError::protocol(format!(
369                "Failed to flush connection: {error_msg}"
370            )));
371        }
372
373        info!(
374            "Sent standby status update: received={}, flushed={}, applied={}, reply_requested={}",
375            format_lsn(received_lsn),
376            format_lsn(flushed_lsn),
377            format_lsn(applied_lsn),
378            reply_requested
379        );
380
381        Ok(())
382    }
383
384    /// Initialize async socket for non-blocking operations
385    fn initialize_async_socket(&mut self) -> Result<()> {
386        let sock: RawFd = unsafe { PQsocket(self.conn) };
387        if sock < 0 {
388            return Err(ReplicationError::protocol(
389                "Invalid PostgreSQL socket".to_string(),
390            ));
391        }
392
393        let async_fd = AsyncFd::new(sock)
394            .map_err(|e| ReplicationError::protocol(format!("Failed to create AsyncFd: {e}")))?;
395
396        self.async_fd = Some(async_fd);
397        Ok(())
398    }
399
400    /// Get copy data from replication stream (truly async, non-blocking)
401    ///
402    /// This method uses proper async I/O patterns with AsyncFd to enable the tokio
403    /// scheduler to yield the task when no data is available, allowing other tasks
404    /// to run without blocking threads.
405    ///
406    /// # Arguments
407    /// * `cancellation_token` - Cancellation token to abort the operation
408    ///
409    /// # Returns
410    /// * `Ok(data)` - Successfully received data
411    /// * `Err(ReplicationError::Cancelled(_))` - Operation was cancelled or COPY stream ended
412    /// * `Err(_)` - Other errors occurred (connection issues, protocol errors)
413    pub async fn get_copy_data_async(
414        &mut self,
415        cancellation_token: &CancellationToken,
416    ) -> Result<Vec<u8>> {
417        if !self.is_replication_conn {
418            return Err(ReplicationError::protocol(
419                "Connection is not in replication mode".to_string(),
420            ));
421        }
422
423        let async_fd = self
424            .async_fd
425            .as_ref()
426            .ok_or_else(|| ReplicationError::protocol("AsyncFd not initialized".to_string()))?;
427
428        loop {
429            // First, try to read any buffered data without blocking
430            match self.try_read_buffered_data()? {
431                ReadResult::Data(data) => return Ok(data),
432                ReadResult::CopyDone => {
433                    debug!("COPY stream ended gracefully");
434                    return Err(ReplicationError::Cancelled("COPY stream ended".to_string()));
435                }
436                ReadResult::WouldBlock => {}
437            }
438
439            // If no buffered data, wait for either socket readability or cancellation
440            tokio::select! {
441                biased;
442
443                _ = cancellation_token.cancelled() => {
444                    debug!("Cancellation detected in get_copy_data_async");
445                    // Check one more time for buffered data before returning
446                    match self.try_read_buffered_data()? {
447                        ReadResult::Data(data) => {
448                            info!("Found buffered data after cancellation, returning it");
449                            return Ok(data);
450                        }
451                        ReadResult::CopyDone => {
452                            info!("Cancellation token triggered COPY stream ended during cancellation check");
453                            return Err(ReplicationError::Cancelled(
454                                "COPY stream ended".to_string(),
455                            ));
456                        }
457                        ReadResult::WouldBlock => {
458                            info!("Cancellation token triggered with no buffered data");
459                        }
460                    }
461                    return Err(ReplicationError::Cancelled("Operation cancelled".to_string()));
462                }
463
464                // Wait for socket to become readable
465                guard_result = async_fd.readable() => {
466                    let mut guard = guard_result.map_err(|e| {
467                        ReplicationError::protocol(format!("Failed to wait for socket readability: {e}"))
468                    })?;
469
470                    // Socket is readable - consume input from the OS socket
471                    // This is the ONLY place we call PQconsumeInput, avoiding busy-loops
472                    let consumed = unsafe { PQconsumeInput(self.conn) };
473                    if consumed == 0 {
474                        let error_msg = self.last_error_message();
475                        return Err(ReplicationError::protocol(format!(
476                            "PQconsumeInput failed: {error_msg}"
477                        )));
478                    }
479
480                    // Check if we got a complete message after consuming input.
481                    // If we got data, return it immediately (the guard drop will clear ready flag).
482                    // If no complete message yet, explicitly clear the ready flag to re-arm epoll.
483                    match self.try_read_buffered_data()? {
484                        ReadResult::Data(data) => {
485                            return Ok(data);
486                        }
487                        ReadResult::CopyDone => {
488                            debug!("COPY stream ended after consuming input");
489                            return Err(ReplicationError::Cancelled(
490                                "COPY stream ended".to_string(),
491                            ));
492                        }
493                        ReadResult::WouldBlock => {
494                            // No complete message available yet.
495                            // Clear the ready flag to re-arm epoll and continue waiting.
496                            guard.clear_ready();
497                        }
498                    }
499                }
500            }
501        }
502    }
503
504    /// Try to read copy data from libpq's internal buffer without consuming OS socket
505    /// This is a non-blocking operation that only checks libpq's internal buffer.
506    /// It should only be called after PQconsumeInput has been called to transfer
507    /// data from the OS socket to libpq's buffer.
508    ///
509    /// # Returns
510    /// * `Ok(ReadResult::Data(data))` - Complete message available in buffer
511    /// * `Ok(ReadResult::WouldBlock)` - No complete message yet, need to wait for more data
512    /// * `Ok(ReadResult::CopyDone)` - COPY stream has ended gracefully
513    /// * `Err(_)` - Protocol or buffer error
514    #[inline]
515    fn try_read_buffered_data(&self) -> Result<ReadResult> {
516        // PQgetCopyData with async=1 is already non-blocking, so we don't need PQisBusy check.
517        let mut buffer: *mut std::os::raw::c_char = ptr::null_mut();
518        let result = unsafe { PQgetCopyData(self.conn, &mut buffer, 1) };
519
520        match result {
521            len if len > 0 => {
522                if buffer.is_null() {
523                    return Err(ReplicationError::buffer(
524                        "Received null buffer from PQgetCopyData".to_string(),
525                    ));
526                }
527
528                let data =
529                    unsafe { slice::from_raw_parts(buffer as *const u8, len as usize).to_vec() };
530
531                // Free the buffer allocated by PostgreSQL
532                unsafe { PQfreemem(buffer as *mut c_void) };
533                Ok(ReadResult::Data(data))
534            }
535            0 | -2 => {
536                // 0 : According to libpq docs, 0 means async mode and no data ready
537                // 2 : No complete data available yet, would block
538                Ok(ReadResult::WouldBlock)
539            }
540            -1 => {
541                // COPY finished - this is a graceful shutdown signal
542                debug!("COPY stream finished (PQgetCopyData returned -1)");
543                Ok(ReadResult::CopyDone)
544            }
545            other => Err(ReplicationError::protocol(format!(
546                "Unexpected PQgetCopyData result: {other}"
547            ))),
548        }
549    }
550
551    /// Get the last error message from the connection
552    fn last_error_message(&self) -> String {
553        unsafe {
554            let error_ptr = PQerrorMessage(self.conn);
555            if error_ptr.is_null() {
556                "Unknown error".to_string()
557            } else {
558                CStr::from_ptr(error_ptr).to_string_lossy().into_owned()
559            }
560        }
561    }
562
563    /// Check if the connection is still alive
564    pub fn is_alive(&self) -> bool {
565        if self.conn.is_null() {
566            return false;
567        }
568
569        unsafe { PQstatus(self.conn) == ConnStatusType::CONNECTION_OK }
570    }
571
572    /// Get the server version
573    pub fn server_version(&self) -> i32 {
574        unsafe { PQserverVersion(self.conn) }
575    }
576
577    fn close_replication_connection(&mut self) {
578        if !self.conn.is_null() {
579            info!("Closing PostgreSQL replication connection");
580
581            // If we're in replication mode, try to end the copy gracefully
582            if self.is_replication_conn {
583                debug!("Ending COPY mode before closing connection");
584                unsafe {
585                    // Try to end the copy operation gracefully, This is important to properly close the replication stream
586                    let result = PQputCopyEnd(self.conn, ptr::null());
587                    if result != 1 {
588                        warn!(
589                            "Failed to end COPY mode gracefully: {}",
590                            self.last_error_message()
591                        );
592                    } else {
593                        debug!("COPY mode ended gracefully");
594                    }
595                }
596                self.is_replication_conn = false;
597            }
598
599            // Close the connection
600            unsafe {
601                PQfinish(self.conn);
602            }
603
604            // Clear the connection pointer and reset state
605            self.conn = std::ptr::null_mut();
606            self.async_fd = None;
607
608            info!("PostgreSQL replication connection closed and cleaned up");
609        } else {
610            info!("Connection already closed or was never initialized");
611        }
612    }
613}
614
615impl Drop for PgReplicationConnection {
616    fn drop(&mut self) {
617        self.close_replication_connection();
618    }
619}
620
621// Make the connection Send by ensuring exclusive access
622unsafe impl Send for PgReplicationConnection {}
623
624/// Safe wrapper for PostgreSQL result
625pub struct PgResult {
626    result: *mut PGresult,
627}
628
629impl PgResult {
630    fn new(result: *mut PGresult) -> Self {
631        Self { result }
632    }
633
634    /// Get the execution status
635    pub fn status(&self) -> ExecStatusType {
636        unsafe { PQresultStatus(self.result) }
637    }
638
639    /// Check if the result is OK
640    pub fn is_ok(&self) -> bool {
641        matches!(
642            self.status(),
643            ExecStatusType::PGRES_TUPLES_OK | ExecStatusType::PGRES_COMMAND_OK
644        )
645    }
646
647    /// Get number of tuples (rows)
648    pub fn ntuples(&self) -> i32 {
649        unsafe { PQntuples(self.result) }
650    }
651
652    /// Get number of fields (columns)
653    pub fn nfields(&self) -> i32 {
654        unsafe { PQnfields(self.result) }
655    }
656
657    /// Get a field value as string
658    pub fn get_value(&self, row: i32, col: i32) -> Option<String> {
659        if row >= self.ntuples() || col >= self.nfields() {
660            return None;
661        }
662
663        let value_ptr = unsafe { PQgetvalue(self.result, row, col) };
664        if value_ptr.is_null() {
665            None
666        } else {
667            unsafe { Some(CStr::from_ptr(value_ptr).to_string_lossy().into_owned()) }
668        }
669    }
670
671    /// Get error message if any
672    pub fn error_message(&self) -> Option<String> {
673        let error_ptr = unsafe { PQresultErrorMessage(self.result) };
674        if error_ptr.is_null() {
675            None
676        } else {
677            unsafe { Some(CStr::from_ptr(error_ptr).to_string_lossy().into_owned()) }
678        }
679    }
680}
681
682impl Drop for PgResult {
683    fn drop(&mut self) {
684        if !self.result.is_null() {
685            unsafe {
686                PQclear(self.result);
687            }
688        }
689    }
690}