1use crate::buffer::BufferWriter;
6use crate::error::{ReplicationError, Result};
7use crate::types::{format_lsn, system_time_to_postgres_timestamp, XLogRecPtr};
8use libpq_sys::*;
9use std::ffi::{CStr, CString};
10use std::os::unix::io::RawFd;
11use std::ptr;
12use std::time::SystemTime;
13use tokio::io::unix::AsyncFd;
14use tokio_util::sync::CancellationToken;
15use tracing::{debug, info, warn};
16
17pub use crate::types::INVALID_XLOG_REC_PTR;
18
19pub struct PgReplicationConnection {
52 conn: *mut PGconn,
53 is_replication_conn: bool,
54 async_fd: Option<AsyncFd<RawFd>>,
55}
56
57impl PgReplicationConnection {
58 pub fn connect(conninfo: &str) -> Result<Self> {
94 unsafe {
96 let library_version = PQlibVersion();
97 debug!("Using libpq version: {}", library_version);
98 }
99
100 let c_conninfo = CString::new(conninfo)
101 .map_err(|e| ReplicationError::connection(format!("Invalid connection string: {e}")))?;
102
103 let conn = unsafe { PQconnectdb(c_conninfo.as_ptr()) };
104
105 if conn.is_null() {
106 return Err(ReplicationError::transient_connection(
107 "Failed to allocate PostgreSQL connection object".to_string(),
108 ));
109 }
110
111 let status = unsafe { PQstatus(conn) };
112 if status != ConnStatusType::CONNECTION_OK {
113 let error_msg = unsafe {
114 let error_ptr = PQerrorMessage(conn);
115 if error_ptr.is_null() {
116 "Unknown connection error".to_string()
117 } else {
118 CStr::from_ptr(error_ptr).to_string_lossy().into_owned()
119 }
120 };
121 unsafe { PQfinish(conn) };
122
123 let error_msg_lower = error_msg.to_lowercase();
125 if error_msg_lower.contains("authentication failed")
126 || error_msg_lower.contains("password authentication failed")
127 || error_msg_lower.contains("role does not exist")
128 {
129 return Err(ReplicationError::authentication(format!(
130 "PostgreSQL authentication failed: {error_msg}"
131 )));
132 } else if error_msg_lower.contains("database does not exist")
133 || error_msg_lower.contains("invalid connection string")
134 || error_msg_lower.contains("unsupported")
135 {
136 return Err(ReplicationError::permanent_connection(format!(
137 "PostgreSQL connection failed (permanent): {error_msg}"
138 )));
139 } else {
140 return Err(ReplicationError::transient_connection(format!(
141 "PostgreSQL connection failed (transient): {error_msg}"
142 )));
143 }
144 }
145
146 let server_version = unsafe { PQserverVersion(conn) };
148 if server_version < 140000 {
149 unsafe { PQfinish(conn) };
150 return Err(ReplicationError::permanent_connection(format!(
151 "PostgreSQL version {server_version} is not supported. Logical replication requires PostgreSQL 14+"
152 )));
153 }
154
155 debug!("Connected to PostgreSQL server version: {}", server_version);
156
157 Ok(Self {
158 conn,
159 is_replication_conn: false,
160 async_fd: None,
161 })
162 }
163
164 pub fn exec(&self, query: &str) -> Result<PgResult> {
166 let c_query = CString::new(query)
167 .map_err(|e| ReplicationError::protocol(format!("Invalid query string: {e}")))?;
168
169 let result = unsafe { PQexec(self.conn, c_query.as_ptr()) };
170
171 if result.is_null() {
172 return Err(ReplicationError::protocol(
173 "Query execution failed - null result".to_string(),
174 ));
175 }
176
177 let pg_result = PgResult::new(result);
178 let status = pg_result.status();
180 info!(
181 "query : {} pg_result.status() : {:?}",
182 query,
183 pg_result.status()
184 );
185 if !matches!(
186 status,
187 ExecStatusType::PGRES_TUPLES_OK
188 | ExecStatusType::PGRES_COMMAND_OK
189 | ExecStatusType::PGRES_COPY_BOTH
190 ) {
191 let error_msg = pg_result
192 .error_message()
193 .unwrap_or_else(|| "Unknown error".to_string());
194 return Err(ReplicationError::protocol(format!(
195 "Query execution failed: {error_msg}"
196 )));
197 }
198
199 Ok(pg_result)
200 }
201
202 pub fn identify_system(&self) -> Result<PgResult> {
204 debug!("Sending IDENTIFY_SYSTEM command");
205 let result = self.exec("IDENTIFY_SYSTEM")?;
206
207 if result.ntuples() > 0 {
208 if let (Some(systemid), Some(timeline), Some(xlogpos)) = (
209 result.get_value(0, 0),
210 result.get_value(0, 1),
211 result.get_value(0, 2),
212 ) {
213 debug!(
214 "System identification: systemid={}, timeline={}, xlogpos={}",
215 systemid, timeline, xlogpos
216 );
217 }
218 }
219
220 Ok(result)
221 }
222
223 pub fn create_replication_slot(
225 &self,
226 slot_name: &str,
227 output_plugin: &str,
228 ) -> Result<PgResult> {
229 let create_slot_sql = format!(
230 "CREATE_REPLICATION_SLOT \"{slot_name}\" LOGICAL {output_plugin} NOEXPORT_SNAPSHOT;"
231 );
232
233 let result = self.exec(&create_slot_sql)?;
234
235 if result.ntuples() > 0 {
236 if let Some(slot_name_result) = result.get_value(0, 0) {
237 debug!("Replication slot created: {}", slot_name_result);
238 }
239 }
240
241 Ok(result)
242 }
243
244 pub fn start_replication(
246 &mut self,
247 slot_name: &str,
248 start_lsn: XLogRecPtr,
249 options: &[(&str, &str)],
250 ) -> Result<()> {
251 let mut options_str = String::new();
252 for (i, (key, value)) in options.iter().enumerate() {
253 if i > 0 {
254 options_str.push_str(", ");
255 }
256 options_str.push_str(&format!("\"{key}\" '{value}'"));
257 }
258
259 let start_replication_sql = if start_lsn == INVALID_XLOG_REC_PTR {
260 format!("START_REPLICATION SLOT \"{slot_name}\" LOGICAL 0/0 ({options_str})")
261 } else {
262 format!(
263 "START_REPLICATION SLOT \"{}\" LOGICAL {} ({})",
264 slot_name,
265 format_lsn(start_lsn),
266 options_str
267 )
268 };
269
270 debug!("Starting replication: {}", start_replication_sql);
271 let _result = self.exec(&start_replication_sql)?;
272
273 self.is_replication_conn = true;
274
275 self.initialize_async_socket()?;
277
278 debug!("Replication started successfully");
279 Ok(())
280 }
281
282 pub fn send_standby_status_update(
284 &self,
285 received_lsn: XLogRecPtr,
286 flushed_lsn: XLogRecPtr,
287 applied_lsn: XLogRecPtr,
288 reply_requested: bool,
289 ) -> Result<()> {
290 if !self.is_replication_conn {
291 return Err(ReplicationError::protocol(
292 "Connection is not in replication mode".to_string(),
293 ));
294 }
295
296 let timestamp = system_time_to_postgres_timestamp(SystemTime::now());
297
298 let mut buffer = BufferWriter::with_capacity(34); buffer.write_u8(b'r')?; buffer.write_u64(received_lsn)?;
303 buffer.write_u64(flushed_lsn)?;
304 buffer.write_u64(applied_lsn)?;
305 buffer.write_i64(timestamp)?;
306 buffer.write_u8(if reply_requested { 1 } else { 0 })?;
307
308 let reply_data = buffer.freeze();
309
310 let result = unsafe {
311 PQputCopyData(
312 self.conn,
313 reply_data.as_ptr() as *const std::os::raw::c_char,
314 reply_data.len() as i32,
315 )
316 };
317
318 if result != 1 {
319 let error_msg = self.last_error_message();
320 return Err(ReplicationError::protocol(format!(
321 "Failed to send standby status update: {error_msg}"
322 )));
323 }
324
325 let flush_result = unsafe { PQflush(self.conn) };
327 if flush_result != 0 {
328 let error_msg = self.last_error_message();
329 return Err(ReplicationError::protocol(format!(
330 "Failed to flush connection: {error_msg}"
331 )));
332 }
333
334 info!(
335 "Sent standby status update: received={}, flushed={}, applied={}, reply_requested={}",
336 format_lsn(received_lsn),
337 format_lsn(flushed_lsn),
338 format_lsn(applied_lsn),
339 reply_requested
340 );
341
342 Ok(())
343 }
344
345 fn initialize_async_socket(&mut self) -> Result<()> {
347 let sock: RawFd = unsafe { PQsocket(self.conn) };
348 if sock < 0 {
349 return Err(ReplicationError::protocol(
350 "Invalid PostgreSQL socket".to_string(),
351 ));
352 }
353
354 let async_fd = AsyncFd::new(sock)
355 .map_err(|e| ReplicationError::protocol(format!("Failed to create AsyncFd: {e}")))?;
356
357 self.async_fd = Some(async_fd);
358 Ok(())
359 }
360
361 pub async fn get_copy_data_async(
372 &mut self,
373 cancellation_token: &CancellationToken,
374 ) -> Result<Option<Vec<u8>>> {
375 if !self.is_replication_conn {
376 return Err(ReplicationError::protocol(
377 "Connection is not in replication mode".to_string(),
378 ));
379 }
380
381 let async_fd = self
382 .async_fd
383 .as_ref()
384 .ok_or_else(|| ReplicationError::protocol("AsyncFd not initialized".to_string()))?;
385
386 if let Some(data) = self.try_read_buffered_data()? {
388 return Ok(Some(data));
389 }
390
391 tokio::select! {
393 biased;
394
395 _ = cancellation_token.cancelled() => {
396 info!("Cancellation detected in get_copy_data_async");
397 if let Some(data) = self.try_read_buffered_data()? {
398 info!("Found buffered data after cancellation, returning it");
399 return Ok(Some(data));
400 }
401 Ok(None)
402 }
403
404 guard_result = async_fd.readable() => {
406 let mut guard = guard_result.map_err(|e| {
407 ReplicationError::protocol(format!("Failed to wait for socket readability: {e}"))
408 })?;
409
410 let consumed = unsafe { PQconsumeInput(self.conn) };
412 if consumed == 0 {
413 return Err(ReplicationError::protocol(self.last_error_message()));
414 }
415
416 if let Some(data) = self.try_read_buffered_data()? {
418 return Ok(Some(data));
419 }
420
421 guard.clear_ready();
422 Ok(None)
423 }
424 }
425 }
426
427 fn try_read_buffered_data(&self) -> Result<Option<Vec<u8>>> {
430 if unsafe { PQisBusy(self.conn) } != 0 {
432 return Ok(None); }
434
435 let mut buffer: *mut std::os::raw::c_char = ptr::null_mut();
436 let result = unsafe { PQgetCopyData(self.conn, &mut buffer, 1) };
437
438 match result {
439 len if len > 0 => {
440 if buffer.is_null() {
441 return Err(ReplicationError::buffer(
442 "Received null buffer from PQgetCopyData".to_string(),
443 ));
444 }
445
446 let data = unsafe {
447 std::slice::from_raw_parts(buffer as *const u8, len as usize).to_vec()
448 };
449
450 unsafe { PQfreemem(buffer as *mut std::os::raw::c_void) };
451 Ok(Some(data))
452 }
453 0 | -2 => Ok(None), -1 => {
455 debug!("COPY finished or channel closed");
457 Ok(None)
458 }
459 _ => Err(ReplicationError::protocol(format!(
460 "Unexpected PQgetCopyData result: {result}"
461 ))),
462 }
463 }
464
465 fn last_error_message(&self) -> String {
467 unsafe {
468 let error_ptr = PQerrorMessage(self.conn);
469 if error_ptr.is_null() {
470 "Unknown error".to_string()
471 } else {
472 CStr::from_ptr(error_ptr).to_string_lossy().into_owned()
473 }
474 }
475 }
476
477 pub fn is_alive(&self) -> bool {
479 if self.conn.is_null() {
480 return false;
481 }
482
483 unsafe { PQstatus(self.conn) == ConnStatusType::CONNECTION_OK }
484 }
485
486 pub fn server_version(&self) -> i32 {
488 unsafe { PQserverVersion(self.conn) }
489 }
490
491 fn close_replication_connection(&mut self) {
492 if !self.conn.is_null() {
493 info!("Closing PostgreSQL replication connection");
494
495 if self.is_replication_conn {
497 debug!("Ending COPY mode before closing connection");
498 unsafe {
499 let result = PQputCopyEnd(self.conn, ptr::null());
501 if result != 1 {
502 warn!(
503 "Failed to end COPY mode gracefully: {}",
504 self.last_error_message()
505 );
506 } else {
507 debug!("COPY mode ended gracefully");
508 }
509 }
510 self.is_replication_conn = false;
511 }
512
513 unsafe {
515 PQfinish(self.conn);
516 }
517
518 self.conn = std::ptr::null_mut();
520 self.async_fd = None;
521
522 info!("PostgreSQL replication connection closed and cleaned up");
523 } else {
524 info!("Connection already closed or was never initialized");
525 }
526 }
527}
528
529impl Drop for PgReplicationConnection {
530 fn drop(&mut self) {
531 self.close_replication_connection();
532 }
533}
534
535unsafe impl Send for PgReplicationConnection {}
537
538pub struct PgResult {
540 result: *mut PGresult,
541}
542
543impl PgResult {
544 fn new(result: *mut PGresult) -> Self {
545 Self { result }
546 }
547
548 pub fn status(&self) -> ExecStatusType {
550 unsafe { PQresultStatus(self.result) }
551 }
552
553 pub fn is_ok(&self) -> bool {
555 matches!(
556 self.status(),
557 ExecStatusType::PGRES_TUPLES_OK | ExecStatusType::PGRES_COMMAND_OK
558 )
559 }
560
561 pub fn ntuples(&self) -> i32 {
563 unsafe { PQntuples(self.result) }
564 }
565
566 pub fn nfields(&self) -> i32 {
568 unsafe { PQnfields(self.result) }
569 }
570
571 pub fn get_value(&self, row: i32, col: i32) -> Option<String> {
573 if row >= self.ntuples() || col >= self.nfields() {
574 return None;
575 }
576
577 let value_ptr = unsafe { PQgetvalue(self.result, row, col) };
578 if value_ptr.is_null() {
579 None
580 } else {
581 unsafe { Some(CStr::from_ptr(value_ptr).to_string_lossy().into_owned()) }
582 }
583 }
584
585 pub fn error_message(&self) -> Option<String> {
587 let error_ptr = unsafe { PQresultErrorMessage(self.result) };
588 if error_ptr.is_null() {
589 None
590 } else {
591 unsafe { Some(CStr::from_ptr(error_ptr).to_string_lossy().into_owned()) }
592 }
593 }
594}
595
596impl Drop for PgResult {
597 fn drop(&mut self) {
598 if !self.result.is_null() {
599 unsafe {
600 PQclear(self.result);
601 }
602 }
603 }
604}