1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12 read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13 write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17 encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20pub struct SharedProgress {
25 applied: AtomicU64,
26}
27
28impl SharedProgress {
29 pub fn new(start: Lsn) -> Self {
30 Self {
31 applied: AtomicU64::new(start.as_u64()),
32 }
33 }
34
35 #[inline]
36 pub fn load_applied(&self) -> Lsn {
37 Lsn::from_u64(self.applied.load(Ordering::Acquire))
38 }
39
40 #[inline]
43 pub fn update_applied(&self, lsn: Lsn) {
44 let new = lsn.as_u64();
45 let mut cur = self.applied.load(Ordering::Relaxed);
46
47 while new > cur {
48 match self
49 .applied
50 .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51 {
52 Ok(_) => break,
53 Err(observed) => cur = observed,
54 }
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62 KeepAlive {
64 wal_end: Lsn,
66 reply_requested: bool,
68 server_time_micros: i64,
70 },
71
72 Begin {
74 final_lsn: Lsn,
75 xid: u32,
76 commit_time_micros: i64,
77 },
78
79 XLogData {
81 wal_start: Lsn,
83 wal_end: Lsn,
85 server_time_micros: i64,
87 data: Bytes,
89 },
90
91 Commit {
93 lsn: Lsn,
94 end_lsn: Lsn,
95 commit_time_micros: i64,
96 },
97
98 Message {
104 transactional: bool,
106 lsn: Lsn,
108 prefix: String,
110 content: Bytes,
112 },
113
114 StoppedAt {
119 reached: Lsn,
121 },
122}
123
124pub type ReplicationEventReceiver =
126 mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
127
128pub struct WorkerState {
130 cfg: ReplicationConfig,
131 progress: Arc<SharedProgress>,
132 stop_rx: watch::Receiver<bool>,
133 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
134}
135
136impl WorkerState {
137 pub fn new(
138 cfg: ReplicationConfig,
139 progress: Arc<SharedProgress>,
140 stop_rx: watch::Receiver<bool>,
141 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
142 ) -> Self {
143 Self {
144 cfg,
145 progress,
146 stop_rx,
147 out,
148 }
149 }
150
151 pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
153 &mut self,
154 stream: &mut S,
155 ) -> Result<()> {
156 self.startup(stream).await?;
157 self.authenticate(stream).await?;
158 self.start_replication(stream).await?;
159 self.stream_loop(stream).await
160 }
161
162 async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
164 let params = [
165 ("user", self.cfg.user.as_str()),
166 ("database", self.cfg.database.as_str()),
167 ("replication", "database"),
168 ("client_encoding", "UTF8"),
169 ("application_name", "pgwire-replication"),
170 ];
171 write_startup_message(stream, 196608, ¶ms).await
172 }
173
174 async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
176 &self,
177 stream: &mut S,
178 ) -> Result<()> {
179 let publication = self.cfg.publication.replace('\'', "''");
181 let sql = format!(
182 "START_REPLICATION SLOT {} LOGICAL {} \
183 (proto_version '1', publication_names '{}', messages 'true')",
184 self.cfg.slot, self.cfg.start_lsn, publication,
185 );
186 write_query(stream, &sql).await?;
187
188 loop {
190 let msg = read_backend_message(stream).await?;
191 match msg.tag {
192 b'W' => return Ok(()), b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
194 b'N' | b'S' | b'K' => continue, _ => continue,
196 }
197 }
198 }
199
200 async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
202 &mut self,
203 stream: &mut S,
204 ) -> Result<()> {
205 let mut last_status_sent = Instant::now() - self.cfg.status_interval;
206 let mut last_applied = self.progress.load_applied();
207
208 loop {
209 let current_applied = self.progress.load_applied();
211 if current_applied != last_applied {
212 last_applied = current_applied;
213 }
214
215 if last_status_sent.elapsed() >= self.cfg.status_interval {
217 self.send_feedback(stream, last_applied, false).await?;
218 last_status_sent = Instant::now();
219 }
220
221 let msg = tokio::select! {
225 biased; _ = self.stop_rx.changed() => {
228 if *self.stop_rx.borrow() {
229 let _ = write_copy_done(stream).await;
230 return Ok(());
231 }
232 continue;
234 }
235
236 msg_result = tokio::time::timeout(
237 self.cfg.idle_wakeup_interval,
238 read_backend_message(stream),
239 ) => {
240 match msg_result {
241 Ok(res) => res?, Err(_) => {
243 let applied = self.progress.load_applied();
245 last_applied = applied;
246 self.send_feedback(stream, applied, false).await?;
247 last_status_sent = Instant::now();
248 continue;
249 }
250 }
251 }
252 };
253
254 match msg.tag {
255 b'd' => {
256 let should_stop = self
257 .handle_copy_data(
258 stream,
259 msg.payload,
260 &mut last_applied,
261 &mut last_status_sent,
262 )
263 .await?;
264 if should_stop {
265 return Ok(());
266 }
267 }
268 b'E' => {
269 let err = PgWireError::Server(parse_error_response(&msg.payload));
270 return Err(err);
271 }
272 _ => {
273 }
275 }
276 }
277 }
278
279 async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
281 &mut self,
282 stream: &mut S,
283 payload: Bytes,
284 last_applied: &mut Lsn,
285 last_status_sent: &mut Instant,
286 ) -> Result<bool> {
287 let cd = parse_copy_data(payload)?;
288
289 match cd {
290 ReplicationCopyData::KeepAlive {
291 wal_end,
292 server_time_micros,
293 reply_requested,
294 } => {
295 if reply_requested {
297 let applied = self.progress.load_applied();
298 *last_applied = applied;
299 self.send_feedback(stream, applied, true).await?;
300 *last_status_sent = Instant::now();
301 }
302
303 self.send_event(Ok(ReplicationEvent::KeepAlive {
304 wal_end,
305 reply_requested,
306 server_time_micros,
307 }))
308 .await;
309
310 Ok(false)
311 }
312 ReplicationCopyData::XLogData {
313 wal_start,
314 wal_end,
315 server_time_micros,
316 data,
317 } => {
318 if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
320 let reached_lsn = match boundary_ev {
321 ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
322 ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
323 _ => wal_end, };
325
326 self.send_event(Ok(boundary_ev)).await;
327
328 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
330 if reached_lsn >= stop_lsn {
331 self.send_event(Ok(ReplicationEvent::StoppedAt {
332 reached: reached_lsn,
333 }))
334 .await;
335 let _ = write_copy_done(stream).await;
336 return Ok(true); }
338 }
339
340 return Ok(false);
341 }
342 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
345 if wal_end >= stop_lsn {
346 self.send_event(Ok(ReplicationEvent::XLogData {
348 wal_start,
349 wal_end,
350 server_time_micros,
351 data,
352 }))
353 .await;
354
355 self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
356 .await;
357
358 let _ = write_copy_done(stream).await;
359 return Ok(true);
360 }
361 }
362
363 self.send_event(Ok(ReplicationEvent::XLogData {
364 wal_start,
365 wal_end,
366 server_time_micros,
367 data,
368 }))
369 .await;
370
371 Ok(false)
372 }
373 }
374 }
375
376 async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
381 if self.out.send(event).await.is_err() {
382 tracing::debug!("event channel closed, client may have disconnected");
383 }
384 }
385
386 async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
388 &mut self,
389 stream: &mut S,
390 ) -> Result<()> {
391 loop {
392 let msg = read_backend_message(stream).await?;
393 match msg.tag {
394 b'R' => {
395 let (code, rest) = parse_auth_request(&msg.payload)?;
396 self.handle_auth_request(stream, code, rest).await?;
397 }
398 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
399 b'S' | b'K' => {} b'Z' => return Ok(()), _ => {}
402 }
403 }
404 }
405
406 async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
408 &mut self,
409 stream: &mut S,
410 code: i32,
411 data: &[u8],
412 ) -> Result<()> {
413 match code {
414 0 => Ok(()), 3 => {
416 let mut payload = Vec::from(self.cfg.password.as_bytes());
418 payload.push(0);
419 write_password_message(stream, &payload).await
420 }
421 10 => {
422 self.auth_scram(stream, data).await
424 }
425 #[cfg(feature = "md5")]
426 5 => {
427 if data.len() != 4 {
429 return Err(PgWireError::Protocol(
430 "MD5 auth: expected 4-byte salt".into(),
431 ));
432 }
433 let mut salt = [0u8; 4];
434 salt.copy_from_slice(&data[..4]);
435
436 let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
437 let mut payload = hash.into_bytes();
438 payload.push(0);
439 write_password_message(stream, &payload).await
440 }
441 _ => Err(PgWireError::Auth(format!(
442 "unsupported auth method code: {code}"
443 ))),
444 }
445 }
446
447 async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
449 &mut self,
450 stream: &mut S,
451 mechanisms_data: &[u8],
452 ) -> Result<()> {
453 let mechanisms = parse_sasl_mechanisms(mechanisms_data);
455
456 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
457 return Err(PgWireError::Auth(format!(
458 "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
459 )));
460 }
461
462 #[cfg(not(feature = "scram"))]
463 return Err(PgWireError::Auth(
464 "SCRAM authentication required but 'scram' feature not enabled".into(),
465 ));
466
467 #[cfg(feature = "scram")]
468 {
469 use crate::auth::scram::ScramClient;
470
471 let scram = ScramClient::new(&self.cfg.user);
472
473 let mut init = Vec::new();
475 init.extend_from_slice(b"SCRAM-SHA-256\0");
476 init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
477 init.extend_from_slice(scram.client_first.as_bytes());
478 write_password_message(stream, &init).await?;
479
480 let server_first = read_auth_data(stream, 11).await?;
482 let server_first_str = String::from_utf8_lossy(&server_first);
483
484 let (client_final, auth_message, salted_password) =
486 scram.client_final(&self.cfg.password, &server_first_str)?;
487 write_password_message(stream, client_final.as_bytes()).await?;
488
489 let server_final = read_auth_data(stream, 12).await?;
491 let server_final_str = String::from_utf8_lossy(&server_final);
492 ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
493
494 Ok(())
495 }
496 }
497
498 async fn send_feedback<S: AsyncWrite + Unpin>(
500 &self,
501 stream: &mut S,
502 applied: Lsn,
503 reply_requested: bool,
504 ) -> Result<()> {
505 let client_time = current_pg_timestamp();
506 let payload = encode_standby_status_update(applied, client_time, reply_requested);
507 write_copy_data(stream, &payload).await
508 }
509}
510
511fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
513 let mut mechanisms = Vec::new();
514 let mut remaining = data;
515
516 while !remaining.is_empty() {
517 if let Some(pos) = remaining.iter().position(|&x| x == 0) {
518 if pos == 0 {
519 break; }
521 mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
522 remaining = &remaining[pos + 1..];
523 } else {
524 break;
525 }
526 }
527
528 mechanisms
529}
530
531fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
532 if data.is_empty() {
533 return Ok(None);
534 }
535
536 let tag = data[0];
537 let mut p = &data[1..];
538
539 fn take_i8(p: &mut &[u8]) -> Result<i8> {
540 if p.is_empty() {
541 return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
542 }
543 let v = p[0] as i8;
544 *p = &p[1..];
545 Ok(v)
546 }
547
548 fn take_i32(p: &mut &[u8]) -> Result<i32> {
549 if p.len() < 4 {
550 return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
551 }
552 let (head, tail) = p.split_at(4);
553 *p = tail;
554 Ok(i32::from_be_bytes(head.try_into().unwrap()))
555 }
556
557 fn take_i64(p: &mut &[u8]) -> Result<i64> {
558 if p.len() < 8 {
559 return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
560 }
561 let (head, tail) = p.split_at(8);
562 *p = tail;
563 Ok(i64::from_be_bytes(head.try_into().unwrap()))
564 }
565
566 match tag {
567 b'B' => {
568 let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
569 let commit_time_micros = take_i64(&mut p)?;
570 let xid = take_i32(&mut p)? as u32;
571
572 Ok(Some(ReplicationEvent::Begin {
573 final_lsn,
574 commit_time_micros,
575 xid,
576 }))
577 }
578 b'C' => {
579 let _flags = take_i8(&mut p)?;
580 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
582 let commit_time_micros = take_i64(&mut p)?;
583
584 Ok(Some(ReplicationEvent::Commit {
585 lsn,
586 end_lsn,
587 commit_time_micros,
588 }))
589 }
590 b'M' => {
591 let flags = take_i8(&mut p)?;
594 let transactional = (flags & 1) != 0;
595 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
596
597 let prefix_end = p.iter().position(|&b| b == 0).ok_or_else(|| {
599 PgWireError::Protocol("pgoutput Message: missing null terminator for prefix".into())
600 })?;
601 let prefix = String::from_utf8_lossy(&p[..prefix_end]).into_owned();
602 p = &p[prefix_end + 1..]; let content_len = take_i32(&mut p)? as usize;
605 if p.len() < content_len {
606 return Err(PgWireError::Protocol(format!(
607 "pgoutput Message: expected {} content bytes, got {}",
608 content_len,
609 p.len()
610 )));
611 }
612 let content = Bytes::copy_from_slice(&p[..content_len]);
613
614 Ok(Some(ReplicationEvent::Message {
615 transactional,
616 lsn,
617 prefix,
618 content,
619 }))
620 }
621 _ => Ok(None),
622 }
623}
624
625async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
627 stream: &mut S,
628 expected_code: i32,
629) -> Result<Vec<u8>> {
630 loop {
631 let msg = read_backend_message(stream).await?;
632 match msg.tag {
633 b'R' => {
634 let (code, data) = parse_auth_request(&msg.payload)?;
635 if code == expected_code {
636 return Ok(data.to_vec());
637 }
638 return Err(PgWireError::Auth(format!(
639 "unexpected auth code {code}, expected {expected_code}"
640 )));
641 }
642 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
643 _ => {} }
645 }
646}
647
648fn current_pg_timestamp() -> i64 {
650 use std::time::{SystemTime, UNIX_EPOCH};
651
652 let now = SystemTime::now()
653 .duration_since(UNIX_EPOCH)
654 .unwrap_or_default();
655
656 let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
657 unix_micros - PG_EPOCH_MICROS
658}
659
660#[cfg(feature = "md5")]
662fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
663 fn md5_hex(data: &[u8]) -> String {
664 format!("{:x}", md5::compute(data))
665 }
666
667 let inner = md5_hex(format!("{password}{user}").as_bytes());
669
670 let mut outer_input = inner.into_bytes();
672 outer_input.extend_from_slice(salt);
673
674 format!("md5{}", md5_hex(&outer_input))
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680
681 #[test]
682 fn parse_sasl_mechanisms_single() {
683 let data = b"SCRAM-SHA-256\0\0";
684 let mechs = parse_sasl_mechanisms(data);
685 assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
686 }
687
688 #[test]
689 fn parse_sasl_mechanisms_multiple() {
690 let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
691 let mechs = parse_sasl_mechanisms(data);
692 assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
693 }
694
695 #[test]
696 fn parse_sasl_mechanisms_empty() {
697 let mechs = parse_sasl_mechanisms(b"\0");
698 assert!(mechs.is_empty());
699 }
700
701 #[test]
702 #[cfg(feature = "md5")]
703 fn postgres_md5_known_value() {
704 let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
707 assert!(hash.starts_with("md5"));
708 assert_eq!(hash.len(), 35); }
710
711 #[test]
712 fn current_pg_timestamp_is_positive() {
713 let ts = current_pg_timestamp();
715 assert!(ts > 0);
716 }
717}