1use bytes::{Bytes, BytesMut};
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite, BufReader};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12 read_backend_message, read_backend_message_into, write_copy_data, write_copy_done,
13 write_password_message, write_query, write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17 encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20pub struct SharedProgress {
25 applied: AtomicU64,
26}
27
28impl SharedProgress {
29 pub fn new(start: Lsn) -> Self {
30 Self {
31 applied: AtomicU64::new(start.as_u64()),
32 }
33 }
34
35 #[inline]
36 pub fn load_applied(&self) -> Lsn {
37 Lsn::from_u64(self.applied.load(Ordering::Acquire))
38 }
39
40 #[inline]
43 pub fn update_applied(&self, lsn: Lsn) {
44 let new = lsn.as_u64();
45 let mut cur = self.applied.load(Ordering::Relaxed);
46
47 while new > cur {
48 match self
49 .applied
50 .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51 {
52 Ok(_) => break,
53 Err(observed) => cur = observed,
54 }
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62 KeepAlive {
64 wal_end: Lsn,
66 reply_requested: bool,
68 server_time_micros: i64,
70 },
71
72 Begin {
74 final_lsn: Lsn,
75 xid: u32,
76 commit_time_micros: i64,
77 },
78
79 XLogData {
81 wal_start: Lsn,
83 wal_end: Lsn,
85 server_time_micros: i64,
87 data: Bytes,
89 },
90
91 Commit {
93 lsn: Lsn,
94 end_lsn: Lsn,
95 commit_time_micros: i64,
96 },
97
98 Message {
104 transactional: bool,
106 lsn: Lsn,
108 prefix: String,
110 content: Bytes,
112 },
113
114 StoppedAt {
119 reached: Lsn,
121 },
122}
123
124pub type ReplicationEventReceiver =
126 mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
127
128pub struct WorkerState {
130 cfg: ReplicationConfig,
131 progress: Arc<SharedProgress>,
132 stop_rx: watch::Receiver<bool>,
133 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
134}
135
136impl WorkerState {
137 pub fn new(
138 cfg: ReplicationConfig,
139 progress: Arc<SharedProgress>,
140 stop_rx: watch::Receiver<bool>,
141 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
142 ) -> Self {
143 Self {
144 cfg,
145 progress,
146 stop_rx,
147 out,
148 }
149 }
150
151 pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
153 &mut self,
154 stream: &mut S,
155 ) -> Result<()> {
156 let mut stream = BufReader::with_capacity(128 * 1024, stream);
160 self.startup(&mut stream).await?;
161 self.authenticate(&mut stream).await?;
162 self.start_replication(&mut stream).await?;
163 self.stream_loop(&mut stream).await
164 }
165
166 async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
168 let params = [
169 ("user", self.cfg.user.as_str()),
170 ("database", self.cfg.database.as_str()),
171 ("replication", "database"),
172 ("client_encoding", "UTF8"),
173 ("application_name", "pgwire-replication"),
174 ];
175 write_startup_message(stream, 196608, ¶ms).await
176 }
177
178 async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
180 &self,
181 stream: &mut S,
182 ) -> Result<()> {
183 let publication = self.cfg.publication.replace('\'', "''");
185 let sql = format!(
186 "START_REPLICATION SLOT {} LOGICAL {} \
187 (proto_version '1', publication_names '{}', messages 'true')",
188 self.cfg.slot, self.cfg.start_lsn, publication,
189 );
190 write_query(stream, &sql).await?;
191
192 loop {
194 let msg = read_backend_message(stream).await?;
195 match msg.tag {
196 b'W' => return Ok(()), b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
198 b'N' | b'S' | b'K' => continue, _ => continue,
200 }
201 }
202 }
203
204 async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
212 &mut self,
213 stream: &mut BufReader<S>,
214 ) -> Result<()> {
215 let mut last_status_sent = Instant::now() - self.cfg.status_interval;
216 let mut last_applied = self.progress.load_applied();
217 let mut read_buf = BytesMut::with_capacity(4096);
219 const DRAIN_BATCH: usize = 256;
222
223 loop {
224 let current_applied = self.progress.load_applied();
226 if current_applied != last_applied {
227 last_applied = current_applied;
228 }
229
230 if last_status_sent.elapsed() >= self.cfg.status_interval {
232 self.send_feedback(stream, last_applied, false).await?;
233 last_status_sent = Instant::now();
234 }
235
236 let mut drained = 0usize;
241 while stream.buffer().len() >= 5 && drained < DRAIN_BATCH {
242 let msg = read_backend_message_into(stream, &mut read_buf).await?;
243 drained += 1;
244 match msg.tag {
245 b'd' => {
246 if self
247 .handle_copy_data(
248 stream,
249 msg.payload,
250 &mut last_applied,
251 &mut last_status_sent,
252 )
253 .await?
254 {
255 return Ok(());
256 }
257 }
258 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
259 _ => {}
260 }
261 }
262
263 if drained > 0 {
266 if self.stop_rx.has_changed().unwrap_or(false) && *self.stop_rx.borrow() {
268 let _ = write_copy_done(stream).await;
269 return Ok(());
270 }
271 continue;
272 }
273
274 let msg = tokio::select! {
276 biased;
277
278 _ = self.stop_rx.changed() => {
279 if *self.stop_rx.borrow() {
280 let _ = write_copy_done(stream).await;
281 return Ok(());
282 }
283 continue;
284 }
285
286 msg_result = tokio::time::timeout(
287 self.cfg.idle_wakeup_interval,
288 read_backend_message_into(stream, &mut read_buf),
289 ) => {
290 match msg_result {
291 Ok(res) => res?,
292 Err(_) => {
293 let applied = self.progress.load_applied();
294 last_applied = applied;
295 self.send_feedback(stream, applied, false).await?;
296 last_status_sent = Instant::now();
297 continue;
298 }
299 }
300 }
301 };
302
303 match msg.tag {
304 b'd' => {
305 if self
306 .handle_copy_data(
307 stream,
308 msg.payload,
309 &mut last_applied,
310 &mut last_status_sent,
311 )
312 .await?
313 {
314 return Ok(());
315 }
316 }
317 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
318 _ => {}
319 }
320 }
321 }
322
323 async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
325 &mut self,
326 stream: &mut BufReader<S>,
327 payload: Bytes,
328 last_applied: &mut Lsn,
329 last_status_sent: &mut Instant,
330 ) -> Result<bool> {
331 let cd = parse_copy_data(payload)?;
332
333 match cd {
334 ReplicationCopyData::KeepAlive {
335 wal_end,
336 server_time_micros,
337 reply_requested,
338 } => {
339 if reply_requested {
341 let applied = self.progress.load_applied();
342 *last_applied = applied;
343 self.send_feedback(stream, applied, true).await?;
344 *last_status_sent = Instant::now();
345 }
346
347 self.send_event(Ok(ReplicationEvent::KeepAlive {
348 wal_end,
349 reply_requested,
350 server_time_micros,
351 }))
352 .await;
353
354 Ok(false)
355 }
356 ReplicationCopyData::XLogData {
357 wal_start,
358 wal_end,
359 server_time_micros,
360 data,
361 } => {
362 if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
364 let reached_lsn = match boundary_ev {
365 ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
366 ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
367 _ => wal_end, };
369
370 self.send_event(Ok(boundary_ev)).await;
371
372 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
374 if reached_lsn >= stop_lsn {
375 self.send_event(Ok(ReplicationEvent::StoppedAt {
376 reached: reached_lsn,
377 }))
378 .await;
379 let _ = write_copy_done(stream).await;
380 return Ok(true); }
382 }
383
384 return Ok(false);
385 }
386 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
389 if wal_end >= stop_lsn {
390 self.send_event(Ok(ReplicationEvent::XLogData {
392 wal_start,
393 wal_end,
394 server_time_micros,
395 data,
396 }))
397 .await;
398
399 self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
400 .await;
401
402 let _ = write_copy_done(stream).await;
403 return Ok(true);
404 }
405 }
406
407 self.send_event(Ok(ReplicationEvent::XLogData {
408 wal_start,
409 wal_end,
410 server_time_micros,
411 data,
412 }))
413 .await;
414
415 Ok(false)
416 }
417 }
418 }
419
420 async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
425 if self.out.send(event).await.is_err() {
426 tracing::debug!("event channel closed, client may have disconnected");
427 }
428 }
429
430 async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
432 &mut self,
433 stream: &mut S,
434 ) -> Result<()> {
435 loop {
436 let msg = read_backend_message(stream).await?;
437 match msg.tag {
438 b'R' => {
439 let (code, rest) = parse_auth_request(&msg.payload)?;
440 self.handle_auth_request(stream, code, rest).await?;
441 }
442 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
443 b'S' | b'K' => {} b'Z' => return Ok(()), _ => {}
446 }
447 }
448 }
449
450 async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
452 &mut self,
453 stream: &mut S,
454 code: i32,
455 data: &[u8],
456 ) -> Result<()> {
457 match code {
458 0 => Ok(()), 3 => {
460 let mut payload = Vec::from(self.cfg.password.as_bytes());
462 payload.push(0);
463 write_password_message(stream, &payload).await
464 }
465 10 => {
466 self.auth_scram(stream, data).await
468 }
469 #[cfg(feature = "md5")]
470 5 => {
471 if data.len() != 4 {
473 return Err(PgWireError::Protocol(
474 "MD5 auth: expected 4-byte salt".into(),
475 ));
476 }
477 let mut salt = [0u8; 4];
478 salt.copy_from_slice(&data[..4]);
479
480 let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
481 let mut payload = hash.into_bytes();
482 payload.push(0);
483 write_password_message(stream, &payload).await
484 }
485 _ => Err(PgWireError::Auth(format!(
486 "unsupported auth method code: {code}"
487 ))),
488 }
489 }
490
491 async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
493 &mut self,
494 stream: &mut S,
495 mechanisms_data: &[u8],
496 ) -> Result<()> {
497 let mechanisms = parse_sasl_mechanisms(mechanisms_data);
499
500 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
501 return Err(PgWireError::Auth(format!(
502 "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
503 )));
504 }
505
506 #[cfg(not(feature = "scram"))]
507 return Err(PgWireError::Auth(
508 "SCRAM authentication required but 'scram' feature not enabled".into(),
509 ));
510
511 #[cfg(feature = "scram")]
512 {
513 use crate::auth::scram::ScramClient;
514
515 let scram = ScramClient::new(&self.cfg.user);
516
517 let mut init = Vec::new();
519 init.extend_from_slice(b"SCRAM-SHA-256\0");
520 init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
521 init.extend_from_slice(scram.client_first.as_bytes());
522 write_password_message(stream, &init).await?;
523
524 let server_first = read_auth_data(stream, 11).await?;
526 let server_first_str = String::from_utf8_lossy(&server_first);
527
528 let (client_final, auth_message, salted_password) =
530 scram.client_final(&self.cfg.password, &server_first_str)?;
531 write_password_message(stream, client_final.as_bytes()).await?;
532
533 let server_final = read_auth_data(stream, 12).await?;
535 let server_final_str = String::from_utf8_lossy(&server_final);
536 ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
537
538 Ok(())
539 }
540 }
541
542 async fn send_feedback<S: AsyncWrite + Unpin>(
544 &self,
545 stream: &mut S,
546 applied: Lsn,
547 reply_requested: bool,
548 ) -> Result<()> {
549 let client_time = current_pg_timestamp();
550 let payload = encode_standby_status_update(applied, client_time, reply_requested);
551 write_copy_data(stream, &payload).await
552 }
553}
554
555fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
557 let mut mechanisms = Vec::new();
558 let mut remaining = data;
559
560 while !remaining.is_empty() {
561 if let Some(pos) = remaining.iter().position(|&x| x == 0) {
562 if pos == 0 {
563 break; }
565 mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
566 remaining = &remaining[pos + 1..];
567 } else {
568 break;
569 }
570 }
571
572 mechanisms
573}
574
575fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
576 if data.is_empty() {
577 return Ok(None);
578 }
579
580 let tag = data[0];
581 let mut p = &data[1..];
582
583 fn take_i8(p: &mut &[u8]) -> Result<i8> {
584 if p.is_empty() {
585 return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
586 }
587 let v = p[0] as i8;
588 *p = &p[1..];
589 Ok(v)
590 }
591
592 fn take_i32(p: &mut &[u8]) -> Result<i32> {
593 if p.len() < 4 {
594 return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
595 }
596 let (head, tail) = p.split_at(4);
597 *p = tail;
598 Ok(i32::from_be_bytes(head.try_into().unwrap()))
599 }
600
601 fn take_i64(p: &mut &[u8]) -> Result<i64> {
602 if p.len() < 8 {
603 return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
604 }
605 let (head, tail) = p.split_at(8);
606 *p = tail;
607 Ok(i64::from_be_bytes(head.try_into().unwrap()))
608 }
609
610 match tag {
611 b'B' => {
612 let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
613 let commit_time_micros = take_i64(&mut p)?;
614 let xid = take_i32(&mut p)? as u32;
615
616 Ok(Some(ReplicationEvent::Begin {
617 final_lsn,
618 commit_time_micros,
619 xid,
620 }))
621 }
622 b'C' => {
623 let _flags = take_i8(&mut p)?;
624 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
626 let commit_time_micros = take_i64(&mut p)?;
627
628 Ok(Some(ReplicationEvent::Commit {
629 lsn,
630 end_lsn,
631 commit_time_micros,
632 }))
633 }
634 b'M' => {
635 let flags = take_i8(&mut p)?;
638 let transactional = (flags & 1) != 0;
639 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
640
641 let prefix_end = p.iter().position(|&b| b == 0).ok_or_else(|| {
643 PgWireError::Protocol("pgoutput Message: missing null terminator for prefix".into())
644 })?;
645 let prefix = String::from_utf8_lossy(&p[..prefix_end]).into_owned();
646 p = &p[prefix_end + 1..]; let content_len = take_i32(&mut p)? as usize;
649 if p.len() < content_len {
650 return Err(PgWireError::Protocol(format!(
651 "pgoutput Message: expected {} content bytes, got {}",
652 content_len,
653 p.len()
654 )));
655 }
656 let content = Bytes::copy_from_slice(&p[..content_len]);
657
658 Ok(Some(ReplicationEvent::Message {
659 transactional,
660 lsn,
661 prefix,
662 content,
663 }))
664 }
665 _ => Ok(None),
666 }
667}
668
669async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
671 stream: &mut S,
672 expected_code: i32,
673) -> Result<Vec<u8>> {
674 loop {
675 let msg = read_backend_message(stream).await?;
676 match msg.tag {
677 b'R' => {
678 let (code, data) = parse_auth_request(&msg.payload)?;
679 if code == expected_code {
680 return Ok(data.to_vec());
681 }
682 return Err(PgWireError::Auth(format!(
683 "unexpected auth code {code}, expected {expected_code}"
684 )));
685 }
686 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
687 _ => {} }
689 }
690}
691
692fn current_pg_timestamp() -> i64 {
694 use std::time::{SystemTime, UNIX_EPOCH};
695
696 let now = SystemTime::now()
697 .duration_since(UNIX_EPOCH)
698 .unwrap_or_default();
699
700 let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
701 unix_micros - PG_EPOCH_MICROS
702}
703
704#[cfg(feature = "md5")]
706fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
707 fn md5_hex(data: &[u8]) -> String {
708 format!("{:x}", md5::compute(data))
709 }
710
711 let inner = md5_hex(format!("{password}{user}").as_bytes());
713
714 let mut outer_input = inner.into_bytes();
716 outer_input.extend_from_slice(salt);
717
718 format!("md5{}", md5_hex(&outer_input))
719}
720
721#[cfg(test)]
722mod tests {
723 use super::*;
724
725 #[test]
726 fn parse_sasl_mechanisms_single() {
727 let data = b"SCRAM-SHA-256\0\0";
728 let mechs = parse_sasl_mechanisms(data);
729 assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
730 }
731
732 #[test]
733 fn parse_sasl_mechanisms_multiple() {
734 let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
735 let mechs = parse_sasl_mechanisms(data);
736 assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
737 }
738
739 #[test]
740 fn parse_sasl_mechanisms_empty() {
741 let mechs = parse_sasl_mechanisms(b"\0");
742 assert!(mechs.is_empty());
743 }
744
745 #[test]
746 #[cfg(feature = "md5")]
747 fn postgres_md5_known_value() {
748 let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
751 assert!(hash.starts_with("md5"));
752 assert_eq!(hash.len(), 35); }
754
755 #[test]
756 fn current_pg_timestamp_is_positive() {
757 let ts = current_pg_timestamp();
759 assert!(ts > 0);
760 }
761}