1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite, BufReader};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12 read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13 write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17 encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20pub struct SharedProgress {
25 applied: AtomicU64,
26}
27
28impl SharedProgress {
29 pub fn new(start: Lsn) -> Self {
30 Self {
31 applied: AtomicU64::new(start.as_u64()),
32 }
33 }
34
35 #[inline]
36 pub fn load_applied(&self) -> Lsn {
37 Lsn::from_u64(self.applied.load(Ordering::Acquire))
38 }
39
40 #[inline]
43 pub fn update_applied(&self, lsn: Lsn) {
44 let new = lsn.as_u64();
45 let mut cur = self.applied.load(Ordering::Relaxed);
46
47 while new > cur {
48 match self
49 .applied
50 .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51 {
52 Ok(_) => break,
53 Err(observed) => cur = observed,
54 }
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62 KeepAlive {
64 wal_end: Lsn,
66 reply_requested: bool,
68 server_time_micros: i64,
70 },
71
72 Begin {
74 final_lsn: Lsn,
75 xid: u32,
76 commit_time_micros: i64,
77 },
78
79 XLogData {
81 wal_start: Lsn,
83 wal_end: Lsn,
85 server_time_micros: i64,
87 data: Bytes,
89 },
90
91 Commit {
93 lsn: Lsn,
94 end_lsn: Lsn,
95 commit_time_micros: i64,
96 },
97
98 Message {
104 transactional: bool,
106 lsn: Lsn,
108 prefix: String,
110 content: Bytes,
112 },
113
114 StoppedAt {
119 reached: Lsn,
121 },
122}
123
124pub type ReplicationEventReceiver =
126 mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
127
128pub struct WorkerState {
130 cfg: ReplicationConfig,
131 progress: Arc<SharedProgress>,
132 stop_rx: watch::Receiver<bool>,
133 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
134}
135
136impl WorkerState {
137 pub fn new(
138 cfg: ReplicationConfig,
139 progress: Arc<SharedProgress>,
140 stop_rx: watch::Receiver<bool>,
141 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
142 ) -> Self {
143 Self {
144 cfg,
145 progress,
146 stop_rx,
147 out,
148 }
149 }
150
151 pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
153 &mut self,
154 stream: &mut S,
155 ) -> Result<()> {
156 let mut stream = BufReader::with_capacity(128 * 1024, stream);
160 self.startup(&mut stream).await?;
161 self.authenticate(&mut stream).await?;
162 self.start_replication(&mut stream).await?;
163 self.stream_loop(&mut stream).await
164 }
165
166 async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
168 let params = [
169 ("user", self.cfg.user.as_str()),
170 ("database", self.cfg.database.as_str()),
171 ("replication", "database"),
172 ("client_encoding", "UTF8"),
173 ("application_name", "pgwire-replication"),
174 ];
175 write_startup_message(stream, 196608, ¶ms).await
176 }
177
178 async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
180 &self,
181 stream: &mut S,
182 ) -> Result<()> {
183 let publication = self.cfg.publication.replace('\'', "''");
185 let sql = format!(
186 "START_REPLICATION SLOT {} LOGICAL {} \
187 (proto_version '1', publication_names '{}', messages 'true')",
188 self.cfg.slot, self.cfg.start_lsn, publication,
189 );
190 write_query(stream, &sql).await?;
191
192 loop {
194 let msg = read_backend_message(stream).await?;
195 match msg.tag {
196 b'W' => return Ok(()), b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
198 b'N' | b'S' | b'K' => continue, _ => continue,
200 }
201 }
202 }
203
204 async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
206 &mut self,
207 stream: &mut S,
208 ) -> Result<()> {
209 let mut last_status_sent = Instant::now() - self.cfg.status_interval;
210 let mut last_applied = self.progress.load_applied();
211
212 loop {
213 let current_applied = self.progress.load_applied();
215 if current_applied != last_applied {
216 last_applied = current_applied;
217 }
218
219 if last_status_sent.elapsed() >= self.cfg.status_interval {
221 self.send_feedback(stream, last_applied, false).await?;
222 last_status_sent = Instant::now();
223 }
224
225 let msg = tokio::select! {
229 biased; _ = self.stop_rx.changed() => {
232 if *self.stop_rx.borrow() {
233 let _ = write_copy_done(stream).await;
234 return Ok(());
235 }
236 continue;
238 }
239
240 msg_result = tokio::time::timeout(
241 self.cfg.idle_wakeup_interval,
242 read_backend_message(stream),
243 ) => {
244 match msg_result {
245 Ok(res) => res?, Err(_) => {
247 let applied = self.progress.load_applied();
249 last_applied = applied;
250 self.send_feedback(stream, applied, false).await?;
251 last_status_sent = Instant::now();
252 continue;
253 }
254 }
255 }
256 };
257
258 match msg.tag {
259 b'd' => {
260 let should_stop = self
261 .handle_copy_data(
262 stream,
263 msg.payload,
264 &mut last_applied,
265 &mut last_status_sent,
266 )
267 .await?;
268 if should_stop {
269 return Ok(());
270 }
271 }
272 b'E' => {
273 let err = PgWireError::Server(parse_error_response(&msg.payload));
274 return Err(err);
275 }
276 _ => {
277 }
279 }
280 }
281 }
282
283 async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
285 &mut self,
286 stream: &mut S,
287 payload: Bytes,
288 last_applied: &mut Lsn,
289 last_status_sent: &mut Instant,
290 ) -> Result<bool> {
291 let cd = parse_copy_data(payload)?;
292
293 match cd {
294 ReplicationCopyData::KeepAlive {
295 wal_end,
296 server_time_micros,
297 reply_requested,
298 } => {
299 if reply_requested {
301 let applied = self.progress.load_applied();
302 *last_applied = applied;
303 self.send_feedback(stream, applied, true).await?;
304 *last_status_sent = Instant::now();
305 }
306
307 self.send_event(Ok(ReplicationEvent::KeepAlive {
308 wal_end,
309 reply_requested,
310 server_time_micros,
311 }))
312 .await;
313
314 Ok(false)
315 }
316 ReplicationCopyData::XLogData {
317 wal_start,
318 wal_end,
319 server_time_micros,
320 data,
321 } => {
322 if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
324 let reached_lsn = match boundary_ev {
325 ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
326 ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
327 _ => wal_end, };
329
330 self.send_event(Ok(boundary_ev)).await;
331
332 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
334 if reached_lsn >= stop_lsn {
335 self.send_event(Ok(ReplicationEvent::StoppedAt {
336 reached: reached_lsn,
337 }))
338 .await;
339 let _ = write_copy_done(stream).await;
340 return Ok(true); }
342 }
343
344 return Ok(false);
345 }
346 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
349 if wal_end >= stop_lsn {
350 self.send_event(Ok(ReplicationEvent::XLogData {
352 wal_start,
353 wal_end,
354 server_time_micros,
355 data,
356 }))
357 .await;
358
359 self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
360 .await;
361
362 let _ = write_copy_done(stream).await;
363 return Ok(true);
364 }
365 }
366
367 self.send_event(Ok(ReplicationEvent::XLogData {
368 wal_start,
369 wal_end,
370 server_time_micros,
371 data,
372 }))
373 .await;
374
375 Ok(false)
376 }
377 }
378 }
379
380 async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
385 if self.out.send(event).await.is_err() {
386 tracing::debug!("event channel closed, client may have disconnected");
387 }
388 }
389
390 async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
392 &mut self,
393 stream: &mut S,
394 ) -> Result<()> {
395 loop {
396 let msg = read_backend_message(stream).await?;
397 match msg.tag {
398 b'R' => {
399 let (code, rest) = parse_auth_request(&msg.payload)?;
400 self.handle_auth_request(stream, code, rest).await?;
401 }
402 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
403 b'S' | b'K' => {} b'Z' => return Ok(()), _ => {}
406 }
407 }
408 }
409
410 async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
412 &mut self,
413 stream: &mut S,
414 code: i32,
415 data: &[u8],
416 ) -> Result<()> {
417 match code {
418 0 => Ok(()), 3 => {
420 let mut payload = Vec::from(self.cfg.password.as_bytes());
422 payload.push(0);
423 write_password_message(stream, &payload).await
424 }
425 10 => {
426 self.auth_scram(stream, data).await
428 }
429 #[cfg(feature = "md5")]
430 5 => {
431 if data.len() != 4 {
433 return Err(PgWireError::Protocol(
434 "MD5 auth: expected 4-byte salt".into(),
435 ));
436 }
437 let mut salt = [0u8; 4];
438 salt.copy_from_slice(&data[..4]);
439
440 let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
441 let mut payload = hash.into_bytes();
442 payload.push(0);
443 write_password_message(stream, &payload).await
444 }
445 _ => Err(PgWireError::Auth(format!(
446 "unsupported auth method code: {code}"
447 ))),
448 }
449 }
450
451 async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
453 &mut self,
454 stream: &mut S,
455 mechanisms_data: &[u8],
456 ) -> Result<()> {
457 let mechanisms = parse_sasl_mechanisms(mechanisms_data);
459
460 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
461 return Err(PgWireError::Auth(format!(
462 "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
463 )));
464 }
465
466 #[cfg(not(feature = "scram"))]
467 return Err(PgWireError::Auth(
468 "SCRAM authentication required but 'scram' feature not enabled".into(),
469 ));
470
471 #[cfg(feature = "scram")]
472 {
473 use crate::auth::scram::ScramClient;
474
475 let scram = ScramClient::new(&self.cfg.user);
476
477 let mut init = Vec::new();
479 init.extend_from_slice(b"SCRAM-SHA-256\0");
480 init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
481 init.extend_from_slice(scram.client_first.as_bytes());
482 write_password_message(stream, &init).await?;
483
484 let server_first = read_auth_data(stream, 11).await?;
486 let server_first_str = String::from_utf8_lossy(&server_first);
487
488 let (client_final, auth_message, salted_password) =
490 scram.client_final(&self.cfg.password, &server_first_str)?;
491 write_password_message(stream, client_final.as_bytes()).await?;
492
493 let server_final = read_auth_data(stream, 12).await?;
495 let server_final_str = String::from_utf8_lossy(&server_final);
496 ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
497
498 Ok(())
499 }
500 }
501
502 async fn send_feedback<S: AsyncWrite + Unpin>(
504 &self,
505 stream: &mut S,
506 applied: Lsn,
507 reply_requested: bool,
508 ) -> Result<()> {
509 let client_time = current_pg_timestamp();
510 let payload = encode_standby_status_update(applied, client_time, reply_requested);
511 write_copy_data(stream, &payload).await
512 }
513}
514
515fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
517 let mut mechanisms = Vec::new();
518 let mut remaining = data;
519
520 while !remaining.is_empty() {
521 if let Some(pos) = remaining.iter().position(|&x| x == 0) {
522 if pos == 0 {
523 break; }
525 mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
526 remaining = &remaining[pos + 1..];
527 } else {
528 break;
529 }
530 }
531
532 mechanisms
533}
534
535fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
536 if data.is_empty() {
537 return Ok(None);
538 }
539
540 let tag = data[0];
541 let mut p = &data[1..];
542
543 fn take_i8(p: &mut &[u8]) -> Result<i8> {
544 if p.is_empty() {
545 return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
546 }
547 let v = p[0] as i8;
548 *p = &p[1..];
549 Ok(v)
550 }
551
552 fn take_i32(p: &mut &[u8]) -> Result<i32> {
553 if p.len() < 4 {
554 return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
555 }
556 let (head, tail) = p.split_at(4);
557 *p = tail;
558 Ok(i32::from_be_bytes(head.try_into().unwrap()))
559 }
560
561 fn take_i64(p: &mut &[u8]) -> Result<i64> {
562 if p.len() < 8 {
563 return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
564 }
565 let (head, tail) = p.split_at(8);
566 *p = tail;
567 Ok(i64::from_be_bytes(head.try_into().unwrap()))
568 }
569
570 match tag {
571 b'B' => {
572 let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
573 let commit_time_micros = take_i64(&mut p)?;
574 let xid = take_i32(&mut p)? as u32;
575
576 Ok(Some(ReplicationEvent::Begin {
577 final_lsn,
578 commit_time_micros,
579 xid,
580 }))
581 }
582 b'C' => {
583 let _flags = take_i8(&mut p)?;
584 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
586 let commit_time_micros = take_i64(&mut p)?;
587
588 Ok(Some(ReplicationEvent::Commit {
589 lsn,
590 end_lsn,
591 commit_time_micros,
592 }))
593 }
594 b'M' => {
595 let flags = take_i8(&mut p)?;
598 let transactional = (flags & 1) != 0;
599 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
600
601 let prefix_end = p.iter().position(|&b| b == 0).ok_or_else(|| {
603 PgWireError::Protocol("pgoutput Message: missing null terminator for prefix".into())
604 })?;
605 let prefix = String::from_utf8_lossy(&p[..prefix_end]).into_owned();
606 p = &p[prefix_end + 1..]; let content_len = take_i32(&mut p)? as usize;
609 if p.len() < content_len {
610 return Err(PgWireError::Protocol(format!(
611 "pgoutput Message: expected {} content bytes, got {}",
612 content_len,
613 p.len()
614 )));
615 }
616 let content = Bytes::copy_from_slice(&p[..content_len]);
617
618 Ok(Some(ReplicationEvent::Message {
619 transactional,
620 lsn,
621 prefix,
622 content,
623 }))
624 }
625 _ => Ok(None),
626 }
627}
628
629async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
631 stream: &mut S,
632 expected_code: i32,
633) -> Result<Vec<u8>> {
634 loop {
635 let msg = read_backend_message(stream).await?;
636 match msg.tag {
637 b'R' => {
638 let (code, data) = parse_auth_request(&msg.payload)?;
639 if code == expected_code {
640 return Ok(data.to_vec());
641 }
642 return Err(PgWireError::Auth(format!(
643 "unexpected auth code {code}, expected {expected_code}"
644 )));
645 }
646 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
647 _ => {} }
649 }
650}
651
652fn current_pg_timestamp() -> i64 {
654 use std::time::{SystemTime, UNIX_EPOCH};
655
656 let now = SystemTime::now()
657 .duration_since(UNIX_EPOCH)
658 .unwrap_or_default();
659
660 let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
661 unix_micros - PG_EPOCH_MICROS
662}
663
664#[cfg(feature = "md5")]
666fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
667 fn md5_hex(data: &[u8]) -> String {
668 format!("{:x}", md5::compute(data))
669 }
670
671 let inner = md5_hex(format!("{password}{user}").as_bytes());
673
674 let mut outer_input = inner.into_bytes();
676 outer_input.extend_from_slice(salt);
677
678 format!("md5{}", md5_hex(&outer_input))
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684
685 #[test]
686 fn parse_sasl_mechanisms_single() {
687 let data = b"SCRAM-SHA-256\0\0";
688 let mechs = parse_sasl_mechanisms(data);
689 assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
690 }
691
692 #[test]
693 fn parse_sasl_mechanisms_multiple() {
694 let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
695 let mechs = parse_sasl_mechanisms(data);
696 assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
697 }
698
699 #[test]
700 fn parse_sasl_mechanisms_empty() {
701 let mechs = parse_sasl_mechanisms(b"\0");
702 assert!(mechs.is_empty());
703 }
704
705 #[test]
706 #[cfg(feature = "md5")]
707 fn postgres_md5_known_value() {
708 let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
711 assert!(hash.starts_with("md5"));
712 assert_eq!(hash.len(), 35); }
714
715 #[test]
716 fn current_pg_timestamp_is_positive() {
717 let ts = current_pg_timestamp();
719 assert!(ts > 0);
720 }
721}