pgwire_replication/client/
worker.rs

1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12    read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13    write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17    encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20/// Shared replication progress updated by the consumer and read by the worker.
21///
22/// Stored as an AtomicU64 so progress updates are cheap and monotonic
23/// without async backpressure.
24pub struct SharedProgress {
25    applied: AtomicU64,
26}
27
28impl SharedProgress {
29    pub fn new(start: Lsn) -> Self {
30        Self {
31            applied: AtomicU64::new(start.as_u64()),
32        }
33    }
34
35    #[inline]
36    pub fn load_applied(&self) -> Lsn {
37        Lsn::from_u64(self.applied.load(Ordering::Acquire))
38    }
39
40    /// Monotonic update: if `lsn` is lower than the currently stored applied LSN,
41    /// this is a no-op.
42    #[inline]
43    pub fn update_applied(&self, lsn: Lsn) {
44        let new = lsn.as_u64();
45        let mut cur = self.applied.load(Ordering::Relaxed);
46
47        while new > cur {
48            match self
49                .applied
50                .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51            {
52                Ok(_) => break,
53                Err(observed) => cur = observed,
54            }
55        }
56    }
57}
58
59/// Events emitted by the replication worker.
60#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62    /// Server heartbeat message.
63    KeepAlive {
64        /// Current server WAL end position
65        wal_end: Lsn,
66        /// Whether server requested a reply (already handled internally)
67        reply_requested: bool,
68        /// Server timestamp (microseconds since 2000-01-01)
69        server_time_micros: i64,
70    },
71
72    /// Start of a transaction (pgoutput Begin message).
73    Begin {
74        final_lsn: Lsn,
75        xid: u32,
76        commit_time_micros: i64,
77    },
78
79    /// WAL data containing transaction changes.
80    XLogData {
81        /// WAL position where this data starts
82        wal_start: Lsn,
83        /// WAL end position (may be 0 for mid-transaction messages)
84        wal_end: Lsn,
85        /// Server timestamp (microseconds since 2000-01-01)
86        server_time_micros: i64,
87        /// pgoutput-encoded change data
88        data: Bytes,
89    },
90
91    /// End of a transaction (pgoutput Commit message).
92    Commit {
93        lsn: Lsn,
94        end_lsn: Lsn,
95        commit_time_micros: i64,
96    },
97
98    /// Emitted when `stop_at_lsn` has been reached.
99    ///
100    /// After this event, no more events will be emitted and the
101    /// replication stream will be closed.
102    StoppedAt {
103        /// The LSN that triggered the stop condition
104        reached: Lsn,
105    },
106}
107
108/// Channel receiver type for replication events.
109pub type ReplicationEventReceiver =
110    mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
111
112/// Internal worker state.
113pub struct WorkerState {
114    cfg: ReplicationConfig,
115    progress: Arc<SharedProgress>,
116    stop_rx: watch::Receiver<bool>,
117    out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
118}
119
120impl WorkerState {
121    pub fn new(
122        cfg: ReplicationConfig,
123        progress: Arc<SharedProgress>,
124        stop_rx: watch::Receiver<bool>,
125        out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
126    ) -> Self {
127        Self {
128            cfg,
129            progress,
130            stop_rx,
131            out,
132        }
133    }
134
135    /// Run the replication protocol on the given stream.
136    pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
137        &mut self,
138        stream: &mut S,
139    ) -> Result<()> {
140        self.startup(stream).await?;
141        self.authenticate(stream).await?;
142        self.start_replication(stream).await?;
143        self.stream_loop(stream).await
144    }
145
146    /// Send startup message with replication parameters.
147    async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
148        let params = [
149            ("user", self.cfg.user.as_str()),
150            ("database", self.cfg.database.as_str()),
151            ("replication", "database"),
152            ("client_encoding", "UTF8"),
153            ("application_name", "pgwire-replication"),
154        ];
155        write_startup_message(stream, 196608, &params).await
156    }
157
158    /// Start the logical replication stream.
159    async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
160        &self,
161        stream: &mut S,
162    ) -> Result<()> {
163        // Escape single quotes in publication name
164        let publication = self.cfg.publication.replace('\'', "''");
165        let sql = format!(
166            "START_REPLICATION SLOT {} LOGICAL {} (proto_version '1', publication_names '{}')",
167            self.cfg.slot, self.cfg.start_lsn, publication,
168        );
169        write_query(stream, &sql).await?;
170
171        // Wait for CopyBothResponse
172        loop {
173            let msg = read_backend_message(stream).await?;
174            match msg.tag {
175                b'W' => return Ok(()), // CopyBothResponse - ready to stream
176                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
177                b'N' | b'S' | b'K' => continue, // Notice, ParameterStatus, BackendKeyData
178                _ => continue,
179            }
180        }
181    }
182
183    /// Main replication streaming loop.
184    async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
185        &mut self,
186        stream: &mut S,
187    ) -> Result<()> {
188        let mut last_status_sent = Instant::now() - self.cfg.status_interval;
189        let mut last_applied = self.progress.load_applied();
190
191        loop {
192            // Check for stop request
193            if *self.stop_rx.borrow() {
194                let _ = write_copy_done(stream).await;
195                return Ok(());
196            }
197
198            // Update applied LSN from client
199            let current_applied = self.progress.load_applied();
200            if current_applied != last_applied {
201                last_applied = current_applied;
202            }
203
204            // Send periodic status feedback
205            if last_status_sent.elapsed() >= self.cfg.status_interval {
206                self.send_feedback(stream, last_applied, false).await?;
207                last_status_sent = Instant::now();
208            }
209
210            // Read next message with idle timeout
211            let msg = match tokio::time::timeout(
212                self.cfg.idle_wakeup_interval,
213                read_backend_message(stream),
214            )
215            .await
216            {
217                Ok(res) => res?, // read_backend_message result
218                Err(_) => {
219                    // No message received; keep the connection alive by sending feedback
220                    // wakeup: send feedback even while idle
221                    let applied = self.progress.load_applied();
222                    last_applied = applied;
223                    self.send_feedback(stream, applied, false).await?;
224                    last_status_sent = Instant::now();
225                    continue;
226                }
227            };
228
229            match msg.tag {
230                b'd' => {
231                    let should_stop = self
232                        .handle_copy_data(
233                            stream,
234                            msg.payload,
235                            &mut last_applied,
236                            &mut last_status_sent,
237                        )
238                        .await?;
239                    if should_stop {
240                        return Ok(());
241                    }
242                }
243                b'E' => {
244                    let err = PgWireError::Server(parse_error_response(&msg.payload));
245                    return Err(err);
246                }
247                _ => {
248                    // Unexpected in CopyBoth mode, but ignore gracefully
249                }
250            }
251        }
252    }
253
254    /// Handle a CopyData message. Returns true if we should stop.
255    async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
256        &mut self,
257        stream: &mut S,
258        payload: Bytes,
259        last_applied: &mut Lsn,
260        last_status_sent: &mut Instant,
261    ) -> Result<bool> {
262        let cd = parse_copy_data(payload)?;
263
264        match cd {
265            ReplicationCopyData::KeepAlive {
266                wal_end,
267                server_time_micros,
268                reply_requested,
269            } => {
270                // Respond immediately if server requests it
271                if reply_requested {
272                    let applied = self.progress.load_applied();
273                    *last_applied = applied;
274                    self.send_feedback(stream, applied, true).await?;
275                    *last_status_sent = Instant::now();
276                }
277
278                self.send_event(Ok(ReplicationEvent::KeepAlive {
279                    wal_end,
280                    reply_requested,
281                    server_time_micros,
282                }))
283                .await;
284
285                Ok(false)
286            }
287            ReplicationCopyData::XLogData {
288                wal_start,
289                wal_end,
290                server_time_micros,
291                data,
292            } => {
293                // If the payload is a pgoutput Begin/Commit message, emit only the boundary event.
294                if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
295                    let reached_lsn = match boundary_ev {
296                        ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
297                        ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
298                        _ => wal_end, // should never happen if parser only returns Begin/Commit
299                    };
300
301                    self.send_event(Ok(boundary_ev)).await;
302
303                    // Stop condition (prefer boundary LSN semantics when available)
304                    if let Some(stop_lsn) = self.cfg.stop_at_lsn {
305                        if reached_lsn >= stop_lsn {
306                            self.send_event(Ok(ReplicationEvent::StoppedAt {
307                                reached: reached_lsn,
308                            }))
309                            .await;
310                            let _ = write_copy_done(stream).await;
311                            return Ok(true); // should stop.
312                        }
313                    }
314
315                    return Ok(false);
316                }
317                // Otherwise, emit raw payload
318                // Check stop condition
319                if let Some(stop_lsn) = self.cfg.stop_at_lsn {
320                    if wal_end >= stop_lsn {
321                        // Send final event, then stop signal
322                        self.send_event(Ok(ReplicationEvent::XLogData {
323                            wal_start,
324                            wal_end,
325                            server_time_micros,
326                            data,
327                        }))
328                        .await;
329
330                        self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
331                            .await;
332
333                        let _ = write_copy_done(stream).await;
334                        return Ok(true);
335                    }
336                }
337
338                self.send_event(Ok(ReplicationEvent::XLogData {
339                    wal_start,
340                    wal_end,
341                    server_time_micros,
342                    data,
343                }))
344                .await;
345
346                Ok(false)
347            }
348        }
349    }
350
351    /// Send an event to the client channel.
352    ///
353    /// If the channel is full or closed, we log and continue - the client
354    /// may have stopped listening but we don't want to crash the worker.
355    async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
356        if self.out.send(event).await.is_err() {
357            tracing::debug!("event channel closed, client may have disconnected");
358        }
359    }
360
361    /// Handle PostgreSQL authentication exchange.
362    async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
363        &mut self,
364        stream: &mut S,
365    ) -> Result<()> {
366        loop {
367            let msg = read_backend_message(stream).await?;
368            match msg.tag {
369                b'R' => {
370                    let (code, rest) = parse_auth_request(&msg.payload)?;
371                    self.handle_auth_request(stream, code, rest).await?;
372                }
373                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
374                b'S' | b'K' => {}      // ParameterStatus, BackendKeyData - ignore
375                b'Z' => return Ok(()), // ReadyForQuery - auth complete
376                _ => {}
377            }
378        }
379    }
380
381    /// Handle a specific authentication request.
382    async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
383        &mut self,
384        stream: &mut S,
385        code: i32,
386        data: &[u8],
387    ) -> Result<()> {
388        match code {
389            0 => Ok(()), // AuthenticationOk
390            3 => {
391                // Cleartext password
392                let mut payload = Vec::from(self.cfg.password.as_bytes());
393                payload.push(0);
394                write_password_message(stream, &payload).await
395            }
396            10 => {
397                // SASL (SCRAM-SHA-256)
398                self.auth_scram(stream, data).await
399            }
400            #[cfg(feature = "md5")]
401            5 => {
402                // MD5 password
403                if data.len() != 4 {
404                    return Err(PgWireError::Protocol(
405                        "MD5 auth: expected 4-byte salt".into(),
406                    ));
407                }
408                let mut salt = [0u8; 4];
409                salt.copy_from_slice(&data[..4]);
410
411                let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
412                let mut payload = hash.into_bytes();
413                payload.push(0);
414                write_password_message(stream, &payload).await
415            }
416            _ => Err(PgWireError::Auth(format!(
417                "unsupported auth method code: {code}"
418            ))),
419        }
420    }
421
422    /// Perform SCRAM-SHA-256 authentication.
423    async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
424        &mut self,
425        stream: &mut S,
426        mechanisms_data: &[u8],
427    ) -> Result<()> {
428        // Parse offered mechanisms
429        let mechanisms = parse_sasl_mechanisms(mechanisms_data);
430
431        if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
432            return Err(PgWireError::Auth(format!(
433                "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
434            )));
435        }
436
437        #[cfg(not(feature = "scram"))]
438        return Err(PgWireError::Auth(
439            "SCRAM authentication required but 'scram' feature not enabled".into(),
440        ));
441
442        #[cfg(feature = "scram")]
443        {
444            use crate::auth::scram::ScramClient;
445
446            let scram = ScramClient::new(&self.cfg.user);
447
448            // Send SASLInitialResponse
449            let mut init = Vec::new();
450            init.extend_from_slice(b"SCRAM-SHA-256\0");
451            init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
452            init.extend_from_slice(scram.client_first.as_bytes());
453            write_password_message(stream, &init).await?;
454
455            // Receive AuthenticationSASLContinue (code 11)
456            let server_first = read_auth_data(stream, 11).await?;
457            let server_first_str = String::from_utf8_lossy(&server_first);
458
459            // Compute and send client-final
460            let (client_final, auth_message, salted_password) =
461                scram.client_final(&self.cfg.password, &server_first_str)?;
462            write_password_message(stream, client_final.as_bytes()).await?;
463
464            // Receive and verify AuthenticationSASLFinal (code 12)
465            let server_final = read_auth_data(stream, 12).await?;
466            let server_final_str = String::from_utf8_lossy(&server_final);
467            ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
468
469            Ok(())
470        }
471    }
472
473    /// Send standby status update to server.
474    async fn send_feedback<S: AsyncWrite + Unpin>(
475        &self,
476        stream: &mut S,
477        applied: Lsn,
478        reply_requested: bool,
479    ) -> Result<()> {
480        let client_time = current_pg_timestamp();
481        let payload = encode_standby_status_update(applied, client_time, reply_requested);
482        write_copy_data(stream, &payload).await
483    }
484}
485
486/// Parse SASL mechanism list from auth data.
487fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
488    let mut mechanisms = Vec::new();
489    let mut remaining = data;
490
491    while !remaining.is_empty() {
492        if let Some(pos) = remaining.iter().position(|&x| x == 0) {
493            if pos == 0 {
494                break; // Empty string terminates list
495            }
496            mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
497            remaining = &remaining[pos + 1..];
498        } else {
499            break;
500        }
501    }
502
503    mechanisms
504}
505
506fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
507    if data.is_empty() {
508        return Ok(None);
509    }
510
511    let tag = data[0];
512    let mut p = &data[1..];
513
514    fn take_i8(p: &mut &[u8]) -> Result<i8> {
515        if p.is_empty() {
516            return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
517        }
518        let v = p[0] as i8;
519        *p = &p[1..];
520        Ok(v)
521    }
522
523    fn take_i32(p: &mut &[u8]) -> Result<i32> {
524        if p.len() < 4 {
525            return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
526        }
527        let (head, tail) = p.split_at(4);
528        *p = tail;
529        Ok(i32::from_be_bytes(head.try_into().unwrap()))
530    }
531
532    fn take_i64(p: &mut &[u8]) -> Result<i64> {
533        if p.len() < 8 {
534            return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
535        }
536        let (head, tail) = p.split_at(8);
537        *p = tail;
538        Ok(i64::from_be_bytes(head.try_into().unwrap()))
539    }
540
541    match tag {
542        b'B' => {
543            let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
544            let commit_time_micros = take_i64(&mut p)?;
545            let xid = take_i32(&mut p)? as u32;
546
547            Ok(Some(ReplicationEvent::Begin {
548                final_lsn,
549                commit_time_micros,
550                xid,
551            }))
552        }
553        b'C' => {
554            let _flags = take_i8(&mut p)?;
555            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); // should be safe
556            let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
557            let commit_time_micros = take_i64(&mut p)?;
558
559            Ok(Some(ReplicationEvent::Commit {
560                lsn,
561                end_lsn,
562                commit_time_micros,
563            }))
564        }
565        _ => Ok(None),
566    }
567}
568
569/// Read authentication response data for a specific auth code.
570async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
571    stream: &mut S,
572    expected_code: i32,
573) -> Result<Vec<u8>> {
574    loop {
575        let msg = read_backend_message(stream).await?;
576        match msg.tag {
577            b'R' => {
578                let (code, data) = parse_auth_request(&msg.payload)?;
579                if code == expected_code {
580                    return Ok(data.to_vec());
581                }
582                return Err(PgWireError::Auth(format!(
583                    "unexpected auth code {code}, expected {expected_code}"
584                )));
585            }
586            b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
587            _ => {} // Skip other messages
588        }
589    }
590}
591
592/// Get current time as PostgreSQL timestamp (microseconds since 2000-01-01).
593fn current_pg_timestamp() -> i64 {
594    use std::time::{SystemTime, UNIX_EPOCH};
595
596    let now = SystemTime::now()
597        .duration_since(UNIX_EPOCH)
598        .unwrap_or_default();
599
600    let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
601    unix_micros - PG_EPOCH_MICROS
602}
603
604/// Compute PostgreSQL MD5 password hash.
605#[cfg(feature = "md5")]
606fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
607    fn md5_hex(data: &[u8]) -> String {
608        format!("{:x}", md5::compute(data))
609    }
610
611    // First hash: md5(password + username)
612    let inner = md5_hex(format!("{password}{user}").as_bytes());
613
614    // Second hash: md5(inner_hash + salt)
615    let mut outer_input = inner.into_bytes();
616    outer_input.extend_from_slice(salt);
617
618    format!("md5{}", md5_hex(&outer_input))
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624
625    #[test]
626    fn parse_sasl_mechanisms_single() {
627        let data = b"SCRAM-SHA-256\0\0";
628        let mechs = parse_sasl_mechanisms(data);
629        assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
630    }
631
632    #[test]
633    fn parse_sasl_mechanisms_multiple() {
634        let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
635        let mechs = parse_sasl_mechanisms(data);
636        assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
637    }
638
639    #[test]
640    fn parse_sasl_mechanisms_empty() {
641        let mechs = parse_sasl_mechanisms(b"\0");
642        assert!(mechs.is_empty());
643    }
644
645    #[test]
646    #[cfg(feature = "md5")]
647    fn postgres_md5_known_value() {
648        // Test vector: user="md5_user", password="md5_pass", salt=[0x01, 0x02, 0x03, 0x04]
649        // Can verify with: SELECT 'md5' || md5(md5('md5_passmd5_user') || E'\\x01020304');
650        let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
651        assert!(hash.starts_with("md5"));
652        assert_eq!(hash.len(), 35); // "md5" + 32 hex chars
653    }
654
655    #[test]
656    fn current_pg_timestamp_is_positive() {
657        // Any time after 2000-01-01 should be positive
658        let ts = current_pg_timestamp();
659        assert!(ts > 0);
660    }
661}