Skip to main content

pgwire_replication/client/
worker.rs

1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12    read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13    write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17    encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20/// Shared replication progress updated by the consumer and read by the worker.
21///
22/// Stored as an AtomicU64 so progress updates are cheap and monotonic
23/// without async backpressure.
24pub struct SharedProgress {
25    applied: AtomicU64,
26}
27
28impl SharedProgress {
29    pub fn new(start: Lsn) -> Self {
30        Self {
31            applied: AtomicU64::new(start.as_u64()),
32        }
33    }
34
35    #[inline]
36    pub fn load_applied(&self) -> Lsn {
37        Lsn::from_u64(self.applied.load(Ordering::Acquire))
38    }
39
40    /// Monotonic update: if `lsn` is lower than the currently stored applied LSN,
41    /// this is a no-op.
42    #[inline]
43    pub fn update_applied(&self, lsn: Lsn) {
44        let new = lsn.as_u64();
45        let mut cur = self.applied.load(Ordering::Relaxed);
46
47        while new > cur {
48            match self
49                .applied
50                .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51            {
52                Ok(_) => break,
53                Err(observed) => cur = observed,
54            }
55        }
56    }
57}
58
59/// Events emitted by the replication worker.
60#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62    /// Server heartbeat message.
63    KeepAlive {
64        /// Current server WAL end position
65        wal_end: Lsn,
66        /// Whether server requested a reply (already handled internally)
67        reply_requested: bool,
68        /// Server timestamp (microseconds since 2000-01-01)
69        server_time_micros: i64,
70    },
71
72    /// Start of a transaction (pgoutput Begin message).
73    Begin {
74        final_lsn: Lsn,
75        xid: u32,
76        commit_time_micros: i64,
77    },
78
79    /// WAL data containing transaction changes.
80    XLogData {
81        /// WAL position where this data starts
82        wal_start: Lsn,
83        /// WAL end position (may be 0 for mid-transaction messages)
84        wal_end: Lsn,
85        /// Server timestamp (microseconds since 2000-01-01)
86        server_time_micros: i64,
87        /// pgoutput-encoded change data
88        data: Bytes,
89    },
90
91    /// End of a transaction (pgoutput Commit message).
92    Commit {
93        lsn: Lsn,
94        end_lsn: Lsn,
95        commit_time_micros: i64,
96    },
97
98    /// Logical decoding message emitted via `pg_logical_emit_message()`.
99    ///
100    /// Transactional messages are delivered only after the enclosing
101    /// transaction commits. Non-transactional messages are delivered
102    /// immediately.
103    Message {
104        /// Whether the message was emitted inside a transaction.
105        transactional: bool,
106        /// LSN of the message in the WAL.
107        lsn: Lsn,
108        /// Application-defined message prefix (e.g. `"myapp.checkpoint"`).
109        prefix: String,
110        /// Raw message content bytes.
111        content: Bytes,
112    },
113
114    /// Emitted when `stop_at_lsn` has been reached.
115    ///
116    /// After this event, no more events will be emitted and the
117    /// replication stream will be closed.
118    StoppedAt {
119        /// The LSN that triggered the stop condition
120        reached: Lsn,
121    },
122}
123
124/// Channel receiver type for replication events.
125pub type ReplicationEventReceiver =
126    mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
127
128/// Internal worker state.
129pub struct WorkerState {
130    cfg: ReplicationConfig,
131    progress: Arc<SharedProgress>,
132    stop_rx: watch::Receiver<bool>,
133    out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
134}
135
136impl WorkerState {
137    pub fn new(
138        cfg: ReplicationConfig,
139        progress: Arc<SharedProgress>,
140        stop_rx: watch::Receiver<bool>,
141        out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
142    ) -> Self {
143        Self {
144            cfg,
145            progress,
146            stop_rx,
147            out,
148        }
149    }
150
151    /// Run the replication protocol on the given stream.
152    pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
153        &mut self,
154        stream: &mut S,
155    ) -> Result<()> {
156        self.startup(stream).await?;
157        self.authenticate(stream).await?;
158        self.start_replication(stream).await?;
159        self.stream_loop(stream).await
160    }
161
162    /// Send startup message with replication parameters.
163    async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
164        let params = [
165            ("user", self.cfg.user.as_str()),
166            ("database", self.cfg.database.as_str()),
167            ("replication", "database"),
168            ("client_encoding", "UTF8"),
169            ("application_name", "pgwire-replication"),
170        ];
171        write_startup_message(stream, 196608, &params).await
172    }
173
174    /// Start the logical replication stream.
175    async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
176        &self,
177        stream: &mut S,
178    ) -> Result<()> {
179        // Escape single quotes in publication name
180        let publication = self.cfg.publication.replace('\'', "''");
181        let sql = format!(
182            "START_REPLICATION SLOT {} LOGICAL {} \
183            (proto_version '1', publication_names '{}', messages 'true')",
184            self.cfg.slot, self.cfg.start_lsn, publication,
185        );
186        write_query(stream, &sql).await?;
187
188        // Wait for CopyBothResponse
189        loop {
190            let msg = read_backend_message(stream).await?;
191            match msg.tag {
192                b'W' => return Ok(()), // CopyBothResponse - ready to stream
193                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
194                b'N' | b'S' | b'K' => continue, // Notice, ParameterStatus, BackendKeyData
195                _ => continue,
196            }
197        }
198    }
199
200    /// Main replication streaming loop.
201    async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
202        &mut self,
203        stream: &mut S,
204    ) -> Result<()> {
205        let mut last_status_sent = Instant::now() - self.cfg.status_interval;
206        let mut last_applied = self.progress.load_applied();
207
208        loop {
209            // Update applied LSN from client
210            let current_applied = self.progress.load_applied();
211            if current_applied != last_applied {
212                last_applied = current_applied;
213            }
214
215            // Send periodic status feedback
216            if last_status_sent.elapsed() >= self.cfg.status_interval {
217                self.send_feedback(stream, last_applied, false).await?;
218                last_status_sent = Instant::now();
219            }
220
221            // Use select! to check stop signal while waiting for messages.
222            // This makes stop immediately responsive instead of waiting up to
223            // idle_wakeup_interval.
224            let msg = tokio::select! {
225                biased; // Check stop first for immediate responsiveness
226
227                _ = self.stop_rx.changed() => {
228                    if *self.stop_rx.borrow() {
229                        let _ = write_copy_done(stream).await;
230                        return Ok(());
231                    }
232                    // Spurious wake or stop was reset to false; continue loop
233                    continue;
234                }
235
236                msg_result = tokio::time::timeout(
237                    self.cfg.idle_wakeup_interval,
238                    read_backend_message(stream),
239                ) => {
240                    match msg_result {
241                        Ok(res) => res?, // read_backend_message result
242                        Err(_) => {
243                            // No message received; keep the connection alive by sending feedback
244                            let applied = self.progress.load_applied();
245                            last_applied = applied;
246                            self.send_feedback(stream, applied, false).await?;
247                            last_status_sent = Instant::now();
248                            continue;
249                        }
250                    }
251                }
252            };
253
254            match msg.tag {
255                b'd' => {
256                    let should_stop = self
257                        .handle_copy_data(
258                            stream,
259                            msg.payload,
260                            &mut last_applied,
261                            &mut last_status_sent,
262                        )
263                        .await?;
264                    if should_stop {
265                        return Ok(());
266                    }
267                }
268                b'E' => {
269                    let err = PgWireError::Server(parse_error_response(&msg.payload));
270                    return Err(err);
271                }
272                _ => {
273                    // Unexpected in CopyBoth mode, but ignore gracefully
274                }
275            }
276        }
277    }
278
279    /// Handle a CopyData message. Returns true if we should stop.
280    async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
281        &mut self,
282        stream: &mut S,
283        payload: Bytes,
284        last_applied: &mut Lsn,
285        last_status_sent: &mut Instant,
286    ) -> Result<bool> {
287        let cd = parse_copy_data(payload)?;
288
289        match cd {
290            ReplicationCopyData::KeepAlive {
291                wal_end,
292                server_time_micros,
293                reply_requested,
294            } => {
295                // Respond immediately if server requests it
296                if reply_requested {
297                    let applied = self.progress.load_applied();
298                    *last_applied = applied;
299                    self.send_feedback(stream, applied, true).await?;
300                    *last_status_sent = Instant::now();
301                }
302
303                self.send_event(Ok(ReplicationEvent::KeepAlive {
304                    wal_end,
305                    reply_requested,
306                    server_time_micros,
307                }))
308                .await;
309
310                Ok(false)
311            }
312            ReplicationCopyData::XLogData {
313                wal_start,
314                wal_end,
315                server_time_micros,
316                data,
317            } => {
318                // If the payload is a pgoutput Begin/Commit message, emit only the boundary event.
319                if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
320                    let reached_lsn = match boundary_ev {
321                        ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
322                        ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
323                        _ => wal_end, // should never happen if parser only returns Begin/Commit
324                    };
325
326                    self.send_event(Ok(boundary_ev)).await;
327
328                    // Stop condition (prefer boundary LSN semantics when available)
329                    if let Some(stop_lsn) = self.cfg.stop_at_lsn {
330                        if reached_lsn >= stop_lsn {
331                            self.send_event(Ok(ReplicationEvent::StoppedAt {
332                                reached: reached_lsn,
333                            }))
334                            .await;
335                            let _ = write_copy_done(stream).await;
336                            return Ok(true); // should stop.
337                        }
338                    }
339
340                    return Ok(false);
341                }
342                // Otherwise, emit raw payload
343                // Check stop condition
344                if let Some(stop_lsn) = self.cfg.stop_at_lsn {
345                    if wal_end >= stop_lsn {
346                        // Send final event, then stop signal
347                        self.send_event(Ok(ReplicationEvent::XLogData {
348                            wal_start,
349                            wal_end,
350                            server_time_micros,
351                            data,
352                        }))
353                        .await;
354
355                        self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
356                            .await;
357
358                        let _ = write_copy_done(stream).await;
359                        return Ok(true);
360                    }
361                }
362
363                self.send_event(Ok(ReplicationEvent::XLogData {
364                    wal_start,
365                    wal_end,
366                    server_time_micros,
367                    data,
368                }))
369                .await;
370
371                Ok(false)
372            }
373        }
374    }
375
376    /// Send an event to the client channel.
377    ///
378    /// If the channel is full or closed, we log and continue - the client
379    /// may have stopped listening but we don't want to crash the worker.
380    async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
381        if self.out.send(event).await.is_err() {
382            tracing::debug!("event channel closed, client may have disconnected");
383        }
384    }
385
386    /// Handle PostgreSQL authentication exchange.
387    async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
388        &mut self,
389        stream: &mut S,
390    ) -> Result<()> {
391        loop {
392            let msg = read_backend_message(stream).await?;
393            match msg.tag {
394                b'R' => {
395                    let (code, rest) = parse_auth_request(&msg.payload)?;
396                    self.handle_auth_request(stream, code, rest).await?;
397                }
398                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
399                b'S' | b'K' => {}      // ParameterStatus, BackendKeyData - ignore
400                b'Z' => return Ok(()), // ReadyForQuery - auth complete
401                _ => {}
402            }
403        }
404    }
405
406    /// Handle a specific authentication request.
407    async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
408        &mut self,
409        stream: &mut S,
410        code: i32,
411        data: &[u8],
412    ) -> Result<()> {
413        match code {
414            0 => Ok(()), // AuthenticationOk
415            3 => {
416                // Cleartext password
417                let mut payload = Vec::from(self.cfg.password.as_bytes());
418                payload.push(0);
419                write_password_message(stream, &payload).await
420            }
421            10 => {
422                // SASL (SCRAM-SHA-256)
423                self.auth_scram(stream, data).await
424            }
425            #[cfg(feature = "md5")]
426            5 => {
427                // MD5 password
428                if data.len() != 4 {
429                    return Err(PgWireError::Protocol(
430                        "MD5 auth: expected 4-byte salt".into(),
431                    ));
432                }
433                let mut salt = [0u8; 4];
434                salt.copy_from_slice(&data[..4]);
435
436                let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
437                let mut payload = hash.into_bytes();
438                payload.push(0);
439                write_password_message(stream, &payload).await
440            }
441            _ => Err(PgWireError::Auth(format!(
442                "unsupported auth method code: {code}"
443            ))),
444        }
445    }
446
447    /// Perform SCRAM-SHA-256 authentication.
448    async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
449        &mut self,
450        stream: &mut S,
451        mechanisms_data: &[u8],
452    ) -> Result<()> {
453        // Parse offered mechanisms
454        let mechanisms = parse_sasl_mechanisms(mechanisms_data);
455
456        if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
457            return Err(PgWireError::Auth(format!(
458                "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
459            )));
460        }
461
462        #[cfg(not(feature = "scram"))]
463        return Err(PgWireError::Auth(
464            "SCRAM authentication required but 'scram' feature not enabled".into(),
465        ));
466
467        #[cfg(feature = "scram")]
468        {
469            use crate::auth::scram::ScramClient;
470
471            let scram = ScramClient::new(&self.cfg.user);
472
473            // Send SASLInitialResponse
474            let mut init = Vec::new();
475            init.extend_from_slice(b"SCRAM-SHA-256\0");
476            init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
477            init.extend_from_slice(scram.client_first.as_bytes());
478            write_password_message(stream, &init).await?;
479
480            // Receive AuthenticationSASLContinue (code 11)
481            let server_first = read_auth_data(stream, 11).await?;
482            let server_first_str = String::from_utf8_lossy(&server_first);
483
484            // Compute and send client-final
485            let (client_final, auth_message, salted_password) =
486                scram.client_final(&self.cfg.password, &server_first_str)?;
487            write_password_message(stream, client_final.as_bytes()).await?;
488
489            // Receive and verify AuthenticationSASLFinal (code 12)
490            let server_final = read_auth_data(stream, 12).await?;
491            let server_final_str = String::from_utf8_lossy(&server_final);
492            ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
493
494            Ok(())
495        }
496    }
497
498    /// Send standby status update to server.
499    async fn send_feedback<S: AsyncWrite + Unpin>(
500        &self,
501        stream: &mut S,
502        applied: Lsn,
503        reply_requested: bool,
504    ) -> Result<()> {
505        let client_time = current_pg_timestamp();
506        let payload = encode_standby_status_update(applied, client_time, reply_requested);
507        write_copy_data(stream, &payload).await
508    }
509}
510
511/// Parse SASL mechanism list from auth data.
512fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
513    let mut mechanisms = Vec::new();
514    let mut remaining = data;
515
516    while !remaining.is_empty() {
517        if let Some(pos) = remaining.iter().position(|&x| x == 0) {
518            if pos == 0 {
519                break; // Empty string terminates list
520            }
521            mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
522            remaining = &remaining[pos + 1..];
523        } else {
524            break;
525        }
526    }
527
528    mechanisms
529}
530
531fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
532    if data.is_empty() {
533        return Ok(None);
534    }
535
536    let tag = data[0];
537    let mut p = &data[1..];
538
539    fn take_i8(p: &mut &[u8]) -> Result<i8> {
540        if p.is_empty() {
541            return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
542        }
543        let v = p[0] as i8;
544        *p = &p[1..];
545        Ok(v)
546    }
547
548    fn take_i32(p: &mut &[u8]) -> Result<i32> {
549        if p.len() < 4 {
550            return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
551        }
552        let (head, tail) = p.split_at(4);
553        *p = tail;
554        Ok(i32::from_be_bytes(head.try_into().unwrap()))
555    }
556
557    fn take_i64(p: &mut &[u8]) -> Result<i64> {
558        if p.len() < 8 {
559            return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
560        }
561        let (head, tail) = p.split_at(8);
562        *p = tail;
563        Ok(i64::from_be_bytes(head.try_into().unwrap()))
564    }
565
566    match tag {
567        b'B' => {
568            let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
569            let commit_time_micros = take_i64(&mut p)?;
570            let xid = take_i32(&mut p)? as u32;
571
572            Ok(Some(ReplicationEvent::Begin {
573                final_lsn,
574                commit_time_micros,
575                xid,
576            }))
577        }
578        b'C' => {
579            let _flags = take_i8(&mut p)?;
580            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); // should be safe
581            let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
582            let commit_time_micros = take_i64(&mut p)?;
583
584            Ok(Some(ReplicationEvent::Commit {
585                lsn,
586                end_lsn,
587                commit_time_micros,
588            }))
589        }
590        b'M' => {
591            // Logical decoding message (pg_logical_emit_message)
592            // Wire: flags(1) + lsn(8) + prefix(null-terminated) + content_len(4) + content(n)
593            let flags = take_i8(&mut p)?;
594            let transactional = (flags & 1) != 0;
595            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
596
597            // Read null-terminated prefix string
598            let prefix_end = p.iter().position(|&b| b == 0).ok_or_else(|| {
599                PgWireError::Protocol("pgoutput Message: missing null terminator for prefix".into())
600            })?;
601            let prefix = String::from_utf8_lossy(&p[..prefix_end]).into_owned();
602            p = &p[prefix_end + 1..]; // advance past null byte
603
604            let content_len = take_i32(&mut p)? as usize;
605            if p.len() < content_len {
606                return Err(PgWireError::Protocol(format!(
607                    "pgoutput Message: expected {} content bytes, got {}",
608                    content_len,
609                    p.len()
610                )));
611            }
612            let content = Bytes::copy_from_slice(&p[..content_len]);
613
614            Ok(Some(ReplicationEvent::Message {
615                transactional,
616                lsn,
617                prefix,
618                content,
619            }))
620        }
621        _ => Ok(None),
622    }
623}
624
625/// Read authentication response data for a specific auth code.
626async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
627    stream: &mut S,
628    expected_code: i32,
629) -> Result<Vec<u8>> {
630    loop {
631        let msg = read_backend_message(stream).await?;
632        match msg.tag {
633            b'R' => {
634                let (code, data) = parse_auth_request(&msg.payload)?;
635                if code == expected_code {
636                    return Ok(data.to_vec());
637                }
638                return Err(PgWireError::Auth(format!(
639                    "unexpected auth code {code}, expected {expected_code}"
640                )));
641            }
642            b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
643            _ => {} // Skip other messages
644        }
645    }
646}
647
648/// Get current time as PostgreSQL timestamp (microseconds since 2000-01-01).
649fn current_pg_timestamp() -> i64 {
650    use std::time::{SystemTime, UNIX_EPOCH};
651
652    let now = SystemTime::now()
653        .duration_since(UNIX_EPOCH)
654        .unwrap_or_default();
655
656    let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
657    unix_micros - PG_EPOCH_MICROS
658}
659
660/// Compute PostgreSQL MD5 password hash.
661#[cfg(feature = "md5")]
662fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
663    fn md5_hex(data: &[u8]) -> String {
664        format!("{:x}", md5::compute(data))
665    }
666
667    // First hash: md5(password + username)
668    let inner = md5_hex(format!("{password}{user}").as_bytes());
669
670    // Second hash: md5(inner_hash + salt)
671    let mut outer_input = inner.into_bytes();
672    outer_input.extend_from_slice(salt);
673
674    format!("md5{}", md5_hex(&outer_input))
675}
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680
681    #[test]
682    fn parse_sasl_mechanisms_single() {
683        let data = b"SCRAM-SHA-256\0\0";
684        let mechs = parse_sasl_mechanisms(data);
685        assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
686    }
687
688    #[test]
689    fn parse_sasl_mechanisms_multiple() {
690        let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
691        let mechs = parse_sasl_mechanisms(data);
692        assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
693    }
694
695    #[test]
696    fn parse_sasl_mechanisms_empty() {
697        let mechs = parse_sasl_mechanisms(b"\0");
698        assert!(mechs.is_empty());
699    }
700
701    #[test]
702    #[cfg(feature = "md5")]
703    fn postgres_md5_known_value() {
704        // Test vector: user="md5_user", password="md5_pass", salt=[0x01, 0x02, 0x03, 0x04]
705        // Can verify with: SELECT 'md5' || md5(md5('md5_passmd5_user') || E'\\x01020304');
706        let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
707        assert!(hash.starts_with("md5"));
708        assert_eq!(hash.len(), 35); // "md5" + 32 hex chars
709    }
710
711    #[test]
712    fn current_pg_timestamp_is_positive() {
713        // Any time after 2000-01-01 should be positive
714        let ts = current_pg_timestamp();
715        assert!(ts > 0);
716    }
717}