1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12 read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13 write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17 encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20pub struct SharedProgress {
25 applied: AtomicU64,
26}
27
28impl SharedProgress {
29 pub fn new(start: Lsn) -> Self {
30 Self {
31 applied: AtomicU64::new(start.as_u64()),
32 }
33 }
34
35 #[inline]
36 pub fn load_applied(&self) -> Lsn {
37 Lsn::from_u64(self.applied.load(Ordering::Acquire))
38 }
39
40 #[inline]
43 pub fn update_applied(&self, lsn: Lsn) {
44 let new = lsn.as_u64();
45 let mut cur = self.applied.load(Ordering::Relaxed);
46
47 while new > cur {
48 match self
49 .applied
50 .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51 {
52 Ok(_) => break,
53 Err(observed) => cur = observed,
54 }
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62 KeepAlive {
64 wal_end: Lsn,
66 reply_requested: bool,
68 server_time_micros: i64,
70 },
71
72 Begin {
74 final_lsn: Lsn,
75 xid: u32,
76 commit_time_micros: i64,
77 },
78
79 XLogData {
81 wal_start: Lsn,
83 wal_end: Lsn,
85 server_time_micros: i64,
87 data: Bytes,
89 },
90
91 Commit {
93 lsn: Lsn,
94 end_lsn: Lsn,
95 commit_time_micros: i64,
96 },
97
98 StoppedAt {
103 reached: Lsn,
105 },
106}
107
108pub type ReplicationEventReceiver =
110 mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
111
112pub struct WorkerState {
114 cfg: ReplicationConfig,
115 progress: Arc<SharedProgress>,
116 stop_rx: watch::Receiver<bool>,
117 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
118}
119
120impl WorkerState {
121 pub fn new(
122 cfg: ReplicationConfig,
123 progress: Arc<SharedProgress>,
124 stop_rx: watch::Receiver<bool>,
125 out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
126 ) -> Self {
127 Self {
128 cfg,
129 progress,
130 stop_rx,
131 out,
132 }
133 }
134
135 pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
137 &mut self,
138 stream: &mut S,
139 ) -> Result<()> {
140 self.startup(stream).await?;
141 self.authenticate(stream).await?;
142 self.start_replication(stream).await?;
143 self.stream_loop(stream).await
144 }
145
146 async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
148 let params = [
149 ("user", self.cfg.user.as_str()),
150 ("database", self.cfg.database.as_str()),
151 ("replication", "database"),
152 ("client_encoding", "UTF8"),
153 ("application_name", "pgwire-replication"),
154 ];
155 write_startup_message(stream, 196608, ¶ms).await
156 }
157
158 async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
160 &self,
161 stream: &mut S,
162 ) -> Result<()> {
163 let publication = self.cfg.publication.replace('\'', "''");
165 let sql = format!(
166 "START_REPLICATION SLOT {} LOGICAL {} (proto_version '1', publication_names '{}')",
167 self.cfg.slot, self.cfg.start_lsn, publication,
168 );
169 write_query(stream, &sql).await?;
170
171 loop {
173 let msg = read_backend_message(stream).await?;
174 match msg.tag {
175 b'W' => return Ok(()), b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
177 b'N' | b'S' | b'K' => continue, _ => continue,
179 }
180 }
181 }
182
183 async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
185 &mut self,
186 stream: &mut S,
187 ) -> Result<()> {
188 let mut last_status_sent = Instant::now() - self.cfg.status_interval;
189 let mut last_applied = self.progress.load_applied();
190
191 loop {
192 if *self.stop_rx.borrow() {
194 let _ = write_copy_done(stream).await;
195 return Ok(());
196 }
197
198 let current_applied = self.progress.load_applied();
200 if current_applied != last_applied {
201 last_applied = current_applied;
202 }
203
204 if last_status_sent.elapsed() >= self.cfg.status_interval {
206 self.send_feedback(stream, last_applied, false).await?;
207 last_status_sent = Instant::now();
208 }
209
210 let msg = match tokio::time::timeout(
212 self.cfg.idle_wakeup_interval,
213 read_backend_message(stream),
214 )
215 .await
216 {
217 Ok(res) => res?, Err(_) => {
219 let applied = self.progress.load_applied();
222 last_applied = applied;
223 self.send_feedback(stream, applied, false).await?;
224 last_status_sent = Instant::now();
225 continue;
226 }
227 };
228
229 match msg.tag {
230 b'd' => {
231 let should_stop = self
232 .handle_copy_data(
233 stream,
234 msg.payload,
235 &mut last_applied,
236 &mut last_status_sent,
237 )
238 .await?;
239 if should_stop {
240 return Ok(());
241 }
242 }
243 b'E' => {
244 let err = PgWireError::Server(parse_error_response(&msg.payload));
245 return Err(err);
246 }
247 _ => {
248 }
250 }
251 }
252 }
253
254 async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
256 &mut self,
257 stream: &mut S,
258 payload: Bytes,
259 last_applied: &mut Lsn,
260 last_status_sent: &mut Instant,
261 ) -> Result<bool> {
262 let cd = parse_copy_data(payload)?;
263
264 match cd {
265 ReplicationCopyData::KeepAlive {
266 wal_end,
267 server_time_micros,
268 reply_requested,
269 } => {
270 if reply_requested {
272 let applied = self.progress.load_applied();
273 *last_applied = applied;
274 self.send_feedback(stream, applied, true).await?;
275 *last_status_sent = Instant::now();
276 }
277
278 self.send_event(Ok(ReplicationEvent::KeepAlive {
279 wal_end,
280 reply_requested,
281 server_time_micros,
282 }))
283 .await;
284
285 Ok(false)
286 }
287 ReplicationCopyData::XLogData {
288 wal_start,
289 wal_end,
290 server_time_micros,
291 data,
292 } => {
293 if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
295 let reached_lsn = match boundary_ev {
296 ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
297 ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
298 _ => wal_end, };
300
301 self.send_event(Ok(boundary_ev)).await;
302
303 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
305 if reached_lsn >= stop_lsn {
306 self.send_event(Ok(ReplicationEvent::StoppedAt {
307 reached: reached_lsn,
308 }))
309 .await;
310 let _ = write_copy_done(stream).await;
311 return Ok(true); }
313 }
314
315 return Ok(false);
316 }
317 if let Some(stop_lsn) = self.cfg.stop_at_lsn {
320 if wal_end >= stop_lsn {
321 self.send_event(Ok(ReplicationEvent::XLogData {
323 wal_start,
324 wal_end,
325 server_time_micros,
326 data,
327 }))
328 .await;
329
330 self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
331 .await;
332
333 let _ = write_copy_done(stream).await;
334 return Ok(true);
335 }
336 }
337
338 self.send_event(Ok(ReplicationEvent::XLogData {
339 wal_start,
340 wal_end,
341 server_time_micros,
342 data,
343 }))
344 .await;
345
346 Ok(false)
347 }
348 }
349 }
350
351 async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
356 if self.out.send(event).await.is_err() {
357 tracing::debug!("event channel closed, client may have disconnected");
358 }
359 }
360
361 async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
363 &mut self,
364 stream: &mut S,
365 ) -> Result<()> {
366 loop {
367 let msg = read_backend_message(stream).await?;
368 match msg.tag {
369 b'R' => {
370 let (code, rest) = parse_auth_request(&msg.payload)?;
371 self.handle_auth_request(stream, code, rest).await?;
372 }
373 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
374 b'S' | b'K' => {} b'Z' => return Ok(()), _ => {}
377 }
378 }
379 }
380
381 async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
383 &mut self,
384 stream: &mut S,
385 code: i32,
386 data: &[u8],
387 ) -> Result<()> {
388 match code {
389 0 => Ok(()), 3 => {
391 let mut payload = Vec::from(self.cfg.password.as_bytes());
393 payload.push(0);
394 write_password_message(stream, &payload).await
395 }
396 10 => {
397 self.auth_scram(stream, data).await
399 }
400 #[cfg(feature = "md5")]
401 5 => {
402 if data.len() != 4 {
404 return Err(PgWireError::Protocol(
405 "MD5 auth: expected 4-byte salt".into(),
406 ));
407 }
408 let mut salt = [0u8; 4];
409 salt.copy_from_slice(&data[..4]);
410
411 let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
412 let mut payload = hash.into_bytes();
413 payload.push(0);
414 write_password_message(stream, &payload).await
415 }
416 _ => Err(PgWireError::Auth(format!(
417 "unsupported auth method code: {code}"
418 ))),
419 }
420 }
421
422 async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
424 &mut self,
425 stream: &mut S,
426 mechanisms_data: &[u8],
427 ) -> Result<()> {
428 let mechanisms = parse_sasl_mechanisms(mechanisms_data);
430
431 if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
432 return Err(PgWireError::Auth(format!(
433 "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
434 )));
435 }
436
437 #[cfg(not(feature = "scram"))]
438 return Err(PgWireError::Auth(
439 "SCRAM authentication required but 'scram' feature not enabled".into(),
440 ));
441
442 #[cfg(feature = "scram")]
443 {
444 use crate::auth::scram::ScramClient;
445
446 let scram = ScramClient::new(&self.cfg.user);
447
448 let mut init = Vec::new();
450 init.extend_from_slice(b"SCRAM-SHA-256\0");
451 init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
452 init.extend_from_slice(scram.client_first.as_bytes());
453 write_password_message(stream, &init).await?;
454
455 let server_first = read_auth_data(stream, 11).await?;
457 let server_first_str = String::from_utf8_lossy(&server_first);
458
459 let (client_final, auth_message, salted_password) =
461 scram.client_final(&self.cfg.password, &server_first_str)?;
462 write_password_message(stream, client_final.as_bytes()).await?;
463
464 let server_final = read_auth_data(stream, 12).await?;
466 let server_final_str = String::from_utf8_lossy(&server_final);
467 ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
468
469 Ok(())
470 }
471 }
472
473 async fn send_feedback<S: AsyncWrite + Unpin>(
475 &self,
476 stream: &mut S,
477 applied: Lsn,
478 reply_requested: bool,
479 ) -> Result<()> {
480 let client_time = current_pg_timestamp();
481 let payload = encode_standby_status_update(applied, client_time, reply_requested);
482 write_copy_data(stream, &payload).await
483 }
484}
485
486fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
488 let mut mechanisms = Vec::new();
489 let mut remaining = data;
490
491 while !remaining.is_empty() {
492 if let Some(pos) = remaining.iter().position(|&x| x == 0) {
493 if pos == 0 {
494 break; }
496 mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
497 remaining = &remaining[pos + 1..];
498 } else {
499 break;
500 }
501 }
502
503 mechanisms
504}
505
506fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
507 if data.is_empty() {
508 return Ok(None);
509 }
510
511 let tag = data[0];
512 let mut p = &data[1..];
513
514 fn take_i8(p: &mut &[u8]) -> Result<i8> {
515 if p.is_empty() {
516 return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
517 }
518 let v = p[0] as i8;
519 *p = &p[1..];
520 Ok(v)
521 }
522
523 fn take_i32(p: &mut &[u8]) -> Result<i32> {
524 if p.len() < 4 {
525 return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
526 }
527 let (head, tail) = p.split_at(4);
528 *p = tail;
529 Ok(i32::from_be_bytes(head.try_into().unwrap()))
530 }
531
532 fn take_i64(p: &mut &[u8]) -> Result<i64> {
533 if p.len() < 8 {
534 return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
535 }
536 let (head, tail) = p.split_at(8);
537 *p = tail;
538 Ok(i64::from_be_bytes(head.try_into().unwrap()))
539 }
540
541 match tag {
542 b'B' => {
543 let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
544 let commit_time_micros = take_i64(&mut p)?;
545 let xid = take_i32(&mut p)? as u32;
546
547 Ok(Some(ReplicationEvent::Begin {
548 final_lsn,
549 commit_time_micros,
550 xid,
551 }))
552 }
553 b'C' => {
554 let _flags = take_i8(&mut p)?;
555 let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
557 let commit_time_micros = take_i64(&mut p)?;
558
559 Ok(Some(ReplicationEvent::Commit {
560 lsn,
561 end_lsn,
562 commit_time_micros,
563 }))
564 }
565 _ => Ok(None),
566 }
567}
568
569async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
571 stream: &mut S,
572 expected_code: i32,
573) -> Result<Vec<u8>> {
574 loop {
575 let msg = read_backend_message(stream).await?;
576 match msg.tag {
577 b'R' => {
578 let (code, data) = parse_auth_request(&msg.payload)?;
579 if code == expected_code {
580 return Ok(data.to_vec());
581 }
582 return Err(PgWireError::Auth(format!(
583 "unexpected auth code {code}, expected {expected_code}"
584 )));
585 }
586 b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
587 _ => {} }
589 }
590}
591
592fn current_pg_timestamp() -> i64 {
594 use std::time::{SystemTime, UNIX_EPOCH};
595
596 let now = SystemTime::now()
597 .duration_since(UNIX_EPOCH)
598 .unwrap_or_default();
599
600 let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
601 unix_micros - PG_EPOCH_MICROS
602}
603
604#[cfg(feature = "md5")]
606fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
607 fn md5_hex(data: &[u8]) -> String {
608 format!("{:x}", md5::compute(data))
609 }
610
611 let inner = md5_hex(format!("{password}{user}").as_bytes());
613
614 let mut outer_input = inner.into_bytes();
616 outer_input.extend_from_slice(salt);
617
618 format!("md5{}", md5_hex(&outer_input))
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624
625 #[test]
626 fn parse_sasl_mechanisms_single() {
627 let data = b"SCRAM-SHA-256\0\0";
628 let mechs = parse_sasl_mechanisms(data);
629 assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
630 }
631
632 #[test]
633 fn parse_sasl_mechanisms_multiple() {
634 let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
635 let mechs = parse_sasl_mechanisms(data);
636 assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
637 }
638
639 #[test]
640 fn parse_sasl_mechanisms_empty() {
641 let mechs = parse_sasl_mechanisms(b"\0");
642 assert!(mechs.is_empty());
643 }
644
645 #[test]
646 #[cfg(feature = "md5")]
647 fn postgres_md5_known_value() {
648 let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
651 assert!(hash.starts_with("md5"));
652 assert_eq!(hash.len(), 35); }
654
655 #[test]
656 fn current_pg_timestamp_is_positive() {
657 let ts = current_pg_timestamp();
659 assert!(ts > 0);
660 }
661}