pgwire_replication/client/
worker.rs

1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12    read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13    write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17    encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20/// Shared replication progress updated by the consumer and read by the worker.
21///
22/// Stored as an AtomicU64 so progress updates are cheap and monotonic
23/// without async backpressure.
24pub struct SharedProgress {
25    applied: AtomicU64,
26}
27
28impl SharedProgress {
29    pub fn new(start: Lsn) -> Self {
30        Self {
31            applied: AtomicU64::new(start.as_u64()),
32        }
33    }
34
35    #[inline]
36    pub fn load_applied(&self) -> Lsn {
37        Lsn::from_u64(self.applied.load(Ordering::Acquire))
38    }
39
40    /// Monotonic update: if `lsn` is lower than the currently stored applied LSN,
41    /// this is a no-op.
42    #[inline]
43    pub fn update_applied(&self, lsn: Lsn) {
44        let new = lsn.as_u64();
45        let mut cur = self.applied.load(Ordering::Relaxed);
46
47        while new > cur {
48            match self
49                .applied
50                .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51            {
52                Ok(_) => break,
53                Err(observed) => cur = observed,
54            }
55        }
56    }
57}
58
59/// Events emitted by the replication worker.
60#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62    /// Server heartbeat message.
63    KeepAlive {
64        /// Current server WAL end position
65        wal_end: Lsn,
66        /// Whether server requested a reply (already handled internally)
67        reply_requested: bool,
68        /// Server timestamp (microseconds since 2000-01-01)
69        server_time_micros: i64,
70    },
71
72    /// Start of a transaction (pgoutput Begin message).
73    Begin {
74        final_lsn: Lsn,
75        xid: u32,
76        commit_time_micros: i64,
77    },
78
79    /// WAL data containing transaction changes.
80    XLogData {
81        /// WAL position where this data starts
82        wal_start: Lsn,
83        /// WAL end position (may be 0 for mid-transaction messages)
84        wal_end: Lsn,
85        /// Server timestamp (microseconds since 2000-01-01)
86        server_time_micros: i64,
87        /// pgoutput-encoded change data
88        data: Bytes,
89    },
90
91    /// End of a transaction (pgoutput Commit message).
92    Commit {
93        lsn: Lsn,
94        end_lsn: Lsn,
95        commit_time_micros: i64,
96    },
97
98    /// Emitted when `stop_at_lsn` has been reached.
99    ///
100    /// After this event, no more events will be emitted and the
101    /// replication stream will be closed.
102    StoppedAt {
103        /// The LSN that triggered the stop condition
104        reached: Lsn,
105    },
106}
107
108/// Channel receiver type for replication events.
109pub type ReplicationEventReceiver =
110    mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
111
112/// Internal worker state.
113pub struct WorkerState {
114    cfg: ReplicationConfig,
115    progress: Arc<SharedProgress>,
116    stop_rx: watch::Receiver<bool>,
117    out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
118}
119
120impl WorkerState {
121    pub fn new(
122        cfg: ReplicationConfig,
123        progress: Arc<SharedProgress>,
124        stop_rx: watch::Receiver<bool>,
125        out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
126    ) -> Self {
127        Self {
128            cfg,
129            progress,
130            stop_rx,
131            out,
132        }
133    }
134
135    /// Run the replication protocol on the given stream.
136    pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
137        &mut self,
138        stream: &mut S,
139    ) -> Result<()> {
140        self.startup(stream).await?;
141        self.authenticate(stream).await?;
142        self.start_replication(stream).await?;
143        self.stream_loop(stream).await
144    }
145
146    /// Send startup message with replication parameters.
147    async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
148        let params = [
149            ("user", self.cfg.user.as_str()),
150            ("database", self.cfg.database.as_str()),
151            ("replication", "database"),
152            ("client_encoding", "UTF8"),
153            ("application_name", "pgwire-replication"),
154        ];
155        write_startup_message(stream, 196608, &params).await
156    }
157
158    /// Start the logical replication stream.
159    async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
160        &self,
161        stream: &mut S,
162    ) -> Result<()> {
163        // Escape single quotes in publication name
164        let publication = self.cfg.publication.replace('\'', "''");
165        let sql = format!(
166            "START_REPLICATION SLOT {} LOGICAL {} (proto_version '1', publication_names '{}')",
167            self.cfg.slot, self.cfg.start_lsn, publication,
168        );
169        write_query(stream, &sql).await?;
170
171        // Wait for CopyBothResponse
172        loop {
173            let msg = read_backend_message(stream).await?;
174            match msg.tag {
175                b'W' => return Ok(()), // CopyBothResponse - ready to stream
176                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
177                b'N' | b'S' | b'K' => continue, // Notice, ParameterStatus, BackendKeyData
178                _ => continue,
179            }
180        }
181    }
182
183    /// Main replication streaming loop.
184    async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
185        &mut self,
186        stream: &mut S,
187    ) -> Result<()> {
188        let mut last_status_sent = Instant::now() - self.cfg.status_interval;
189        let mut last_applied = self.progress.load_applied();
190
191        loop {
192            // Update applied LSN from client
193            let current_applied = self.progress.load_applied();
194            if current_applied != last_applied {
195                last_applied = current_applied;
196            }
197
198            // Send periodic status feedback
199            if last_status_sent.elapsed() >= self.cfg.status_interval {
200                self.send_feedback(stream, last_applied, false).await?;
201                last_status_sent = Instant::now();
202            }
203
204            // Use select! to check stop signal while waiting for messages.
205            // This makes stop immediately responsive instead of waiting up to
206            // idle_wakeup_interval.
207            let msg = tokio::select! {
208                biased; // Check stop first for immediate responsiveness
209
210                _ = self.stop_rx.changed() => {
211                    if *self.stop_rx.borrow() {
212                        let _ = write_copy_done(stream).await;
213                        return Ok(());
214                    }
215                    // Spurious wake or stop was reset to false; continue loop
216                    continue;
217                }
218
219                msg_result = tokio::time::timeout(
220                    self.cfg.idle_wakeup_interval,
221                    read_backend_message(stream),
222                ) => {
223                    match msg_result {
224                        Ok(res) => res?, // read_backend_message result
225                        Err(_) => {
226                            // No message received; keep the connection alive by sending feedback
227                            let applied = self.progress.load_applied();
228                            last_applied = applied;
229                            self.send_feedback(stream, applied, false).await?;
230                            last_status_sent = Instant::now();
231                            continue;
232                        }
233                    }
234                }
235            };
236
237            match msg.tag {
238                b'd' => {
239                    let should_stop = self
240                        .handle_copy_data(
241                            stream,
242                            msg.payload,
243                            &mut last_applied,
244                            &mut last_status_sent,
245                        )
246                        .await?;
247                    if should_stop {
248                        return Ok(());
249                    }
250                }
251                b'E' => {
252                    let err = PgWireError::Server(parse_error_response(&msg.payload));
253                    return Err(err);
254                }
255                _ => {
256                    // Unexpected in CopyBoth mode, but ignore gracefully
257                }
258            }
259        }
260    }
261
262    /// Handle a CopyData message. Returns true if we should stop.
263    async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
264        &mut self,
265        stream: &mut S,
266        payload: Bytes,
267        last_applied: &mut Lsn,
268        last_status_sent: &mut Instant,
269    ) -> Result<bool> {
270        let cd = parse_copy_data(payload)?;
271
272        match cd {
273            ReplicationCopyData::KeepAlive {
274                wal_end,
275                server_time_micros,
276                reply_requested,
277            } => {
278                // Respond immediately if server requests it
279                if reply_requested {
280                    let applied = self.progress.load_applied();
281                    *last_applied = applied;
282                    self.send_feedback(stream, applied, true).await?;
283                    *last_status_sent = Instant::now();
284                }
285
286                self.send_event(Ok(ReplicationEvent::KeepAlive {
287                    wal_end,
288                    reply_requested,
289                    server_time_micros,
290                }))
291                .await;
292
293                Ok(false)
294            }
295            ReplicationCopyData::XLogData {
296                wal_start,
297                wal_end,
298                server_time_micros,
299                data,
300            } => {
301                // If the payload is a pgoutput Begin/Commit message, emit only the boundary event.
302                if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
303                    let reached_lsn = match boundary_ev {
304                        ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
305                        ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
306                        _ => wal_end, // should never happen if parser only returns Begin/Commit
307                    };
308
309                    self.send_event(Ok(boundary_ev)).await;
310
311                    // Stop condition (prefer boundary LSN semantics when available)
312                    if let Some(stop_lsn) = self.cfg.stop_at_lsn {
313                        if reached_lsn >= stop_lsn {
314                            self.send_event(Ok(ReplicationEvent::StoppedAt {
315                                reached: reached_lsn,
316                            }))
317                            .await;
318                            let _ = write_copy_done(stream).await;
319                            return Ok(true); // should stop.
320                        }
321                    }
322
323                    return Ok(false);
324                }
325                // Otherwise, emit raw payload
326                // Check stop condition
327                if let Some(stop_lsn) = self.cfg.stop_at_lsn {
328                    if wal_end >= stop_lsn {
329                        // Send final event, then stop signal
330                        self.send_event(Ok(ReplicationEvent::XLogData {
331                            wal_start,
332                            wal_end,
333                            server_time_micros,
334                            data,
335                        }))
336                        .await;
337
338                        self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
339                            .await;
340
341                        let _ = write_copy_done(stream).await;
342                        return Ok(true);
343                    }
344                }
345
346                self.send_event(Ok(ReplicationEvent::XLogData {
347                    wal_start,
348                    wal_end,
349                    server_time_micros,
350                    data,
351                }))
352                .await;
353
354                Ok(false)
355            }
356        }
357    }
358
359    /// Send an event to the client channel.
360    ///
361    /// If the channel is full or closed, we log and continue - the client
362    /// may have stopped listening but we don't want to crash the worker.
363    async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
364        if self.out.send(event).await.is_err() {
365            tracing::debug!("event channel closed, client may have disconnected");
366        }
367    }
368
369    /// Handle PostgreSQL authentication exchange.
370    async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
371        &mut self,
372        stream: &mut S,
373    ) -> Result<()> {
374        loop {
375            let msg = read_backend_message(stream).await?;
376            match msg.tag {
377                b'R' => {
378                    let (code, rest) = parse_auth_request(&msg.payload)?;
379                    self.handle_auth_request(stream, code, rest).await?;
380                }
381                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
382                b'S' | b'K' => {}      // ParameterStatus, BackendKeyData - ignore
383                b'Z' => return Ok(()), // ReadyForQuery - auth complete
384                _ => {}
385            }
386        }
387    }
388
389    /// Handle a specific authentication request.
390    async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
391        &mut self,
392        stream: &mut S,
393        code: i32,
394        data: &[u8],
395    ) -> Result<()> {
396        match code {
397            0 => Ok(()), // AuthenticationOk
398            3 => {
399                // Cleartext password
400                let mut payload = Vec::from(self.cfg.password.as_bytes());
401                payload.push(0);
402                write_password_message(stream, &payload).await
403            }
404            10 => {
405                // SASL (SCRAM-SHA-256)
406                self.auth_scram(stream, data).await
407            }
408            #[cfg(feature = "md5")]
409            5 => {
410                // MD5 password
411                if data.len() != 4 {
412                    return Err(PgWireError::Protocol(
413                        "MD5 auth: expected 4-byte salt".into(),
414                    ));
415                }
416                let mut salt = [0u8; 4];
417                salt.copy_from_slice(&data[..4]);
418
419                let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
420                let mut payload = hash.into_bytes();
421                payload.push(0);
422                write_password_message(stream, &payload).await
423            }
424            _ => Err(PgWireError::Auth(format!(
425                "unsupported auth method code: {code}"
426            ))),
427        }
428    }
429
430    /// Perform SCRAM-SHA-256 authentication.
431    async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
432        &mut self,
433        stream: &mut S,
434        mechanisms_data: &[u8],
435    ) -> Result<()> {
436        // Parse offered mechanisms
437        let mechanisms = parse_sasl_mechanisms(mechanisms_data);
438
439        if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
440            return Err(PgWireError::Auth(format!(
441                "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
442            )));
443        }
444
445        #[cfg(not(feature = "scram"))]
446        return Err(PgWireError::Auth(
447            "SCRAM authentication required but 'scram' feature not enabled".into(),
448        ));
449
450        #[cfg(feature = "scram")]
451        {
452            use crate::auth::scram::ScramClient;
453
454            let scram = ScramClient::new(&self.cfg.user);
455
456            // Send SASLInitialResponse
457            let mut init = Vec::new();
458            init.extend_from_slice(b"SCRAM-SHA-256\0");
459            init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
460            init.extend_from_slice(scram.client_first.as_bytes());
461            write_password_message(stream, &init).await?;
462
463            // Receive AuthenticationSASLContinue (code 11)
464            let server_first = read_auth_data(stream, 11).await?;
465            let server_first_str = String::from_utf8_lossy(&server_first);
466
467            // Compute and send client-final
468            let (client_final, auth_message, salted_password) =
469                scram.client_final(&self.cfg.password, &server_first_str)?;
470            write_password_message(stream, client_final.as_bytes()).await?;
471
472            // Receive and verify AuthenticationSASLFinal (code 12)
473            let server_final = read_auth_data(stream, 12).await?;
474            let server_final_str = String::from_utf8_lossy(&server_final);
475            ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
476
477            Ok(())
478        }
479    }
480
481    /// Send standby status update to server.
482    async fn send_feedback<S: AsyncWrite + Unpin>(
483        &self,
484        stream: &mut S,
485        applied: Lsn,
486        reply_requested: bool,
487    ) -> Result<()> {
488        let client_time = current_pg_timestamp();
489        let payload = encode_standby_status_update(applied, client_time, reply_requested);
490        write_copy_data(stream, &payload).await
491    }
492}
493
494/// Parse SASL mechanism list from auth data.
495fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
496    let mut mechanisms = Vec::new();
497    let mut remaining = data;
498
499    while !remaining.is_empty() {
500        if let Some(pos) = remaining.iter().position(|&x| x == 0) {
501            if pos == 0 {
502                break; // Empty string terminates list
503            }
504            mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
505            remaining = &remaining[pos + 1..];
506        } else {
507            break;
508        }
509    }
510
511    mechanisms
512}
513
514fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
515    if data.is_empty() {
516        return Ok(None);
517    }
518
519    let tag = data[0];
520    let mut p = &data[1..];
521
522    fn take_i8(p: &mut &[u8]) -> Result<i8> {
523        if p.is_empty() {
524            return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
525        }
526        let v = p[0] as i8;
527        *p = &p[1..];
528        Ok(v)
529    }
530
531    fn take_i32(p: &mut &[u8]) -> Result<i32> {
532        if p.len() < 4 {
533            return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
534        }
535        let (head, tail) = p.split_at(4);
536        *p = tail;
537        Ok(i32::from_be_bytes(head.try_into().unwrap()))
538    }
539
540    fn take_i64(p: &mut &[u8]) -> Result<i64> {
541        if p.len() < 8 {
542            return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
543        }
544        let (head, tail) = p.split_at(8);
545        *p = tail;
546        Ok(i64::from_be_bytes(head.try_into().unwrap()))
547    }
548
549    match tag {
550        b'B' => {
551            let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
552            let commit_time_micros = take_i64(&mut p)?;
553            let xid = take_i32(&mut p)? as u32;
554
555            Ok(Some(ReplicationEvent::Begin {
556                final_lsn,
557                commit_time_micros,
558                xid,
559            }))
560        }
561        b'C' => {
562            let _flags = take_i8(&mut p)?;
563            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); // should be safe
564            let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
565            let commit_time_micros = take_i64(&mut p)?;
566
567            Ok(Some(ReplicationEvent::Commit {
568                lsn,
569                end_lsn,
570                commit_time_micros,
571            }))
572        }
573        _ => Ok(None),
574    }
575}
576
577/// Read authentication response data for a specific auth code.
578async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
579    stream: &mut S,
580    expected_code: i32,
581) -> Result<Vec<u8>> {
582    loop {
583        let msg = read_backend_message(stream).await?;
584        match msg.tag {
585            b'R' => {
586                let (code, data) = parse_auth_request(&msg.payload)?;
587                if code == expected_code {
588                    return Ok(data.to_vec());
589                }
590                return Err(PgWireError::Auth(format!(
591                    "unexpected auth code {code}, expected {expected_code}"
592                )));
593            }
594            b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
595            _ => {} // Skip other messages
596        }
597    }
598}
599
600/// Get current time as PostgreSQL timestamp (microseconds since 2000-01-01).
601fn current_pg_timestamp() -> i64 {
602    use std::time::{SystemTime, UNIX_EPOCH};
603
604    let now = SystemTime::now()
605        .duration_since(UNIX_EPOCH)
606        .unwrap_or_default();
607
608    let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
609    unix_micros - PG_EPOCH_MICROS
610}
611
612/// Compute PostgreSQL MD5 password hash.
613#[cfg(feature = "md5")]
614fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
615    fn md5_hex(data: &[u8]) -> String {
616        format!("{:x}", md5::compute(data))
617    }
618
619    // First hash: md5(password + username)
620    let inner = md5_hex(format!("{password}{user}").as_bytes());
621
622    // Second hash: md5(inner_hash + salt)
623    let mut outer_input = inner.into_bytes();
624    outer_input.extend_from_slice(salt);
625
626    format!("md5{}", md5_hex(&outer_input))
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn parse_sasl_mechanisms_single() {
635        let data = b"SCRAM-SHA-256\0\0";
636        let mechs = parse_sasl_mechanisms(data);
637        assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
638    }
639
640    #[test]
641    fn parse_sasl_mechanisms_multiple() {
642        let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
643        let mechs = parse_sasl_mechanisms(data);
644        assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
645    }
646
647    #[test]
648    fn parse_sasl_mechanisms_empty() {
649        let mechs = parse_sasl_mechanisms(b"\0");
650        assert!(mechs.is_empty());
651    }
652
653    #[test]
654    #[cfg(feature = "md5")]
655    fn postgres_md5_known_value() {
656        // Test vector: user="md5_user", password="md5_pass", salt=[0x01, 0x02, 0x03, 0x04]
657        // Can verify with: SELECT 'md5' || md5(md5('md5_passmd5_user') || E'\\x01020304');
658        let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
659        assert!(hash.starts_with("md5"));
660        assert_eq!(hash.len(), 35); // "md5" + 32 hex chars
661    }
662
663    #[test]
664    fn current_pg_timestamp_is_positive() {
665        // Any time after 2000-01-01 should be positive
666        let ts = current_pg_timestamp();
667        assert!(ts > 0);
668    }
669}