Skip to main content

pgwire_replication/client/
worker.rs

1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite, BufReader};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12    read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13    write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17    encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20/// Shared replication progress updated by the consumer and read by the worker.
21///
22/// Stored as an AtomicU64 so progress updates are cheap and monotonic
23/// without async backpressure.
24pub struct SharedProgress {
25    applied: AtomicU64,
26}
27
28impl SharedProgress {
29    pub fn new(start: Lsn) -> Self {
30        Self {
31            applied: AtomicU64::new(start.as_u64()),
32        }
33    }
34
35    #[inline]
36    pub fn load_applied(&self) -> Lsn {
37        Lsn::from_u64(self.applied.load(Ordering::Acquire))
38    }
39
40    /// Monotonic update: if `lsn` is lower than the currently stored applied LSN,
41    /// this is a no-op.
42    #[inline]
43    pub fn update_applied(&self, lsn: Lsn) {
44        let new = lsn.as_u64();
45        let mut cur = self.applied.load(Ordering::Relaxed);
46
47        while new > cur {
48            match self
49                .applied
50                .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51            {
52                Ok(_) => break,
53                Err(observed) => cur = observed,
54            }
55        }
56    }
57}
58
59/// Events emitted by the replication worker.
60#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62    /// Server heartbeat message.
63    KeepAlive {
64        /// Current server WAL end position
65        wal_end: Lsn,
66        /// Whether server requested a reply (already handled internally)
67        reply_requested: bool,
68        /// Server timestamp (microseconds since 2000-01-01)
69        server_time_micros: i64,
70    },
71
72    /// Start of a transaction (pgoutput Begin message).
73    Begin {
74        final_lsn: Lsn,
75        xid: u32,
76        commit_time_micros: i64,
77    },
78
79    /// WAL data containing transaction changes.
80    XLogData {
81        /// WAL position where this data starts
82        wal_start: Lsn,
83        /// WAL end position (may be 0 for mid-transaction messages)
84        wal_end: Lsn,
85        /// Server timestamp (microseconds since 2000-01-01)
86        server_time_micros: i64,
87        /// pgoutput-encoded change data
88        data: Bytes,
89    },
90
91    /// End of a transaction (pgoutput Commit message).
92    Commit {
93        lsn: Lsn,
94        end_lsn: Lsn,
95        commit_time_micros: i64,
96    },
97
98    /// Logical decoding message emitted via `pg_logical_emit_message()`.
99    ///
100    /// Transactional messages are delivered only after the enclosing
101    /// transaction commits. Non-transactional messages are delivered
102    /// immediately.
103    Message {
104        /// Whether the message was emitted inside a transaction.
105        transactional: bool,
106        /// LSN of the message in the WAL.
107        lsn: Lsn,
108        /// Application-defined message prefix (e.g. `"myapp.checkpoint"`).
109        prefix: String,
110        /// Raw message content bytes.
111        content: Bytes,
112    },
113
114    /// Emitted when `stop_at_lsn` has been reached.
115    ///
116    /// After this event, no more events will be emitted and the
117    /// replication stream will be closed.
118    StoppedAt {
119        /// The LSN that triggered the stop condition
120        reached: Lsn,
121    },
122}
123
124/// Channel receiver type for replication events.
125pub type ReplicationEventReceiver =
126    mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
127
128/// Internal worker state.
129pub struct WorkerState {
130    cfg: ReplicationConfig,
131    progress: Arc<SharedProgress>,
132    stop_rx: watch::Receiver<bool>,
133    out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
134}
135
136impl WorkerState {
137    pub fn new(
138        cfg: ReplicationConfig,
139        progress: Arc<SharedProgress>,
140        stop_rx: watch::Receiver<bool>,
141        out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
142    ) -> Self {
143        Self {
144            cfg,
145            progress,
146            stop_rx,
147            out,
148        }
149    }
150
151    /// Run the replication protocol on the given stream.
152    pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
153        &mut self,
154        stream: &mut S,
155    ) -> Result<()> {
156        // Wrap in a 128KB read buffer to batch multiple WAL messages into fewer
157        // recv() syscalls. BufReader delegates AsyncWrite to the inner stream,
158        // so writes (standby status replies, etc.) are unaffected.
159        let mut stream = BufReader::with_capacity(128 * 1024, stream);
160        self.startup(&mut stream).await?;
161        self.authenticate(&mut stream).await?;
162        self.start_replication(&mut stream).await?;
163        self.stream_loop(&mut stream).await
164    }
165
166    /// Send startup message with replication parameters.
167    async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
168        let params = [
169            ("user", self.cfg.user.as_str()),
170            ("database", self.cfg.database.as_str()),
171            ("replication", "database"),
172            ("client_encoding", "UTF8"),
173            ("application_name", "pgwire-replication"),
174        ];
175        write_startup_message(stream, 196608, &params).await
176    }
177
178    /// Start the logical replication stream.
179    async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
180        &self,
181        stream: &mut S,
182    ) -> Result<()> {
183        // Escape single quotes in publication name
184        let publication = self.cfg.publication.replace('\'', "''");
185        let sql = format!(
186            "START_REPLICATION SLOT {} LOGICAL {} \
187            (proto_version '1', publication_names '{}', messages 'true')",
188            self.cfg.slot, self.cfg.start_lsn, publication,
189        );
190        write_query(stream, &sql).await?;
191
192        // Wait for CopyBothResponse
193        loop {
194            let msg = read_backend_message(stream).await?;
195            match msg.tag {
196                b'W' => return Ok(()), // CopyBothResponse - ready to stream
197                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
198                b'N' | b'S' | b'K' => continue, // Notice, ParameterStatus, BackendKeyData
199                _ => continue,
200            }
201        }
202    }
203
204    /// Main replication streaming loop.
205    async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
206        &mut self,
207        stream: &mut S,
208    ) -> Result<()> {
209        let mut last_status_sent = Instant::now() - self.cfg.status_interval;
210        let mut last_applied = self.progress.load_applied();
211
212        loop {
213            // Update applied LSN from client
214            let current_applied = self.progress.load_applied();
215            if current_applied != last_applied {
216                last_applied = current_applied;
217            }
218
219            // Send periodic status feedback
220            if last_status_sent.elapsed() >= self.cfg.status_interval {
221                self.send_feedback(stream, last_applied, false).await?;
222                last_status_sent = Instant::now();
223            }
224
225            // Use select! to check stop signal while waiting for messages.
226            // This makes stop immediately responsive instead of waiting up to
227            // idle_wakeup_interval.
228            let msg = tokio::select! {
229                biased; // Check stop first for immediate responsiveness
230
231                _ = self.stop_rx.changed() => {
232                    if *self.stop_rx.borrow() {
233                        let _ = write_copy_done(stream).await;
234                        return Ok(());
235                    }
236                    // Spurious wake or stop was reset to false; continue loop
237                    continue;
238                }
239
240                msg_result = tokio::time::timeout(
241                    self.cfg.idle_wakeup_interval,
242                    read_backend_message(stream),
243                ) => {
244                    match msg_result {
245                        Ok(res) => res?, // read_backend_message result
246                        Err(_) => {
247                            // No message received; keep the connection alive by sending feedback
248                            let applied = self.progress.load_applied();
249                            last_applied = applied;
250                            self.send_feedback(stream, applied, false).await?;
251                            last_status_sent = Instant::now();
252                            continue;
253                        }
254                    }
255                }
256            };
257
258            match msg.tag {
259                b'd' => {
260                    let should_stop = self
261                        .handle_copy_data(
262                            stream,
263                            msg.payload,
264                            &mut last_applied,
265                            &mut last_status_sent,
266                        )
267                        .await?;
268                    if should_stop {
269                        return Ok(());
270                    }
271                }
272                b'E' => {
273                    let err = PgWireError::Server(parse_error_response(&msg.payload));
274                    return Err(err);
275                }
276                _ => {
277                    // Unexpected in CopyBoth mode, but ignore gracefully
278                }
279            }
280        }
281    }
282
283    /// Handle a CopyData message. Returns true if we should stop.
284    async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
285        &mut self,
286        stream: &mut S,
287        payload: Bytes,
288        last_applied: &mut Lsn,
289        last_status_sent: &mut Instant,
290    ) -> Result<bool> {
291        let cd = parse_copy_data(payload)?;
292
293        match cd {
294            ReplicationCopyData::KeepAlive {
295                wal_end,
296                server_time_micros,
297                reply_requested,
298            } => {
299                // Respond immediately if server requests it
300                if reply_requested {
301                    let applied = self.progress.load_applied();
302                    *last_applied = applied;
303                    self.send_feedback(stream, applied, true).await?;
304                    *last_status_sent = Instant::now();
305                }
306
307                self.send_event(Ok(ReplicationEvent::KeepAlive {
308                    wal_end,
309                    reply_requested,
310                    server_time_micros,
311                }))
312                .await;
313
314                Ok(false)
315            }
316            ReplicationCopyData::XLogData {
317                wal_start,
318                wal_end,
319                server_time_micros,
320                data,
321            } => {
322                // If the payload is a pgoutput Begin/Commit message, emit only the boundary event.
323                if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
324                    let reached_lsn = match boundary_ev {
325                        ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
326                        ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
327                        _ => wal_end, // should never happen if parser only returns Begin/Commit
328                    };
329
330                    self.send_event(Ok(boundary_ev)).await;
331
332                    // Stop condition (prefer boundary LSN semantics when available)
333                    if let Some(stop_lsn) = self.cfg.stop_at_lsn {
334                        if reached_lsn >= stop_lsn {
335                            self.send_event(Ok(ReplicationEvent::StoppedAt {
336                                reached: reached_lsn,
337                            }))
338                            .await;
339                            let _ = write_copy_done(stream).await;
340                            return Ok(true); // should stop.
341                        }
342                    }
343
344                    return Ok(false);
345                }
346                // Otherwise, emit raw payload
347                // Check stop condition
348                if let Some(stop_lsn) = self.cfg.stop_at_lsn {
349                    if wal_end >= stop_lsn {
350                        // Send final event, then stop signal
351                        self.send_event(Ok(ReplicationEvent::XLogData {
352                            wal_start,
353                            wal_end,
354                            server_time_micros,
355                            data,
356                        }))
357                        .await;
358
359                        self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
360                            .await;
361
362                        let _ = write_copy_done(stream).await;
363                        return Ok(true);
364                    }
365                }
366
367                self.send_event(Ok(ReplicationEvent::XLogData {
368                    wal_start,
369                    wal_end,
370                    server_time_micros,
371                    data,
372                }))
373                .await;
374
375                Ok(false)
376            }
377        }
378    }
379
380    /// Send an event to the client channel.
381    ///
382    /// If the channel is full or closed, we log and continue - the client
383    /// may have stopped listening but we don't want to crash the worker.
384    async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
385        if self.out.send(event).await.is_err() {
386            tracing::debug!("event channel closed, client may have disconnected");
387        }
388    }
389
390    /// Handle PostgreSQL authentication exchange.
391    async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
392        &mut self,
393        stream: &mut S,
394    ) -> Result<()> {
395        loop {
396            let msg = read_backend_message(stream).await?;
397            match msg.tag {
398                b'R' => {
399                    let (code, rest) = parse_auth_request(&msg.payload)?;
400                    self.handle_auth_request(stream, code, rest).await?;
401                }
402                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
403                b'S' | b'K' => {}      // ParameterStatus, BackendKeyData - ignore
404                b'Z' => return Ok(()), // ReadyForQuery - auth complete
405                _ => {}
406            }
407        }
408    }
409
410    /// Handle a specific authentication request.
411    async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
412        &mut self,
413        stream: &mut S,
414        code: i32,
415        data: &[u8],
416    ) -> Result<()> {
417        match code {
418            0 => Ok(()), // AuthenticationOk
419            3 => {
420                // Cleartext password
421                let mut payload = Vec::from(self.cfg.password.as_bytes());
422                payload.push(0);
423                write_password_message(stream, &payload).await
424            }
425            10 => {
426                // SASL (SCRAM-SHA-256)
427                self.auth_scram(stream, data).await
428            }
429            #[cfg(feature = "md5")]
430            5 => {
431                // MD5 password
432                if data.len() != 4 {
433                    return Err(PgWireError::Protocol(
434                        "MD5 auth: expected 4-byte salt".into(),
435                    ));
436                }
437                let mut salt = [0u8; 4];
438                salt.copy_from_slice(&data[..4]);
439
440                let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
441                let mut payload = hash.into_bytes();
442                payload.push(0);
443                write_password_message(stream, &payload).await
444            }
445            _ => Err(PgWireError::Auth(format!(
446                "unsupported auth method code: {code}"
447            ))),
448        }
449    }
450
451    /// Perform SCRAM-SHA-256 authentication.
452    async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
453        &mut self,
454        stream: &mut S,
455        mechanisms_data: &[u8],
456    ) -> Result<()> {
457        // Parse offered mechanisms
458        let mechanisms = parse_sasl_mechanisms(mechanisms_data);
459
460        if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
461            return Err(PgWireError::Auth(format!(
462                "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
463            )));
464        }
465
466        #[cfg(not(feature = "scram"))]
467        return Err(PgWireError::Auth(
468            "SCRAM authentication required but 'scram' feature not enabled".into(),
469        ));
470
471        #[cfg(feature = "scram")]
472        {
473            use crate::auth::scram::ScramClient;
474
475            let scram = ScramClient::new(&self.cfg.user);
476
477            // Send SASLInitialResponse
478            let mut init = Vec::new();
479            init.extend_from_slice(b"SCRAM-SHA-256\0");
480            init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
481            init.extend_from_slice(scram.client_first.as_bytes());
482            write_password_message(stream, &init).await?;
483
484            // Receive AuthenticationSASLContinue (code 11)
485            let server_first = read_auth_data(stream, 11).await?;
486            let server_first_str = String::from_utf8_lossy(&server_first);
487
488            // Compute and send client-final
489            let (client_final, auth_message, salted_password) =
490                scram.client_final(&self.cfg.password, &server_first_str)?;
491            write_password_message(stream, client_final.as_bytes()).await?;
492
493            // Receive and verify AuthenticationSASLFinal (code 12)
494            let server_final = read_auth_data(stream, 12).await?;
495            let server_final_str = String::from_utf8_lossy(&server_final);
496            ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
497
498            Ok(())
499        }
500    }
501
502    /// Send standby status update to server.
503    async fn send_feedback<S: AsyncWrite + Unpin>(
504        &self,
505        stream: &mut S,
506        applied: Lsn,
507        reply_requested: bool,
508    ) -> Result<()> {
509        let client_time = current_pg_timestamp();
510        let payload = encode_standby_status_update(applied, client_time, reply_requested);
511        write_copy_data(stream, &payload).await
512    }
513}
514
515/// Parse SASL mechanism list from auth data.
516fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
517    let mut mechanisms = Vec::new();
518    let mut remaining = data;
519
520    while !remaining.is_empty() {
521        if let Some(pos) = remaining.iter().position(|&x| x == 0) {
522            if pos == 0 {
523                break; // Empty string terminates list
524            }
525            mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
526            remaining = &remaining[pos + 1..];
527        } else {
528            break;
529        }
530    }
531
532    mechanisms
533}
534
535fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
536    if data.is_empty() {
537        return Ok(None);
538    }
539
540    let tag = data[0];
541    let mut p = &data[1..];
542
543    fn take_i8(p: &mut &[u8]) -> Result<i8> {
544        if p.is_empty() {
545            return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
546        }
547        let v = p[0] as i8;
548        *p = &p[1..];
549        Ok(v)
550    }
551
552    fn take_i32(p: &mut &[u8]) -> Result<i32> {
553        if p.len() < 4 {
554            return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
555        }
556        let (head, tail) = p.split_at(4);
557        *p = tail;
558        Ok(i32::from_be_bytes(head.try_into().unwrap()))
559    }
560
561    fn take_i64(p: &mut &[u8]) -> Result<i64> {
562        if p.len() < 8 {
563            return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
564        }
565        let (head, tail) = p.split_at(8);
566        *p = tail;
567        Ok(i64::from_be_bytes(head.try_into().unwrap()))
568    }
569
570    match tag {
571        b'B' => {
572            let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
573            let commit_time_micros = take_i64(&mut p)?;
574            let xid = take_i32(&mut p)? as u32;
575
576            Ok(Some(ReplicationEvent::Begin {
577                final_lsn,
578                commit_time_micros,
579                xid,
580            }))
581        }
582        b'C' => {
583            let _flags = take_i8(&mut p)?;
584            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); // should be safe
585            let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
586            let commit_time_micros = take_i64(&mut p)?;
587
588            Ok(Some(ReplicationEvent::Commit {
589                lsn,
590                end_lsn,
591                commit_time_micros,
592            }))
593        }
594        b'M' => {
595            // Logical decoding message (pg_logical_emit_message)
596            // Wire: flags(1) + lsn(8) + prefix(null-terminated) + content_len(4) + content(n)
597            let flags = take_i8(&mut p)?;
598            let transactional = (flags & 1) != 0;
599            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
600
601            // Read null-terminated prefix string
602            let prefix_end = p.iter().position(|&b| b == 0).ok_or_else(|| {
603                PgWireError::Protocol("pgoutput Message: missing null terminator for prefix".into())
604            })?;
605            let prefix = String::from_utf8_lossy(&p[..prefix_end]).into_owned();
606            p = &p[prefix_end + 1..]; // advance past null byte
607
608            let content_len = take_i32(&mut p)? as usize;
609            if p.len() < content_len {
610                return Err(PgWireError::Protocol(format!(
611                    "pgoutput Message: expected {} content bytes, got {}",
612                    content_len,
613                    p.len()
614                )));
615            }
616            let content = Bytes::copy_from_slice(&p[..content_len]);
617
618            Ok(Some(ReplicationEvent::Message {
619                transactional,
620                lsn,
621                prefix,
622                content,
623            }))
624        }
625        _ => Ok(None),
626    }
627}
628
629/// Read authentication response data for a specific auth code.
630async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
631    stream: &mut S,
632    expected_code: i32,
633) -> Result<Vec<u8>> {
634    loop {
635        let msg = read_backend_message(stream).await?;
636        match msg.tag {
637            b'R' => {
638                let (code, data) = parse_auth_request(&msg.payload)?;
639                if code == expected_code {
640                    return Ok(data.to_vec());
641                }
642                return Err(PgWireError::Auth(format!(
643                    "unexpected auth code {code}, expected {expected_code}"
644                )));
645            }
646            b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
647            _ => {} // Skip other messages
648        }
649    }
650}
651
652/// Get current time as PostgreSQL timestamp (microseconds since 2000-01-01).
653fn current_pg_timestamp() -> i64 {
654    use std::time::{SystemTime, UNIX_EPOCH};
655
656    let now = SystemTime::now()
657        .duration_since(UNIX_EPOCH)
658        .unwrap_or_default();
659
660    let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
661    unix_micros - PG_EPOCH_MICROS
662}
663
664/// Compute PostgreSQL MD5 password hash.
665#[cfg(feature = "md5")]
666fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
667    fn md5_hex(data: &[u8]) -> String {
668        format!("{:x}", md5::compute(data))
669    }
670
671    // First hash: md5(password + username)
672    let inner = md5_hex(format!("{password}{user}").as_bytes());
673
674    // Second hash: md5(inner_hash + salt)
675    let mut outer_input = inner.into_bytes();
676    outer_input.extend_from_slice(salt);
677
678    format!("md5{}", md5_hex(&outer_input))
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684
685    #[test]
686    fn parse_sasl_mechanisms_single() {
687        let data = b"SCRAM-SHA-256\0\0";
688        let mechs = parse_sasl_mechanisms(data);
689        assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
690    }
691
692    #[test]
693    fn parse_sasl_mechanisms_multiple() {
694        let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
695        let mechs = parse_sasl_mechanisms(data);
696        assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
697    }
698
699    #[test]
700    fn parse_sasl_mechanisms_empty() {
701        let mechs = parse_sasl_mechanisms(b"\0");
702        assert!(mechs.is_empty());
703    }
704
705    #[test]
706    #[cfg(feature = "md5")]
707    fn postgres_md5_known_value() {
708        // Test vector: user="md5_user", password="md5_pass", salt=[0x01, 0x02, 0x03, 0x04]
709        // Can verify with: SELECT 'md5' || md5(md5('md5_passmd5_user') || E'\\x01020304');
710        let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
711        assert!(hash.starts_with("md5"));
712        assert_eq!(hash.len(), 35); // "md5" + 32 hex chars
713    }
714
715    #[test]
716    fn current_pg_timestamp_is_positive() {
717        // Any time after 2000-01-01 should be positive
718        let ts = current_pg_timestamp();
719        assert!(ts > 0);
720    }
721}