1use anyhow::{anyhow, Result};
16use bytes::{Buf, BytesMut};
17use log::{debug, info, trace, warn};
18use std::collections::HashMap;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::TcpStream;
21
22use super::protocol::{
23 parse_backend_message, AuthenticationMessage, BackendMessage, FrontendMessage, StartupMessage,
24 TransactionStatus,
25};
26use super::scram::ScramClient;
27use super::types::{ReplicationSlotInfo, StandbyStatusUpdate};
28
29pub struct ReplicationConnection {
30 stream: TcpStream,
31 read_buffer: BytesMut,
32 write_buffer: BytesMut,
33 parameters: HashMap<String, String>,
34 process_id: Option<i32>,
35 secret_key: Option<i32>,
36 transaction_status: TransactionStatus,
37 in_copy_mode: bool,
38}
39
40impl ReplicationConnection {
41 pub async fn connect(
42 host: &str,
43 port: u16,
44 database: &str,
45 user: &str,
46 password: &str,
47 ) -> Result<Self> {
48 info!("Connecting to PostgreSQL at {host}:{port}");
49
50 let stream = TcpStream::connect((host, port)).await?;
51 stream.set_nodelay(true)?;
52
53 let mut conn = Self {
54 stream,
55 read_buffer: BytesMut::with_capacity(8192),
56 write_buffer: BytesMut::with_capacity(8192),
57 parameters: HashMap::new(),
58 process_id: None,
59 secret_key: None,
60 transaction_status: TransactionStatus::Idle,
61 in_copy_mode: false,
62 };
63
64 conn.startup_replication(database, user, password).await?;
65
66 Ok(conn)
67 }
68
69 async fn startup_replication(
70 &mut self,
71 database: &str,
72 user: &str,
73 password: &str,
74 ) -> Result<()> {
75 debug!("Starting replication protocol handshake");
76
77 let startup = StartupMessage::new_replication(database, user);
79 self.send_message(FrontendMessage::StartupMessage(startup))
80 .await?;
81
82 loop {
84 let msg = self.read_message().await?;
85 match msg {
86 BackendMessage::Authentication(auth) => {
87 match auth {
88 AuthenticationMessage::Ok => {
89 debug!("Authentication successful");
90 break;
91 }
92 AuthenticationMessage::CleartextPassword => {
93 debug!("Server requested cleartext password");
94 self.send_message(FrontendMessage::PasswordMessage(
95 password.to_string(),
96 ))
97 .await?;
98 }
99 AuthenticationMessage::MD5Password(_) => {
100 return Err(anyhow!(
101 "MD5 authentication is not supported (insecure). \
102 Please configure PostgreSQL to use scram-sha-256 in pg_hba.conf"
103 ));
104 }
105 AuthenticationMessage::SASL(mechanisms) => {
106 if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
107 debug!("Server requested SCRAM-SHA-256 authentication");
108 let mut scram_client = ScramClient::new(user, password);
109
110 let client_first = scram_client.client_first_message();
112 self.send_sasl_initial_response("SCRAM-SHA-256", &client_first)
113 .await?;
114
115 loop {
117 let sasl_msg = self.read_message().await?;
118 match sasl_msg {
119 BackendMessage::Authentication(
120 AuthenticationMessage::SASLContinue(data),
121 ) => {
122 let server_first = String::from_utf8_lossy(&data);
123 scram_client
124 .process_server_first_message(&server_first)?;
125
126 let client_final =
127 scram_client.client_final_message()?;
128 self.send_sasl_response(&client_final).await?;
129 }
130 BackendMessage::Authentication(
131 AuthenticationMessage::SASLFinal(data),
132 ) => {
133 let server_final = String::from_utf8_lossy(&data);
134 scram_client.verify_server_final(&server_final)?;
135 debug!("SCRAM-SHA-256 authentication successful");
136 break;
137 }
138 BackendMessage::ErrorResponse(err) => {
139 return Err(anyhow!(
140 "SASL authentication failed: {}",
141 err.message
142 ));
143 }
144 _ => {
145 warn!("Unexpected message during SASL: {sasl_msg:?}");
146 }
147 }
148 }
149 } else {
150 return Err(anyhow!("No supported SASL mechanisms"));
151 }
152 }
153 _ => {
154 return Err(anyhow!("Unsupported authentication method"));
155 }
156 }
157 }
158 BackendMessage::ErrorResponse(err) => {
159 return Err(anyhow!("Authentication failed: {}", err.message));
160 }
161 _ => {
162 warn!("Unexpected message during authentication: {msg:?}");
163 }
164 }
165 }
166
167 loop {
169 let msg = self.read_message().await?;
170 match msg {
171 BackendMessage::BackendKeyData {
172 process_id,
173 secret_key,
174 } => {
175 self.process_id = Some(process_id);
176 self.secret_key = Some(secret_key);
177 debug!("Received backend key data: pid={process_id}");
178 }
179 BackendMessage::ParameterStatus { name, value } => {
180 debug!("Parameter: {name} = {value}");
181 self.parameters.insert(name, value);
182 }
183 BackendMessage::ReadyForQuery(status) => {
184 self.transaction_status = status;
185 debug!("Connection ready, status: {status:?}");
186 break;
187 }
188 BackendMessage::ErrorResponse(err) => {
189 return Err(anyhow!("Startup failed: {}", err.message));
190 }
191 BackendMessage::NoticeResponse(notice) => {
192 info!("Notice: {}", notice.message);
193 }
194 _ => {
195 warn!("Unexpected message during startup: {msg:?}");
196 }
197 }
198 }
199
200 Ok(())
201 }
202
203 pub async fn identify_system(&mut self) -> Result<HashMap<String, String>> {
204 debug!("Sending IDENTIFY_SYSTEM command");
205
206 self.send_message(FrontendMessage::Query("IDENTIFY_SYSTEM".to_string()))
207 .await?;
208
209 let mut system_info = HashMap::new();
210
211 loop {
212 let msg = self.read_message().await?;
213 match msg {
214 BackendMessage::RowDescription(_) => {
215 }
217 BackendMessage::DataRow(row) => {
218 if row.len() >= 4 {
220 if let Some(Some(systemid)) = row.first() {
221 system_info.insert(
222 "systemid".to_string(),
223 String::from_utf8_lossy(systemid).to_string(),
224 );
225 }
226 if let Some(Some(timeline)) = row.get(1) {
227 system_info.insert(
228 "timeline".to_string(),
229 String::from_utf8_lossy(timeline).to_string(),
230 );
231 }
232 if let Some(Some(xlogpos)) = row.get(2) {
233 system_info.insert(
234 "xlogpos".to_string(),
235 String::from_utf8_lossy(xlogpos).to_string(),
236 );
237 }
238 if let Some(Some(dbname)) = row.get(3) {
239 system_info.insert(
240 "dbname".to_string(),
241 String::from_utf8_lossy(dbname).to_string(),
242 );
243 }
244 }
245 }
246 BackendMessage::CommandComplete(_) => {
247 }
249 BackendMessage::ReadyForQuery(status) => {
250 self.transaction_status = status;
251 break;
252 }
253 BackendMessage::ErrorResponse(err) => {
254 return Err(anyhow!("IDENTIFY_SYSTEM failed: {}", err.message));
255 }
256 _ => {
257 warn!("Unexpected message during IDENTIFY_SYSTEM: {msg:?}");
258 }
259 }
260 }
261
262 Ok(system_info)
263 }
264
265 pub async fn create_replication_slot(
266 &mut self,
267 slot_name: &str,
268 temporary: bool,
269 ) -> Result<ReplicationSlotInfo> {
270 debug!("Creating replication slot: {slot_name}");
271
272 let query = if temporary {
273 format!("CREATE_REPLICATION_SLOT {slot_name} TEMPORARY LOGICAL pgoutput")
274 } else {
275 format!("CREATE_REPLICATION_SLOT {slot_name} LOGICAL pgoutput")
276 };
277
278 self.send_message(FrontendMessage::Query(query)).await?;
279
280 let mut slot_info = ReplicationSlotInfo {
281 slot_name: slot_name.to_string(),
282 consistent_point: String::new(),
283 snapshot_name: None,
284 output_plugin: "pgoutput".to_string(),
285 restart_lsn: None,
286 };
287
288 loop {
289 let msg = self.read_message().await?;
290 match msg {
291 BackendMessage::RowDescription(_) => {
292 }
294 BackendMessage::DataRow(row) => {
295 if row.len() >= 4 {
297 if let Some(Some(consistent_point)) = row.get(1) {
298 slot_info.consistent_point =
299 String::from_utf8_lossy(consistent_point).to_string();
300 }
301 if let Some(Some(snapshot_name)) = row.get(2) {
302 slot_info.snapshot_name =
303 Some(String::from_utf8_lossy(snapshot_name).to_string());
304 }
305 }
306 }
307 BackendMessage::CommandComplete(_) => {
308 }
310 BackendMessage::ReadyForQuery(status) => {
311 self.transaction_status = status;
312 break;
313 }
314 BackendMessage::ErrorResponse(err) => {
315 if err.message.contains("already exists") {
316 debug!("Replication slot already exists: {slot_name}");
317 loop {
319 let drain_msg = self.read_message().await?;
320 if let BackendMessage::ReadyForQuery(status) = drain_msg {
321 self.transaction_status = status;
322 break;
323 }
324 }
325 return self.get_replication_slot_info(slot_name).await;
326 }
327 return Err(anyhow!("CREATE_REPLICATION_SLOT failed: {}", err.message));
328 }
329 _ => {
330 warn!("Unexpected message during CREATE_REPLICATION_SLOT: {msg:?}");
331 }
332 }
333 }
334
335 Ok(slot_info)
336 }
337
338 pub async fn get_replication_slot_info(
339 &mut self,
340 slot_name: &str,
341 ) -> Result<ReplicationSlotInfo> {
342 debug!("Querying existing replication slot: {slot_name}");
343
344 let slot_name_escaped = slot_name.replace('\'', "''");
345 let query = format!(
346 "SELECT slot_name, confirmed_flush_lsn, restart_lsn, plugin FROM pg_replication_slots WHERE slot_name = '{slot_name_escaped}'"
347 );
348
349 self.send_message(FrontendMessage::Query(query)).await?;
350
351 let mut slot_info = ReplicationSlotInfo {
352 slot_name: slot_name.to_string(),
353 consistent_point: "0/0".to_string(),
354 snapshot_name: None,
355 output_plugin: "pgoutput".to_string(),
356 restart_lsn: None,
357 };
358 let mut found_row = false;
359
360 loop {
361 let msg = self.read_message().await?;
362 match msg {
363 BackendMessage::RowDescription(_) => {
364 }
366 BackendMessage::DataRow(row) => {
367 found_row = true;
368 if row.len() >= 4 {
369 if let Some(Some(confirmed_flush_lsn)) = row.get(1) {
370 let lsn = String::from_utf8_lossy(confirmed_flush_lsn).to_string();
371 if !lsn.is_empty() {
372 slot_info.consistent_point = lsn;
373 }
374 }
375 if let Some(Some(restart_lsn_val)) = row.get(2) {
376 let lsn = String::from_utf8_lossy(restart_lsn_val).to_string();
377 if !lsn.is_empty() {
378 slot_info.restart_lsn = Some(lsn.clone());
379 if slot_info.consistent_point == "0/0" {
381 slot_info.consistent_point = lsn;
382 }
383 }
384 }
385 if let Some(Some(plugin)) = row.get(3) {
386 slot_info.output_plugin = String::from_utf8_lossy(plugin).to_string();
387 }
388 }
389 }
390 BackendMessage::CommandComplete(_) => {
391 }
393 BackendMessage::ReadyForQuery(status) => {
394 self.transaction_status = status;
395 break;
396 }
397 BackendMessage::ErrorResponse(err) => {
398 return Err(anyhow!("Failed to query replication slot: {}", err.message));
399 }
400 _ => {
401 warn!("Unexpected message during slot query: {msg:?}");
402 }
403 }
404 }
405
406 if !found_row {
407 return Err(anyhow!("Replication slot not found: {slot_name}"));
408 }
409
410 info!(
411 "Using existing replication slot: {slot_name} at LSN {}",
412 slot_info.consistent_point
413 );
414 Ok(slot_info)
415 }
416
417 pub async fn start_replication(
418 &mut self,
419 slot_name: &str,
420 start_lsn: Option<u64>,
421 options: HashMap<String, String>,
422 ) -> Result<()> {
423 debug!("Starting replication from slot: {slot_name}");
424
425 let mut query = format!("START_REPLICATION SLOT {slot_name} LOGICAL");
426
427 if let Some(lsn) = start_lsn {
428 query.push_str(&format!(" {}", format_lsn(lsn)));
429 } else {
430 query.push_str(" 0/0");
431 }
432
433 if !options.is_empty() {
434 query.push_str(" (");
435 let opts: Vec<String> = options.iter().map(|(k, v)| format!("{k} '{v}'")).collect();
436 query.push_str(&opts.join(", "));
437 query.push(')');
438 }
439
440 self.send_message(FrontendMessage::Query(query)).await?;
441
442 loop {
444 let msg = self.read_message().await?;
445 match msg {
446 BackendMessage::CopyBothResponse => {
447 debug!("Entered COPY BOTH mode for replication");
448 self.in_copy_mode = true;
449 break;
450 }
451 BackendMessage::ErrorResponse(err) => {
452 return Err(anyhow!("START_REPLICATION failed: {}", err.message));
453 }
454 BackendMessage::ReadyForQuery(_) => {
455 debug!("Received ReadyForQuery before entering COPY mode");
457 }
458 _ => {
459 debug!("Message during START_REPLICATION: {msg:?}");
460 }
461 }
462 }
463
464 Ok(())
465 }
466
467 pub async fn read_replication_message(&mut self) -> Result<BackendMessage> {
468 if !self.in_copy_mode {
469 return Err(anyhow!("Not in COPY mode"));
470 }
471
472 self.read_message().await
473 }
474
475 pub async fn send_standby_status(&mut self, status: StandbyStatusUpdate) -> Result<()> {
476 if !self.in_copy_mode {
477 return Err(anyhow!("Not in COPY mode"));
478 }
479
480 let timestamp = chrono::Utc::now().timestamp_micros() - 946684800000000; self.send_message(FrontendMessage::StandbyStatusUpdate {
483 write_lsn: status.write_lsn,
484 flush_lsn: status.flush_lsn,
485 apply_lsn: status.apply_lsn,
486 timestamp,
487 reply: if status.reply_requested { 1 } else { 0 },
488 })
489 .await
490 }
491
492 async fn send_message(&mut self, msg: FrontendMessage) -> Result<()> {
493 self.write_buffer.clear();
494 msg.encode(&mut self.write_buffer)?;
495
496 self.stream.write_all(&self.write_buffer).await?;
497 self.stream.flush().await?;
498
499 trace!("Sent message: {msg:?}");
500 Ok(())
501 }
502
503 async fn send_sasl_initial_response(&mut self, mechanism: &str, response: &str) -> Result<()> {
504 self.send_message(FrontendMessage::SASLInitialResponse {
505 mechanism: mechanism.to_string(),
506 data: response.as_bytes().to_vec(),
507 })
508 .await
509 }
510
511 async fn send_sasl_response(&mut self, response: &str) -> Result<()> {
512 self.send_message(FrontendMessage::SASLResponse(response.as_bytes().to_vec()))
513 .await
514 }
515
516 async fn read_message(&mut self) -> Result<BackendMessage> {
517 loop {
518 if let Some(msg) = self.try_parse_message()? {
520 trace!("Received message: {msg:?}");
521 return Ok(msg);
522 }
523
524 let mut temp_buf = vec![0u8; 4096];
526 let n = self.stream.read(&mut temp_buf).await?;
527 if n == 0 {
528 return Err(anyhow!("Connection closed by server"));
529 }
530
531 self.read_buffer.extend_from_slice(&temp_buf[..n]);
532 }
533 }
534
535 fn try_parse_message(&mut self) -> Result<Option<BackendMessage>> {
536 if self.read_buffer.len() < 5 {
537 return Ok(None); }
539
540 let msg_type = self.read_buffer[0];
541 let length = u32::from_be_bytes([
542 self.read_buffer[1],
543 self.read_buffer[2],
544 self.read_buffer[3],
545 self.read_buffer[4],
546 ]) as usize;
547
548 if length < 4 {
549 return Err(anyhow!("Invalid message length: {length}"));
550 }
551
552 let total_length = 1 + length; if self.read_buffer.len() < total_length {
555 return Ok(None); }
557
558 let body = self.read_buffer[5..total_length].to_vec();
560 self.read_buffer.advance(total_length);
561
562 let msg = parse_backend_message(msg_type, &body)?;
564 Ok(Some(msg))
565 }
566
567 pub async fn close(mut self) -> Result<()> {
568 if self.in_copy_mode {
569 let _ = self.send_message(FrontendMessage::CopyDone).await;
570 }
571 let _ = self.send_message(FrontendMessage::Terminate).await;
572 let _ = self.stream.shutdown().await;
573 Ok(())
574 }
575}
576
577pub(crate) fn format_lsn(lsn: u64) -> String {
579 format!("{:X}/{:X}", lsn >> 32, lsn & 0xFFFFFFFF)
580}
581
582pub(crate) fn parse_lsn(lsn_str: &str) -> Result<u64> {
585 let parts: Vec<&str> = lsn_str.split('/').collect();
586 if parts.len() != 2 {
587 return Err(anyhow!("Invalid LSN format: {lsn_str}"));
588 }
589
590 let high = u64::from_str_radix(parts[0], 16)?;
591 let low = u64::from_str_radix(parts[1], 16)?;
592
593 Ok((high << 32) | low)
594}