pg_walstream/
connection.rs

1//! Low-level PostgreSQL connection using libpq-sys
2//!
3//! This module provides safe wrappers around libpq functions for logical replication.
4//! It's an optional feature that requires the `libpq` feature flag.
5use crate::buffer::BufferWriter;
6use crate::error::{ReplicationError, Result};
7use crate::types::{format_lsn, system_time_to_postgres_timestamp, XLogRecPtr};
8use libpq_sys::*;
9use std::ffi::{CStr, CString};
10use std::os::unix::io::RawFd;
11use std::ptr;
12use std::time::SystemTime;
13use tokio::io::unix::AsyncFd;
14use tokio_util::sync::CancellationToken;
15use tracing::{debug, info, warn};
16
17pub use crate::types::INVALID_XLOG_REC_PTR;
18
19/// Safe wrapper around PostgreSQL connection for replication
20///
21/// This struct provides a safe, high-level interface to libpq for PostgreSQL
22/// logical replication. It handles connection management, replication slot
23/// creation, and COPY protocol communication.
24///
25/// # Safety
26///
27/// This struct safely wraps the unsafe libpq C API. All unsafe operations
28/// are properly encapsulated and validated.
29///
30/// # Example
31///
32/// ```no_run
33/// use pg_walstream::PgReplicationConnection;
34///
35/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
36/// let mut conn = PgReplicationConnection::connect(
37///     "postgresql://postgres:password@localhost/mydb?replication=database"
38/// )?;
39///
40/// // Identify the system
41/// conn.identify_system()?;
42///
43/// // Create a replication slot
44/// conn.create_replication_slot("my_slot", "pgoutput")?;
45///
46/// // Start replication
47/// conn.start_replication("my_slot", 0, &[("proto_version", "2")])?
48/// # ; Ok(())
49/// # }
50/// ```
51pub struct PgReplicationConnection {
52    conn: *mut PGconn,
53    is_replication_conn: bool,
54    async_fd: Option<AsyncFd<RawFd>>,
55}
56
57impl PgReplicationConnection {
58    /// Create a new PostgreSQL connection for logical replication
59    ///
60    /// Establishes a connection to PostgreSQL using the provided connection string.
61    /// The connection string must include the `replication=database` parameter to
62    /// enable logical replication.
63    ///
64    /// # Arguments
65    ///
66    /// * `conninfo` - PostgreSQL connection string. Must include `replication=database`.
67    ///   Example: `"postgresql://user:pass@host:5432/dbname?replication=database"`
68    ///
69    /// # Returns
70    ///
71    /// Returns a new `PgReplicationConnection` if successful.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if:
76    /// - Connection string is invalid
77    /// - Cannot connect to PostgreSQL server (transient or permanent)
78    /// - Authentication fails
79    /// - PostgreSQL version is too old (< 14.0)
80    ///
81    /// # Example
82    ///
83    /// ```no_run
84    /// use pg_walstream::PgReplicationConnection;
85    ///
86    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
87    /// let conn = PgReplicationConnection::connect(
88    ///     "postgresql://postgres:password@localhost:5432/mydb?replication=database"
89    /// )?;
90    /// # Ok(())
91    /// # }
92    /// ```
93    pub fn connect(conninfo: &str) -> Result<Self> {
94        // Ensure libpq is properly initialized
95        unsafe {
96            let library_version = PQlibVersion();
97            debug!("Using libpq version: {}", library_version);
98        }
99
100        let c_conninfo = CString::new(conninfo)
101            .map_err(|e| ReplicationError::connection(format!("Invalid connection string: {e}")))?;
102
103        let conn = unsafe { PQconnectdb(c_conninfo.as_ptr()) };
104
105        if conn.is_null() {
106            return Err(ReplicationError::transient_connection(
107                "Failed to allocate PostgreSQL connection object".to_string(),
108            ));
109        }
110
111        let status = unsafe { PQstatus(conn) };
112        if status != ConnStatusType::CONNECTION_OK {
113            let error_msg = unsafe {
114                let error_ptr = PQerrorMessage(conn);
115                if error_ptr.is_null() {
116                    "Unknown connection error".to_string()
117                } else {
118                    CStr::from_ptr(error_ptr).to_string_lossy().into_owned()
119                }
120            };
121            unsafe { PQfinish(conn) };
122
123            // Categorize the connection error
124            let error_msg_lower = error_msg.to_lowercase();
125            if error_msg_lower.contains("authentication failed")
126                || error_msg_lower.contains("password authentication failed")
127                || error_msg_lower.contains("role does not exist")
128            {
129                return Err(ReplicationError::authentication(format!(
130                    "PostgreSQL authentication failed: {error_msg}"
131                )));
132            } else if error_msg_lower.contains("database does not exist")
133                || error_msg_lower.contains("invalid connection string")
134                || error_msg_lower.contains("unsupported")
135            {
136                return Err(ReplicationError::permanent_connection(format!(
137                    "PostgreSQL connection failed (permanent): {error_msg}"
138                )));
139            } else {
140                return Err(ReplicationError::transient_connection(format!(
141                    "PostgreSQL connection failed (transient): {error_msg}"
142                )));
143            }
144        }
145
146        // Check server version - logical replication requires PostgreSQL 14+
147        let server_version = unsafe { PQserverVersion(conn) };
148        if server_version < 140000 {
149            unsafe { PQfinish(conn) };
150            return Err(ReplicationError::permanent_connection(format!(
151                "PostgreSQL version {server_version} is not supported. Logical replication requires PostgreSQL 14+"
152            )));
153        }
154
155        debug!("Connected to PostgreSQL server version: {}", server_version);
156
157        Ok(Self {
158            conn,
159            is_replication_conn: false,
160            async_fd: None,
161        })
162    }
163
164    /// Execute a replication command (like IDENTIFY_SYSTEM)
165    pub fn exec(&self, query: &str) -> Result<PgResult> {
166        let c_query = CString::new(query)
167            .map_err(|e| ReplicationError::protocol(format!("Invalid query string: {e}")))?;
168
169        let result = unsafe { PQexec(self.conn, c_query.as_ptr()) };
170
171        if result.is_null() {
172            return Err(ReplicationError::protocol(
173                "Query execution failed - null result".to_string(),
174            ));
175        }
176
177        let pg_result = PgResult::new(result);
178        // Check for errors
179        let status = pg_result.status();
180        info!(
181            "query : {} pg_result.status() : {:?}",
182            query,
183            pg_result.status()
184        );
185        if !matches!(
186            status,
187            ExecStatusType::PGRES_TUPLES_OK
188                | ExecStatusType::PGRES_COMMAND_OK
189                | ExecStatusType::PGRES_COPY_BOTH
190        ) {
191            let error_msg = pg_result
192                .error_message()
193                .unwrap_or_else(|| "Unknown error".to_string());
194            return Err(ReplicationError::protocol(format!(
195                "Query execution failed: {error_msg}"
196            )));
197        }
198
199        Ok(pg_result)
200    }
201
202    /// Send IDENTIFY_SYSTEM command
203    pub fn identify_system(&self) -> Result<PgResult> {
204        debug!("Sending IDENTIFY_SYSTEM command");
205        let result = self.exec("IDENTIFY_SYSTEM")?;
206
207        if result.ntuples() > 0 {
208            if let (Some(systemid), Some(timeline), Some(xlogpos)) = (
209                result.get_value(0, 0),
210                result.get_value(0, 1),
211                result.get_value(0, 2),
212            ) {
213                debug!(
214                    "System identification: systemid={}, timeline={}, xlogpos={}",
215                    systemid, timeline, xlogpos
216                );
217            }
218        }
219
220        Ok(result)
221    }
222
223    /// Create a replication slot
224    pub fn create_replication_slot(
225        &self,
226        slot_name: &str,
227        output_plugin: &str,
228    ) -> Result<PgResult> {
229        let create_slot_sql = format!(
230            "CREATE_REPLICATION_SLOT \"{slot_name}\" LOGICAL {output_plugin} NOEXPORT_SNAPSHOT;"
231        );
232
233        let result = self.exec(&create_slot_sql)?;
234
235        if result.ntuples() > 0 {
236            if let Some(slot_name_result) = result.get_value(0, 0) {
237                debug!("Replication slot created: {}", slot_name_result);
238            }
239        }
240
241        Ok(result)
242    }
243
244    /// Start logical replication
245    pub fn start_replication(
246        &mut self,
247        slot_name: &str,
248        start_lsn: XLogRecPtr,
249        options: &[(&str, &str)],
250    ) -> Result<()> {
251        let mut options_str = String::new();
252        for (i, (key, value)) in options.iter().enumerate() {
253            if i > 0 {
254                options_str.push_str(", ");
255            }
256            options_str.push_str(&format!("\"{key}\" '{value}'"));
257        }
258
259        let start_replication_sql = if start_lsn == INVALID_XLOG_REC_PTR {
260            format!("START_REPLICATION SLOT \"{slot_name}\" LOGICAL 0/0 ({options_str})")
261        } else {
262            format!(
263                "START_REPLICATION SLOT \"{}\" LOGICAL {} ({})",
264                slot_name,
265                format_lsn(start_lsn),
266                options_str
267            )
268        };
269
270        debug!("Starting replication: {}", start_replication_sql);
271        let _result = self.exec(&start_replication_sql)?;
272
273        self.is_replication_conn = true;
274
275        // Initialize async socket for non-blocking operations
276        self.initialize_async_socket()?;
277
278        debug!("Replication started successfully");
279        Ok(())
280    }
281
282    /// Send feedback to the server (standby status update)
283    pub fn send_standby_status_update(
284        &self,
285        received_lsn: XLogRecPtr,
286        flushed_lsn: XLogRecPtr,
287        applied_lsn: XLogRecPtr,
288        reply_requested: bool,
289    ) -> Result<()> {
290        if !self.is_replication_conn {
291            return Err(ReplicationError::protocol(
292                "Connection is not in replication mode".to_string(),
293            ));
294        }
295
296        let timestamp = system_time_to_postgres_timestamp(SystemTime::now());
297
298        // Build the standby status update message using BufferWriter
299        let mut buffer = BufferWriter::with_capacity(34); // 1 + 8 + 8 + 8 + 8 + 1
300
301        buffer.write_u8(b'r')?; // Message type
302        buffer.write_u64(received_lsn)?;
303        buffer.write_u64(flushed_lsn)?;
304        buffer.write_u64(applied_lsn)?;
305        buffer.write_i64(timestamp)?;
306        buffer.write_u8(if reply_requested { 1 } else { 0 })?;
307
308        let reply_data = buffer.freeze();
309
310        let result = unsafe {
311            PQputCopyData(
312                self.conn,
313                reply_data.as_ptr() as *const std::os::raw::c_char,
314                reply_data.len() as i32,
315            )
316        };
317
318        if result != 1 {
319            let error_msg = self.last_error_message();
320            return Err(ReplicationError::protocol(format!(
321                "Failed to send standby status update: {error_msg}"
322            )));
323        }
324
325        // Flush the connection
326        let flush_result = unsafe { PQflush(self.conn) };
327        if flush_result != 0 {
328            let error_msg = self.last_error_message();
329            return Err(ReplicationError::protocol(format!(
330                "Failed to flush connection: {error_msg}"
331            )));
332        }
333
334        info!(
335            "Sent standby status update: received={}, flushed={}, applied={}, reply_requested={}",
336            format_lsn(received_lsn),
337            format_lsn(flushed_lsn),
338            format_lsn(applied_lsn),
339            reply_requested
340        );
341
342        Ok(())
343    }
344
345    /// Initialize async socket for non-blocking operations
346    fn initialize_async_socket(&mut self) -> Result<()> {
347        let sock: RawFd = unsafe { PQsocket(self.conn) };
348        if sock < 0 {
349            return Err(ReplicationError::protocol(
350                "Invalid PostgreSQL socket".to_string(),
351            ));
352        }
353
354        let async_fd = AsyncFd::new(sock)
355            .map_err(|e| ReplicationError::protocol(format!("Failed to create AsyncFd: {e}")))?;
356
357        self.async_fd = Some(async_fd);
358        Ok(())
359    }
360
361    /// Get copy data from replication stream (async non-blocking version)
362    ///
363    /// # Arguments
364    /// * `cancellation_token` - Optional cancellation token to abort the operation
365    ///
366    /// # Returns
367    /// * `Ok(Some(data))` - Successfully received data
368    /// * `Ok(None)` - No data available currently
369    /// * `Err(ReplicationError::Cancelled(_))` - Operation was cancelled
370    /// * `Err(_)` - Other errors occurred
371    pub async fn get_copy_data_async(
372        &mut self,
373        cancellation_token: &CancellationToken,
374    ) -> Result<Option<Vec<u8>>> {
375        if !self.is_replication_conn {
376            return Err(ReplicationError::protocol(
377                "Connection is not in replication mode".to_string(),
378            ));
379        }
380
381        let async_fd = self
382            .async_fd
383            .as_ref()
384            .ok_or_else(|| ReplicationError::protocol("AsyncFd not initialized".to_string()))?;
385
386        // First, try to read any buffered data without blocking
387        if let Some(data) = self.try_read_buffered_data()? {
388            return Ok(Some(data));
389        }
390
391        // If no buffered data, wait for either socket readability or cancellation
392        tokio::select! {
393            biased;
394
395            _ = cancellation_token.cancelled() => {
396                info!("Cancellation detected in get_copy_data_async");
397                if let Some(data) = self.try_read_buffered_data()? {
398                    info!("Found buffered data after cancellation, returning it");
399                    return Ok(Some(data));
400                }
401                Ok(None)
402            }
403
404            // Wait for socket to become readable
405            guard_result = async_fd.readable() => {
406                let mut guard = guard_result.map_err(|e| {
407                    ReplicationError::protocol(format!("Failed to wait for socket readability: {e}"))
408                })?;
409
410                // Socket is readable - consume input from the OS socket. This is the ONLY place we call PQconsumeInput, avoiding busy-loops
411                let consumed = unsafe { PQconsumeInput(self.conn) };
412                if consumed == 0 {
413                    return Err(ReplicationError::protocol(self.last_error_message()));
414                }
415
416                // Check if we have complete data now
417                if let Some(data) = self.try_read_buffered_data()? {
418                    return Ok(Some(data));
419                }
420
421                guard.clear_ready();
422                Ok(None)
423            }
424        }
425    }
426
427    /// Try to read copy data from libpq's internal buffer without consuming OS socket
428    /// This should only be called after PQconsumeInput has been called
429    fn try_read_buffered_data(&self) -> Result<Option<Vec<u8>>> {
430        // Check if data is ready without blocking. PQisBusy returns 0 if a complete message is available
431        if unsafe { PQisBusy(self.conn) } != 0 {
432            return Ok(None); // Buffer not complete, wait for next socket readable event
433        }
434
435        let mut buffer: *mut std::os::raw::c_char = ptr::null_mut();
436        let result = unsafe { PQgetCopyData(self.conn, &mut buffer, 1) };
437
438        match result {
439            len if len > 0 => {
440                if buffer.is_null() {
441                    return Err(ReplicationError::buffer(
442                        "Received null buffer from PQgetCopyData".to_string(),
443                    ));
444                }
445
446                let data = unsafe {
447                    std::slice::from_raw_parts(buffer as *const u8, len as usize).to_vec()
448                };
449
450                unsafe { PQfreemem(buffer as *mut std::os::raw::c_void) };
451                Ok(Some(data))
452            }
453            0 | -2 => Ok(None), // No complete data available, continue waiting
454            -1 => {
455                // COPY finished or channel closed
456                debug!("COPY finished or channel closed");
457                Ok(None)
458            }
459            _ => Err(ReplicationError::protocol(format!(
460                "Unexpected PQgetCopyData result: {result}"
461            ))),
462        }
463    }
464
465    /// Get the last error message from the connection
466    fn last_error_message(&self) -> String {
467        unsafe {
468            let error_ptr = PQerrorMessage(self.conn);
469            if error_ptr.is_null() {
470                "Unknown error".to_string()
471            } else {
472                CStr::from_ptr(error_ptr).to_string_lossy().into_owned()
473            }
474        }
475    }
476
477    /// Check if the connection is still alive
478    pub fn is_alive(&self) -> bool {
479        if self.conn.is_null() {
480            return false;
481        }
482
483        unsafe { PQstatus(self.conn) == ConnStatusType::CONNECTION_OK }
484    }
485
486    /// Get the server version
487    pub fn server_version(&self) -> i32 {
488        unsafe { PQserverVersion(self.conn) }
489    }
490
491    fn close_replication_connection(&mut self) {
492        if !self.conn.is_null() {
493            info!("Closing PostgreSQL replication connection");
494
495            // If we're in replication mode, try to end the copy gracefully
496            if self.is_replication_conn {
497                debug!("Ending COPY mode before closing connection");
498                unsafe {
499                    // Try to end the copy operation gracefully, This is important to properly close the replication stream
500                    let result = PQputCopyEnd(self.conn, ptr::null());
501                    if result != 1 {
502                        warn!(
503                            "Failed to end COPY mode gracefully: {}",
504                            self.last_error_message()
505                        );
506                    } else {
507                        debug!("COPY mode ended gracefully");
508                    }
509                }
510                self.is_replication_conn = false;
511            }
512
513            // Close the connection
514            unsafe {
515                PQfinish(self.conn);
516            }
517
518            // Clear the connection pointer and reset state
519            self.conn = std::ptr::null_mut();
520            self.async_fd = None;
521
522            info!("PostgreSQL replication connection closed and cleaned up");
523        } else {
524            info!("Connection already closed or was never initialized");
525        }
526    }
527}
528
529impl Drop for PgReplicationConnection {
530    fn drop(&mut self) {
531        self.close_replication_connection();
532    }
533}
534
535// Make the connection Send by ensuring exclusive access
536unsafe impl Send for PgReplicationConnection {}
537
538/// Safe wrapper for PostgreSQL result
539pub struct PgResult {
540    result: *mut PGresult,
541}
542
543impl PgResult {
544    fn new(result: *mut PGresult) -> Self {
545        Self { result }
546    }
547
548    /// Get the execution status
549    pub fn status(&self) -> ExecStatusType {
550        unsafe { PQresultStatus(self.result) }
551    }
552
553    /// Check if the result is OK
554    pub fn is_ok(&self) -> bool {
555        matches!(
556            self.status(),
557            ExecStatusType::PGRES_TUPLES_OK | ExecStatusType::PGRES_COMMAND_OK
558        )
559    }
560
561    /// Get number of tuples (rows)
562    pub fn ntuples(&self) -> i32 {
563        unsafe { PQntuples(self.result) }
564    }
565
566    /// Get number of fields (columns)
567    pub fn nfields(&self) -> i32 {
568        unsafe { PQnfields(self.result) }
569    }
570
571    /// Get a field value as string
572    pub fn get_value(&self, row: i32, col: i32) -> Option<String> {
573        if row >= self.ntuples() || col >= self.nfields() {
574            return None;
575        }
576
577        let value_ptr = unsafe { PQgetvalue(self.result, row, col) };
578        if value_ptr.is_null() {
579            None
580        } else {
581            unsafe { Some(CStr::from_ptr(value_ptr).to_string_lossy().into_owned()) }
582        }
583    }
584
585    /// Get error message if any
586    pub fn error_message(&self) -> Option<String> {
587        let error_ptr = unsafe { PQresultErrorMessage(self.result) };
588        if error_ptr.is_null() {
589            None
590        } else {
591            unsafe { Some(CStr::from_ptr(error_ptr).to_string_lossy().into_owned()) }
592        }
593    }
594}
595
596impl Drop for PgResult {
597    fn drop(&mut self) {
598        if !self.result.is_null() {
599            unsafe {
600                PQclear(self.result);
601            }
602        }
603    }
604}