1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12 read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13 write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17 encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20pub struct SharedProgress {
25 applied: AtomicU64,
26}
27
28impl SharedProgress {
29 pub fn new(start: Lsn) -> Self {
30 Self {
31 applied: AtomicU64::new(start.as_u64()),
32 }
33 }
34
35 #[inline]
36 pub fn load_applied(&self) -> Lsn {
37 Lsn::from_u64(self.applied.load(Ordering::Acquire))
38 }
39
40 #[inline]
43 pub fn update_applied(&self, lsn: Lsn) {
44 let new = lsn.as_u64();
45 let mut cur = self.applied.load(Ordering::Relaxed);
46
47 while new > cur {
48 match self
49 .applied
50 .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51 {
52 Ok(_) => break,
53 Err(observed) => cur = observed,
54 }
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62 KeepAlive {
64 wal_end: Lsn,
66 reply_requested: bool,
68 server_time_micros: i64,
70 },
71
72 Begin {
74 final_lsn: Lsn,
75 xid: u32,
76 commit_time_micros: i64,
77 },
78
79 XLogData {
81 wal_start: Lsn,
83 wal_end: Lsn,
85 server_time_micros: i64,
87 data: Bytes,
89 },
90
91 Commit {
93 lsn: Lsn,
94 end_lsn: Lsn,
95 commit_time_micros: i64,
96 },
97
98 StoppedAt {
103 reached: Lsn,
105 },
106}
107
108pub type ReplicationEventReceiver =
110 mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
111
112pub struct WorkerState {
114 cfg: ReplicationConfig,
115 progress: Arc<SharedProgress>,
116 stop_rx: watch::Receiver<bool>,
117 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
118}
119
120impl WorkerState {
121 pub fn new(
122 cfg: ReplicationConfig,
123 progress: Arc<SharedProgress>,
124 stop_rx: watch::Receiver<bool>,
125 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
126 ) -> Self {
127 Self {
128 cfg,
129 progress,
130 stop_rx,
131 out,
132 }
133 }
134
135 pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
137 &mut self,
138 stream: &mut S,
139 ) -> Result<()> {
140 self.startup(stream).await?;
141 self.authenticate(stream).await?;
142 self.start_replication(stream).await?;
143 self.stream_loop(stream).await
144 }
145
146 async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
148 let params = [
149 ("user", self.cfg.user.as_str()),
150 ("database", self.cfg.database.as_str()),
151 ("replication", "database"),
152 ("client_encoding", "UTF8"),
153 ("application_name", "pgwire-replication"),
154 ];
155 write_startup_message(stream, 196608, ¶ms).await
156 }
157
158 async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
160 &self,
161 stream: &mut S,
162 ) -> Result<()> {
163 let publication = self.cfg.publication.replace('\'', "''");
165 let sql = format!(
166 "START_REPLICATION SLOT {} LOGICAL {} (proto_version '1', publication_names '{}')",
167 self.cfg.slot, self.cfg.start_lsn, publication,
168 );
169 write_query(stream, &sql).await?;
170
171 loop {
173 let msg = read_backend_message(stream).await?;
174 match msg.tag {
175 b'W' => return Ok(()), b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
177 b'N' | b'S' | b'K' => continue, _ => continue,
179 }
180 }
181 }
182
183 async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
185 &mut self,
186 stream: &mut S,
187 ) -> Result<()> {
188 let mut last_status_sent = Instant::now() - self.cfg.status_interval;
189 let mut last_applied = self.progress.load_applied();
190
191 loop {
192 let current_applied = self.progress.load_applied();
194 if current_applied != last_applied {
195 last_applied = current_applied;
196 }
197
198 if last_status_sent.elapsed() >= self.cfg.status_interval {
200 self.send_feedback(stream, last_applied, false).await?;
201 last_status_sent = Instant::now();
202 }
203
204 let msg = tokio::select! {
208 biased; _ = self.stop_rx.changed() => {
211 if *self.stop_rx.borrow() {
212 let _ = write_copy_done(stream).await;
213 return Ok(());
214 }
215 continue;
217 }
218
219 msg_result = tokio::time::timeout(
220 self.cfg.idle_wakeup_interval,
221 read_backend_message(stream),
222 ) => {
223 match msg_result {
224 Ok(res) => res?, Err(_) => {
226 let applied = self.progress.load_applied();
228 last_applied = applied;
229 self.send_feedback(stream, applied, false).await?;
230 last_status_sent = Instant::now();
231 continue;
232 }
233 }
234 }
235 };
236
237 match msg.tag {
238 b'd' => {
239 let should_stop = self
240 .handle_copy_data(
241 stream,
242 msg.payload,
243 &mut last_applied,
244 &mut last_status_sent,
245 )
246 .await?;
247 if should_stop {
248 return Ok(());
249 }
250 }
251 b'E' => {
252 let err = PgWireError::Server(parse_error_response(&msg.payload));
253 return Err(err);
254 }
255 _ => {
256 }
258 }
259 }
260 }
261
262 async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
264 &mut self,
265 stream: &mut S,
266 payload: Bytes,
267 last_applied: &mut Lsn,
268 last_status_sent: &mut Instant,
269 ) -> Result<bool> {
270 let cd = parse_copy_data(payload)?;
271
272 match cd {
273 ReplicationCopyData::KeepAlive {
274 wal_end,
275 server_time_micros,
276 reply_requested,
277 } => {
278 if reply_requested {
280 let applied = self.progress.load_applied();
281 *last_applied = applied;
282 self.send_feedback(stream, applied, true).await?;
283 *last_status_sent = Instant::now();
284 }
285
286 self.send_event(Ok(ReplicationEvent::KeepAlive {
287 wal_end,
288 reply_requested,
289 server_time_micros,
290 }))
291 .await;
292
293 Ok(false)
294 }
295 ReplicationCopyData::XLogData {
296 wal_start,
297 wal_end,
298 server_time_micros,
299 data,
300 } => {
301 if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
303 let reached_lsn = match boundary_ev {
304 ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
305 ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
306 _ => wal_end, };
308
309 self.send_event(Ok(boundary_ev)).await;
310
311 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
313 if reached_lsn >= stop_lsn {
314 self.send_event(Ok(ReplicationEvent::StoppedAt {
315 reached: reached_lsn,
316 }))
317 .await;
318 let _ = write_copy_done(stream).await;
319 return Ok(true); }
321 }
322
323 return Ok(false);
324 }
325 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
328 if wal_end >= stop_lsn {
329 self.send_event(Ok(ReplicationEvent::XLogData {
331 wal_start,
332 wal_end,
333 server_time_micros,
334 data,
335 }))
336 .await;
337
338 self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
339 .await;
340
341 let _ = write_copy_done(stream).await;
342 return Ok(true);
343 }
344 }
345
346 self.send_event(Ok(ReplicationEvent::XLogData {
347 wal_start,
348 wal_end,
349 server_time_micros,
350 data,
351 }))
352 .await;
353
354 Ok(false)
355 }
356 }
357 }
358
359 async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
364 if self.out.send(event).await.is_err() {
365 tracing::debug!("event channel closed, client may have disconnected");
366 }
367 }
368
369 async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
371 &mut self,
372 stream: &mut S,
373 ) -> Result<()> {
374 loop {
375 let msg = read_backend_message(stream).await?;
376 match msg.tag {
377 b'R' => {
378 let (code, rest) = parse_auth_request(&msg.payload)?;
379 self.handle_auth_request(stream, code, rest).await?;
380 }
381 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
382 b'S' | b'K' => {} b'Z' => return Ok(()), _ => {}
385 }
386 }
387 }
388
389 async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
391 &mut self,
392 stream: &mut S,
393 code: i32,
394 data: &[u8],
395 ) -> Result<()> {
396 match code {
397 0 => Ok(()), 3 => {
399 let mut payload = Vec::from(self.cfg.password.as_bytes());
401 payload.push(0);
402 write_password_message(stream, &payload).await
403 }
404 10 => {
405 self.auth_scram(stream, data).await
407 }
408 #[cfg(feature = "md5")]
409 5 => {
410 if data.len() != 4 {
412 return Err(PgWireError::Protocol(
413 "MD5 auth: expected 4-byte salt".into(),
414 ));
415 }
416 let mut salt = [0u8; 4];
417 salt.copy_from_slice(&data[..4]);
418
419 let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
420 let mut payload = hash.into_bytes();
421 payload.push(0);
422 write_password_message(stream, &payload).await
423 }
424 _ => Err(PgWireError::Auth(format!(
425 "unsupported auth method code: {code}"
426 ))),
427 }
428 }
429
430 async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
432 &mut self,
433 stream: &mut S,
434 mechanisms_data: &[u8],
435 ) -> Result<()> {
436 let mechanisms = parse_sasl_mechanisms(mechanisms_data);
438
439 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
440 return Err(PgWireError::Auth(format!(
441 "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
442 )));
443 }
444
445 #[cfg(not(feature = "scram"))]
446 return Err(PgWireError::Auth(
447 "SCRAM authentication required but 'scram' feature not enabled".into(),
448 ));
449
450 #[cfg(feature = "scram")]
451 {
452 use crate::auth::scram::ScramClient;
453
454 let scram = ScramClient::new(&self.cfg.user);
455
456 let mut init = Vec::new();
458 init.extend_from_slice(b"SCRAM-SHA-256\0");
459 init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
460 init.extend_from_slice(scram.client_first.as_bytes());
461 write_password_message(stream, &init).await?;
462
463 let server_first = read_auth_data(stream, 11).await?;
465 let server_first_str = String::from_utf8_lossy(&server_first);
466
467 let (client_final, auth_message, salted_password) =
469 scram.client_final(&self.cfg.password, &server_first_str)?;
470 write_password_message(stream, client_final.as_bytes()).await?;
471
472 let server_final = read_auth_data(stream, 12).await?;
474 let server_final_str = String::from_utf8_lossy(&server_final);
475 ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
476
477 Ok(())
478 }
479 }
480
481 async fn send_feedback<S: AsyncWrite + Unpin>(
483 &self,
484 stream: &mut S,
485 applied: Lsn,
486 reply_requested: bool,
487 ) -> Result<()> {
488 let client_time = current_pg_timestamp();
489 let payload = encode_standby_status_update(applied, client_time, reply_requested);
490 write_copy_data(stream, &payload).await
491 }
492}
493
494fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
496 let mut mechanisms = Vec::new();
497 let mut remaining = data;
498
499 while !remaining.is_empty() {
500 if let Some(pos) = remaining.iter().position(|&x| x == 0) {
501 if pos == 0 {
502 break; }
504 mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
505 remaining = &remaining[pos + 1..];
506 } else {
507 break;
508 }
509 }
510
511 mechanisms
512}
513
514fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
515 if data.is_empty() {
516 return Ok(None);
517 }
518
519 let tag = data[0];
520 let mut p = &data[1..];
521
522 fn take_i8(p: &mut &[u8]) -> Result<i8> {
523 if p.is_empty() {
524 return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
525 }
526 let v = p[0] as i8;
527 *p = &p[1..];
528 Ok(v)
529 }
530
531 fn take_i32(p: &mut &[u8]) -> Result<i32> {
532 if p.len() < 4 {
533 return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
534 }
535 let (head, tail) = p.split_at(4);
536 *p = tail;
537 Ok(i32::from_be_bytes(head.try_into().unwrap()))
538 }
539
540 fn take_i64(p: &mut &[u8]) -> Result<i64> {
541 if p.len() < 8 {
542 return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
543 }
544 let (head, tail) = p.split_at(8);
545 *p = tail;
546 Ok(i64::from_be_bytes(head.try_into().unwrap()))
547 }
548
549 match tag {
550 b'B' => {
551 let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
552 let commit_time_micros = take_i64(&mut p)?;
553 let xid = take_i32(&mut p)? as u32;
554
555 Ok(Some(ReplicationEvent::Begin {
556 final_lsn,
557 commit_time_micros,
558 xid,
559 }))
560 }
561 b'C' => {
562 let _flags = take_i8(&mut p)?;
563 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
565 let commit_time_micros = take_i64(&mut p)?;
566
567 Ok(Some(ReplicationEvent::Commit {
568 lsn,
569 end_lsn,
570 commit_time_micros,
571 }))
572 }
573 _ => Ok(None),
574 }
575}
576
577async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
579 stream: &mut S,
580 expected_code: i32,
581) -> Result<Vec<u8>> {
582 loop {
583 let msg = read_backend_message(stream).await?;
584 match msg.tag {
585 b'R' => {
586 let (code, data) = parse_auth_request(&msg.payload)?;
587 if code == expected_code {
588 return Ok(data.to_vec());
589 }
590 return Err(PgWireError::Auth(format!(
591 "unexpected auth code {code}, expected {expected_code}"
592 )));
593 }
594 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
595 _ => {} }
597 }
598}
599
600fn current_pg_timestamp() -> i64 {
602 use std::time::{SystemTime, UNIX_EPOCH};
603
604 let now = SystemTime::now()
605 .duration_since(UNIX_EPOCH)
606 .unwrap_or_default();
607
608 let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
609 unix_micros - PG_EPOCH_MICROS
610}
611
612#[cfg(feature = "md5")]
614fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
615 fn md5_hex(data: &[u8]) -> String {
616 format!("{:x}", md5::compute(data))
617 }
618
619 let inner = md5_hex(format!("{password}{user}").as_bytes());
621
622 let mut outer_input = inner.into_bytes();
624 outer_input.extend_from_slice(salt);
625
626 format!("md5{}", md5_hex(&outer_input))
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn parse_sasl_mechanisms_single() {
635 let data = b"SCRAM-SHA-256\0\0";
636 let mechs = parse_sasl_mechanisms(data);
637 assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
638 }
639
640 #[test]
641 fn parse_sasl_mechanisms_multiple() {
642 let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
643 let mechs = parse_sasl_mechanisms(data);
644 assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
645 }
646
647 #[test]
648 fn parse_sasl_mechanisms_empty() {
649 let mechs = parse_sasl_mechanisms(b"\0");
650 assert!(mechs.is_empty());
651 }
652
653 #[test]
654 #[cfg(feature = "md5")]
655 fn postgres_md5_known_value() {
656 let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
659 assert!(hash.starts_with("md5"));
660 assert_eq!(hash.len(), 35); }
662
663 #[test]
664 fn current_pg_timestamp_is_positive() {
665 let ts = current_pg_timestamp();
667 assert!(ts > 0);
668 }
669}