1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite, BufReader};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12 read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13 write_startup_message, MessageReader,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17 encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20pub struct SharedProgress {
25 applied: AtomicU64,
26}
27
28impl SharedProgress {
29 pub fn new(start: Lsn) -> Self {
30 Self {
31 applied: AtomicU64::new(start.as_u64()),
32 }
33 }
34
35 #[inline]
36 pub fn load_applied(&self) -> Lsn {
37 Lsn::from_u64(self.applied.load(Ordering::Acquire))
38 }
39
40 #[inline]
43 pub fn update_applied(&self, lsn: Lsn) {
44 let new = lsn.as_u64();
45 let mut cur = self.applied.load(Ordering::Relaxed);
46
47 while new > cur {
48 match self
49 .applied
50 .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51 {
52 Ok(_) => break,
53 Err(observed) => cur = observed,
54 }
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62 KeepAlive {
64 wal_end: Lsn,
66 reply_requested: bool,
68 server_time_micros: i64,
70 },
71
72 Begin {
74 final_lsn: Lsn,
75 xid: u32,
76 commit_time_micros: i64,
77 },
78
79 XLogData {
81 wal_start: Lsn,
83 wal_end: Lsn,
85 server_time_micros: i64,
87 data: Bytes,
89 },
90
91 Commit {
93 lsn: Lsn,
94 end_lsn: Lsn,
95 commit_time_micros: i64,
96 },
97
98 Message {
104 transactional: bool,
106 lsn: Lsn,
108 prefix: String,
110 content: Bytes,
112 },
113
114 StoppedAt {
119 reached: Lsn,
121 },
122}
123
124pub type ReplicationEventReceiver =
126 mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
127
128pub struct WorkerState {
130 cfg: ReplicationConfig,
131 progress: Arc<SharedProgress>,
132 stop_rx: watch::Receiver<bool>,
133 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
134}
135
136impl WorkerState {
137 pub fn new(
138 cfg: ReplicationConfig,
139 progress: Arc<SharedProgress>,
140 stop_rx: watch::Receiver<bool>,
141 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
142 ) -> Self {
143 Self {
144 cfg,
145 progress,
146 stop_rx,
147 out,
148 }
149 }
150
151 pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
153 &mut self,
154 stream: &mut S,
155 ) -> Result<()> {
156 let mut stream = BufReader::with_capacity(128 * 1024, stream);
160 self.startup(&mut stream).await?;
161 self.authenticate(&mut stream).await?;
162 self.start_replication(&mut stream).await?;
163 self.stream_loop(&mut stream).await
164 }
165
166 async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
168 let params = [
169 ("user", self.cfg.user.as_str()),
170 ("database", self.cfg.database.as_str()),
171 ("replication", "database"),
172 ("client_encoding", "UTF8"),
173 ("application_name", "pgwire-replication"),
174 ];
175 write_startup_message(stream, 196608, ¶ms).await
176 }
177
178 async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
180 &self,
181 stream: &mut S,
182 ) -> Result<()> {
183 let publication = self.cfg.publication.replace('\'', "''");
185 let sql = format!(
186 "START_REPLICATION SLOT {} LOGICAL {} \
187 (proto_version '1', publication_names '{}', messages 'true')",
188 self.cfg.slot, self.cfg.start_lsn, publication,
189 );
190 write_query(stream, &sql).await?;
191
192 loop {
194 let msg = read_backend_message(stream).await?;
195 match msg.tag {
196 b'W' => return Ok(()), b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
198 b'N' | b'S' | b'K' => continue, _ => continue,
200 }
201 }
202 }
203
204 async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
215 &mut self,
216 stream: &mut BufReader<S>,
217 ) -> Result<()> {
218 let mut last_status_sent = Instant::now() - self.cfg.status_interval;
219 let mut last_applied = self.progress.load_applied();
220 let mut reader = MessageReader::new();
222 const DRAIN_BATCH: usize = 256;
225
226 loop {
227 let current_applied = self.progress.load_applied();
229 if current_applied != last_applied {
230 last_applied = current_applied;
231 }
232
233 if last_status_sent.elapsed() >= self.cfg.status_interval {
235 self.send_feedback(stream, last_applied, false).await?;
236 last_status_sent = Instant::now();
237 }
238
239 let mut drained = 0usize;
244 while stream.buffer().len() >= 5 && drained < DRAIN_BATCH {
245 let msg = reader.read(stream).await?;
246 drained += 1;
247 if msg.tag == b'E' {
248 return Err(PgWireError::Server(parse_error_response(&msg.payload)));
249 }
250 if msg.tag == b'd'
251 && self
252 .handle_copy_data(
253 stream,
254 msg.payload,
255 &mut last_applied,
256 &mut last_status_sent,
257 )
258 .await?
259 {
260 return Ok(());
261 }
262 }
263
264 if drained > 0 {
267 if self.stop_rx.has_changed().unwrap_or(false) && *self.stop_rx.borrow() {
269 let _ = write_copy_done(stream).await;
270 return Ok(());
271 }
272 continue;
273 }
274
275 let msg = tokio::select! {
282 biased;
283
284 _ = self.stop_rx.changed() => {
285 if *self.stop_rx.borrow() {
286 let _ = write_copy_done(stream).await;
287 return Ok(());
288 }
289 continue;
290 }
291
292 msg_result = tokio::time::timeout(
293 self.cfg.idle_wakeup_interval,
294 reader.read(stream),
295 ) => {
296 match msg_result {
297 Ok(res) => res?,
298 Err(_) => {
299 let applied = self.progress.load_applied();
300 last_applied = applied;
301 self.send_feedback(stream, applied, false).await?;
302 last_status_sent = Instant::now();
303 continue;
304 }
305 }
306 }
307 };
308
309 if msg.tag == b'E' {
310 return Err(PgWireError::Server(parse_error_response(&msg.payload)));
311 }
312 if msg.tag == b'd'
313 && self
314 .handle_copy_data(
315 stream,
316 msg.payload,
317 &mut last_applied,
318 &mut last_status_sent,
319 )
320 .await?
321 {
322 return Ok(());
323 }
324 }
325 }
326
327 async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
329 &mut self,
330 stream: &mut BufReader<S>,
331 payload: Bytes,
332 last_applied: &mut Lsn,
333 last_status_sent: &mut Instant,
334 ) -> Result<bool> {
335 let cd = parse_copy_data(payload)?;
336
337 match cd {
338 ReplicationCopyData::KeepAlive {
339 wal_end,
340 server_time_micros,
341 reply_requested,
342 } => {
343 if reply_requested {
345 let applied = self.progress.load_applied();
346 *last_applied = applied;
347 self.send_feedback(stream, applied, true).await?;
348 *last_status_sent = Instant::now();
349 }
350
351 self.send_event(Ok(ReplicationEvent::KeepAlive {
352 wal_end,
353 reply_requested,
354 server_time_micros,
355 }))
356 .await;
357
358 Ok(false)
359 }
360 ReplicationCopyData::XLogData {
361 wal_start,
362 wal_end,
363 server_time_micros,
364 data,
365 } => {
366 if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
368 let reached_lsn = match boundary_ev {
369 ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
370 ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
371 _ => wal_end, };
373
374 self.send_event(Ok(boundary_ev)).await;
375
376 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
378 if reached_lsn >= stop_lsn {
379 self.send_event(Ok(ReplicationEvent::StoppedAt {
380 reached: reached_lsn,
381 }))
382 .await;
383 let _ = write_copy_done(stream).await;
384 return Ok(true); }
386 }
387
388 return Ok(false);
389 }
390 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
393 if wal_end >= stop_lsn {
394 self.send_event(Ok(ReplicationEvent::XLogData {
396 wal_start,
397 wal_end,
398 server_time_micros,
399 data,
400 }))
401 .await;
402
403 self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
404 .await;
405
406 let _ = write_copy_done(stream).await;
407 return Ok(true);
408 }
409 }
410
411 self.send_event(Ok(ReplicationEvent::XLogData {
412 wal_start,
413 wal_end,
414 server_time_micros,
415 data,
416 }))
417 .await;
418
419 Ok(false)
420 }
421 }
422 }
423
424 async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
429 if self.out.send(event).await.is_err() {
430 tracing::debug!("event channel closed, client may have disconnected");
431 }
432 }
433
434 async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
436 &mut self,
437 stream: &mut S,
438 ) -> Result<()> {
439 loop {
440 let msg = read_backend_message(stream).await?;
441 match msg.tag {
442 b'R' => {
443 let (code, rest) = parse_auth_request(&msg.payload)?;
444 self.handle_auth_request(stream, code, rest).await?;
445 }
446 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
447 b'S' | b'K' => {} b'Z' => return Ok(()), _ => {}
450 }
451 }
452 }
453
454 async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
456 &mut self,
457 stream: &mut S,
458 code: i32,
459 data: &[u8],
460 ) -> Result<()> {
461 match code {
462 0 => Ok(()), 3 => {
464 let mut payload = Vec::from(self.cfg.password.as_bytes());
466 payload.push(0);
467 write_password_message(stream, &payload).await
468 }
469 10 => {
470 self.auth_scram(stream, data).await
472 }
473 #[cfg(feature = "md5")]
474 5 => {
475 if data.len() != 4 {
477 return Err(PgWireError::Protocol(
478 "MD5 auth: expected 4-byte salt".into(),
479 ));
480 }
481 let mut salt = [0u8; 4];
482 salt.copy_from_slice(&data[..4]);
483
484 let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
485 let mut payload = hash.into_bytes();
486 payload.push(0);
487 write_password_message(stream, &payload).await
488 }
489 _ => Err(PgWireError::Auth(format!(
490 "unsupported auth method code: {code}"
491 ))),
492 }
493 }
494
495 async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
497 &mut self,
498 stream: &mut S,
499 mechanisms_data: &[u8],
500 ) -> Result<()> {
501 let mechanisms = parse_sasl_mechanisms(mechanisms_data);
503
504 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
505 return Err(PgWireError::Auth(format!(
506 "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
507 )));
508 }
509
510 #[cfg(not(feature = "scram"))]
511 return Err(PgWireError::Auth(
512 "SCRAM authentication required but 'scram' feature not enabled".into(),
513 ));
514
515 #[cfg(feature = "scram")]
516 {
517 use crate::auth::scram::ScramClient;
518
519 let scram = ScramClient::new(&self.cfg.user);
520
521 let mut init = Vec::new();
523 init.extend_from_slice(b"SCRAM-SHA-256\0");
524 init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
525 init.extend_from_slice(scram.client_first.as_bytes());
526 write_password_message(stream, &init).await?;
527
528 let server_first = read_auth_data(stream, 11).await?;
530 let server_first_str = String::from_utf8_lossy(&server_first);
531
532 let (client_final, auth_message, salted_password) =
534 scram.client_final(&self.cfg.password, &server_first_str)?;
535 write_password_message(stream, client_final.as_bytes()).await?;
536
537 let server_final = read_auth_data(stream, 12).await?;
539 let server_final_str = String::from_utf8_lossy(&server_final);
540 ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
541
542 Ok(())
543 }
544 }
545
546 async fn send_feedback<S: AsyncWrite + Unpin>(
548 &self,
549 stream: &mut S,
550 applied: Lsn,
551 reply_requested: bool,
552 ) -> Result<()> {
553 let client_time = current_pg_timestamp();
554 let payload = encode_standby_status_update(applied, client_time, reply_requested);
555 write_copy_data(stream, &payload).await
556 }
557}
558
559fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
561 let mut mechanisms = Vec::new();
562 let mut remaining = data;
563
564 while !remaining.is_empty() {
565 if let Some(pos) = remaining.iter().position(|&x| x == 0) {
566 if pos == 0 {
567 break; }
569 mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
570 remaining = &remaining[pos + 1..];
571 } else {
572 break;
573 }
574 }
575
576 mechanisms
577}
578
579fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
580 if data.is_empty() {
581 return Ok(None);
582 }
583
584 let tag = data[0];
585 let mut p = &data[1..];
586
587 fn take_i8(p: &mut &[u8]) -> Result<i8> {
588 if p.is_empty() {
589 return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
590 }
591 let v = p[0] as i8;
592 *p = &p[1..];
593 Ok(v)
594 }
595
596 fn take_i32(p: &mut &[u8]) -> Result<i32> {
597 if p.len() < 4 {
598 return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
599 }
600 let (head, tail) = p.split_at(4);
601 *p = tail;
602 Ok(i32::from_be_bytes(head.try_into().unwrap()))
603 }
604
605 fn take_i64(p: &mut &[u8]) -> Result<i64> {
606 if p.len() < 8 {
607 return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
608 }
609 let (head, tail) = p.split_at(8);
610 *p = tail;
611 Ok(i64::from_be_bytes(head.try_into().unwrap()))
612 }
613
614 match tag {
615 b'B' => {
616 let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
617 let commit_time_micros = take_i64(&mut p)?;
618 let xid = take_i32(&mut p)? as u32;
619
620 Ok(Some(ReplicationEvent::Begin {
621 final_lsn,
622 commit_time_micros,
623 xid,
624 }))
625 }
626 b'C' => {
627 let _flags = take_i8(&mut p)?;
628 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
630 let commit_time_micros = take_i64(&mut p)?;
631
632 Ok(Some(ReplicationEvent::Commit {
633 lsn,
634 end_lsn,
635 commit_time_micros,
636 }))
637 }
638 b'M' => {
639 let flags = take_i8(&mut p)?;
642 let transactional = (flags & 1) != 0;
643 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
644
645 let prefix_end = p.iter().position(|&b| b == 0).ok_or_else(|| {
647 PgWireError::Protocol("pgoutput Message: missing null terminator for prefix".into())
648 })?;
649 let prefix = String::from_utf8_lossy(&p[..prefix_end]).into_owned();
650 p = &p[prefix_end + 1..]; let content_len = take_i32(&mut p)? as usize;
653 if p.len() < content_len {
654 return Err(PgWireError::Protocol(format!(
655 "pgoutput Message: expected {} content bytes, got {}",
656 content_len,
657 p.len()
658 )));
659 }
660 let content = Bytes::copy_from_slice(&p[..content_len]);
661
662 Ok(Some(ReplicationEvent::Message {
663 transactional,
664 lsn,
665 prefix,
666 content,
667 }))
668 }
669 _ => Ok(None),
670 }
671}
672
673async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
675 stream: &mut S,
676 expected_code: i32,
677) -> Result<Vec<u8>> {
678 loop {
679 let msg = read_backend_message(stream).await?;
680 match msg.tag {
681 b'R' => {
682 let (code, data) = parse_auth_request(&msg.payload)?;
683 if code == expected_code {
684 return Ok(data.to_vec());
685 }
686 return Err(PgWireError::Auth(format!(
687 "unexpected auth code {code}, expected {expected_code}"
688 )));
689 }
690 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
691 _ => {} }
693 }
694}
695
696fn current_pg_timestamp() -> i64 {
698 use std::time::{SystemTime, UNIX_EPOCH};
699
700 let now = SystemTime::now()
701 .duration_since(UNIX_EPOCH)
702 .unwrap_or_default();
703
704 let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
705 unix_micros - PG_EPOCH_MICROS
706}
707
708#[cfg(feature = "md5")]
710fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
711 fn md5_hex(data: &[u8]) -> String {
712 format!("{:x}", md5::compute(data))
713 }
714
715 let inner = md5_hex(format!("{password}{user}").as_bytes());
717
718 let mut outer_input = inner.into_bytes();
720 outer_input.extend_from_slice(salt);
721
722 format!("md5{}", md5_hex(&outer_input))
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728
729 #[test]
730 fn parse_sasl_mechanisms_single() {
731 let data = b"SCRAM-SHA-256\0\0";
732 let mechs = parse_sasl_mechanisms(data);
733 assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
734 }
735
736 #[test]
737 fn parse_sasl_mechanisms_multiple() {
738 let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
739 let mechs = parse_sasl_mechanisms(data);
740 assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
741 }
742
743 #[test]
744 fn parse_sasl_mechanisms_empty() {
745 let mechs = parse_sasl_mechanisms(b"\0");
746 assert!(mechs.is_empty());
747 }
748
749 #[test]
750 #[cfg(feature = "md5")]
751 fn postgres_md5_known_value() {
752 let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
755 assert!(hash.starts_with("md5"));
756 assert_eq!(hash.len(), 35); }
758
759 #[test]
760 fn current_pg_timestamp_is_positive() {
761 let ts = current_pg_timestamp();
763 assert!(ts > 0);
764 }
765}