Skip to main content

pgwire_replication/client/
worker.rs

1use bytes::{Bytes, BytesMut};
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite, BufReader};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12    read_backend_message, read_backend_message_into, write_copy_data, write_copy_done,
13    write_password_message, write_query, write_startup_message,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17    encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20/// Shared replication progress updated by the consumer and read by the worker.
21///
22/// Stored as an AtomicU64 so progress updates are cheap and monotonic
23/// without async backpressure.
24pub struct SharedProgress {
25    applied: AtomicU64,
26}
27
28impl SharedProgress {
29    pub fn new(start: Lsn) -> Self {
30        Self {
31            applied: AtomicU64::new(start.as_u64()),
32        }
33    }
34
35    #[inline]
36    pub fn load_applied(&self) -> Lsn {
37        Lsn::from_u64(self.applied.load(Ordering::Acquire))
38    }
39
40    /// Monotonic update: if `lsn` is lower than the currently stored applied LSN,
41    /// this is a no-op.
42    #[inline]
43    pub fn update_applied(&self, lsn: Lsn) {
44        let new = lsn.as_u64();
45        let mut cur = self.applied.load(Ordering::Relaxed);
46
47        while new > cur {
48            match self
49                .applied
50                .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51            {
52                Ok(_) => break,
53                Err(observed) => cur = observed,
54            }
55        }
56    }
57}
58
59/// Events emitted by the replication worker.
60#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62    /// Server heartbeat message.
63    KeepAlive {
64        /// Current server WAL end position
65        wal_end: Lsn,
66        /// Whether server requested a reply (already handled internally)
67        reply_requested: bool,
68        /// Server timestamp (microseconds since 2000-01-01)
69        server_time_micros: i64,
70    },
71
72    /// Start of a transaction (pgoutput Begin message).
73    Begin {
74        final_lsn: Lsn,
75        xid: u32,
76        commit_time_micros: i64,
77    },
78
79    /// WAL data containing transaction changes.
80    XLogData {
81        /// WAL position where this data starts
82        wal_start: Lsn,
83        /// WAL end position (may be 0 for mid-transaction messages)
84        wal_end: Lsn,
85        /// Server timestamp (microseconds since 2000-01-01)
86        server_time_micros: i64,
87        /// pgoutput-encoded change data
88        data: Bytes,
89    },
90
91    /// End of a transaction (pgoutput Commit message).
92    Commit {
93        lsn: Lsn,
94        end_lsn: Lsn,
95        commit_time_micros: i64,
96    },
97
98    /// Logical decoding message emitted via `pg_logical_emit_message()`.
99    ///
100    /// Transactional messages are delivered only after the enclosing
101    /// transaction commits. Non-transactional messages are delivered
102    /// immediately.
103    Message {
104        /// Whether the message was emitted inside a transaction.
105        transactional: bool,
106        /// LSN of the message in the WAL.
107        lsn: Lsn,
108        /// Application-defined message prefix (e.g. `"myapp.checkpoint"`).
109        prefix: String,
110        /// Raw message content bytes.
111        content: Bytes,
112    },
113
114    /// Emitted when `stop_at_lsn` has been reached.
115    ///
116    /// After this event, no more events will be emitted and the
117    /// replication stream will be closed.
118    StoppedAt {
119        /// The LSN that triggered the stop condition
120        reached: Lsn,
121    },
122}
123
124/// Channel receiver type for replication events.
125pub type ReplicationEventReceiver =
126    mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
127
128/// Internal worker state.
129pub struct WorkerState {
130    cfg: ReplicationConfig,
131    progress: Arc<SharedProgress>,
132    stop_rx: watch::Receiver<bool>,
133    out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
134}
135
136impl WorkerState {
137    pub fn new(
138        cfg: ReplicationConfig,
139        progress: Arc<SharedProgress>,
140        stop_rx: watch::Receiver<bool>,
141        out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
142    ) -> Self {
143        Self {
144            cfg,
145            progress,
146            stop_rx,
147            out,
148        }
149    }
150
151    /// Run the replication protocol on the given stream.
152    pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
153        &mut self,
154        stream: &mut S,
155    ) -> Result<()> {
156        // Wrap in a 128KB read buffer to batch multiple WAL messages into fewer
157        // recv() syscalls. BufReader delegates AsyncWrite to the inner stream,
158        // so writes (standby status replies, etc.) are unaffected.
159        let mut stream = BufReader::with_capacity(128 * 1024, stream);
160        self.startup(&mut stream).await?;
161        self.authenticate(&mut stream).await?;
162        self.start_replication(&mut stream).await?;
163        self.stream_loop(&mut stream).await
164    }
165
166    /// Send startup message with replication parameters.
167    async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
168        let params = [
169            ("user", self.cfg.user.as_str()),
170            ("database", self.cfg.database.as_str()),
171            ("replication", "database"),
172            ("client_encoding", "UTF8"),
173            ("application_name", "pgwire-replication"),
174        ];
175        write_startup_message(stream, 196608, &params).await
176    }
177
178    /// Start the logical replication stream.
179    async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
180        &self,
181        stream: &mut S,
182    ) -> Result<()> {
183        // Escape single quotes in publication name
184        let publication = self.cfg.publication.replace('\'', "''");
185        let sql = format!(
186            "START_REPLICATION SLOT {} LOGICAL {} \
187            (proto_version '1', publication_names '{}', messages 'true')",
188            self.cfg.slot, self.cfg.start_lsn, publication,
189        );
190        write_query(stream, &sql).await?;
191
192        // Wait for CopyBothResponse
193        loop {
194            let msg = read_backend_message(stream).await?;
195            match msg.tag {
196                b'W' => return Ok(()), // CopyBothResponse - ready to stream
197                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
198                b'N' | b'S' | b'K' => continue, // Notice, ParameterStatus, BackendKeyData
199                _ => continue,
200            }
201        }
202    }
203
204    /// Main replication streaming loop.
205    ///
206    /// Uses a two-phase approach for throughput:
207    /// 1. **Drain phase**: while the BufReader has buffered data, read messages
208    ///    in a tight loop without `select!` or timeout overhead.
209    /// 2. **Wait phase**: when the buffer is empty, fall back to `select!` with
210    ///    timeout + stop signal to handle idle keepalives and graceful shutdown.
211    async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
212        &mut self,
213        stream: &mut BufReader<S>,
214    ) -> Result<()> {
215        let mut last_status_sent = Instant::now() - self.cfg.status_interval;
216        let mut last_applied = self.progress.load_applied();
217        // Reusable read buffer — avoids per-message allocation.
218        let mut read_buf = BytesMut::with_capacity(4096);
219        // How many messages to process in the tight loop before checking
220        // stop signal and sending periodic status feedback.
221        const DRAIN_BATCH: usize = 256;
222
223        loop {
224            // Update applied LSN from client
225            let current_applied = self.progress.load_applied();
226            if current_applied != last_applied {
227                last_applied = current_applied;
228            }
229
230            // Send periodic status feedback
231            if last_status_sent.elapsed() >= self.cfg.status_interval {
232                self.send_feedback(stream, last_applied, false).await?;
233                last_status_sent = Instant::now();
234            }
235
236            // ── Drain phase: tight loop while BufReader has buffered data ──
237            // The BufReader has a 128KB internal buffer. When the kernel delivers
238            // a large TCP segment, many WAL messages are available without syscalls.
239            // Read them in a tight loop to avoid select!/timeout overhead per message.
240            let mut drained = 0usize;
241            while stream.buffer().len() >= 5 && drained < DRAIN_BATCH {
242                let msg = read_backend_message_into(stream, &mut read_buf).await?;
243                drained += 1;
244                match msg.tag {
245                    b'd' => {
246                        if self
247                            .handle_copy_data(
248                                stream,
249                                msg.payload,
250                                &mut last_applied,
251                                &mut last_status_sent,
252                            )
253                            .await?
254                        {
255                            return Ok(());
256                        }
257                    }
258                    b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
259                    _ => {}
260                }
261            }
262
263            // If we drained messages, loop back to check stop/status before
264            // potentially blocking on the next read.
265            if drained > 0 {
266                // Check stop signal without blocking
267                if self.stop_rx.has_changed().unwrap_or(false) && *self.stop_rx.borrow() {
268                    let _ = write_copy_done(stream).await;
269                    return Ok(());
270                }
271                continue;
272            }
273
274            // ── Wait phase: buffer empty, need to wait for socket data ──
275            let msg = tokio::select! {
276                biased;
277
278                _ = self.stop_rx.changed() => {
279                    if *self.stop_rx.borrow() {
280                        let _ = write_copy_done(stream).await;
281                        return Ok(());
282                    }
283                    continue;
284                }
285
286                msg_result = tokio::time::timeout(
287                    self.cfg.idle_wakeup_interval,
288                    read_backend_message_into(stream, &mut read_buf),
289                ) => {
290                    match msg_result {
291                        Ok(res) => res?,
292                        Err(_) => {
293                            let applied = self.progress.load_applied();
294                            last_applied = applied;
295                            self.send_feedback(stream, applied, false).await?;
296                            last_status_sent = Instant::now();
297                            continue;
298                        }
299                    }
300                }
301            };
302
303            match msg.tag {
304                b'd' => {
305                    if self
306                        .handle_copy_data(
307                            stream,
308                            msg.payload,
309                            &mut last_applied,
310                            &mut last_status_sent,
311                        )
312                        .await?
313                    {
314                        return Ok(());
315                    }
316                }
317                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
318                _ => {}
319            }
320        }
321    }
322
323    /// Handle a CopyData message. Returns true if we should stop.
324    async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
325        &mut self,
326        stream: &mut BufReader<S>,
327        payload: Bytes,
328        last_applied: &mut Lsn,
329        last_status_sent: &mut Instant,
330    ) -> Result<bool> {
331        let cd = parse_copy_data(payload)?;
332
333        match cd {
334            ReplicationCopyData::KeepAlive {
335                wal_end,
336                server_time_micros,
337                reply_requested,
338            } => {
339                // Respond immediately if server requests it
340                if reply_requested {
341                    let applied = self.progress.load_applied();
342                    *last_applied = applied;
343                    self.send_feedback(stream, applied, true).await?;
344                    *last_status_sent = Instant::now();
345                }
346
347                self.send_event(Ok(ReplicationEvent::KeepAlive {
348                    wal_end,
349                    reply_requested,
350                    server_time_micros,
351                }))
352                .await;
353
354                Ok(false)
355            }
356            ReplicationCopyData::XLogData {
357                wal_start,
358                wal_end,
359                server_time_micros,
360                data,
361            } => {
362                // If the payload is a pgoutput Begin/Commit message, emit only the boundary event.
363                if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
364                    let reached_lsn = match boundary_ev {
365                        ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
366                        ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
367                        _ => wal_end, // should never happen if parser only returns Begin/Commit
368                    };
369
370                    self.send_event(Ok(boundary_ev)).await;
371
372                    // Stop condition (prefer boundary LSN semantics when available)
373                    if let Some(stop_lsn) = self.cfg.stop_at_lsn {
374                        if reached_lsn >= stop_lsn {
375                            self.send_event(Ok(ReplicationEvent::StoppedAt {
376                                reached: reached_lsn,
377                            }))
378                            .await;
379                            let _ = write_copy_done(stream).await;
380                            return Ok(true); // should stop.
381                        }
382                    }
383
384                    return Ok(false);
385                }
386                // Otherwise, emit raw payload
387                // Check stop condition
388                if let Some(stop_lsn) = self.cfg.stop_at_lsn {
389                    if wal_end >= stop_lsn {
390                        // Send final event, then stop signal
391                        self.send_event(Ok(ReplicationEvent::XLogData {
392                            wal_start,
393                            wal_end,
394                            server_time_micros,
395                            data,
396                        }))
397                        .await;
398
399                        self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
400                            .await;
401
402                        let _ = write_copy_done(stream).await;
403                        return Ok(true);
404                    }
405                }
406
407                self.send_event(Ok(ReplicationEvent::XLogData {
408                    wal_start,
409                    wal_end,
410                    server_time_micros,
411                    data,
412                }))
413                .await;
414
415                Ok(false)
416            }
417        }
418    }
419
420    /// Send an event to the client channel.
421    ///
422    /// If the channel is full or closed, we log and continue - the client
423    /// may have stopped listening but we don't want to crash the worker.
424    async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
425        if self.out.send(event).await.is_err() {
426            tracing::debug!("event channel closed, client may have disconnected");
427        }
428    }
429
430    /// Handle PostgreSQL authentication exchange.
431    async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
432        &mut self,
433        stream: &mut S,
434    ) -> Result<()> {
435        loop {
436            let msg = read_backend_message(stream).await?;
437            match msg.tag {
438                b'R' => {
439                    let (code, rest) = parse_auth_request(&msg.payload)?;
440                    self.handle_auth_request(stream, code, rest).await?;
441                }
442                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
443                b'S' | b'K' => {}      // ParameterStatus, BackendKeyData - ignore
444                b'Z' => return Ok(()), // ReadyForQuery - auth complete
445                _ => {}
446            }
447        }
448    }
449
450    /// Handle a specific authentication request.
451    async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
452        &mut self,
453        stream: &mut S,
454        code: i32,
455        data: &[u8],
456    ) -> Result<()> {
457        match code {
458            0 => Ok(()), // AuthenticationOk
459            3 => {
460                // Cleartext password
461                let mut payload = Vec::from(self.cfg.password.as_bytes());
462                payload.push(0);
463                write_password_message(stream, &payload).await
464            }
465            10 => {
466                // SASL (SCRAM-SHA-256)
467                self.auth_scram(stream, data).await
468            }
469            #[cfg(feature = "md5")]
470            5 => {
471                // MD5 password
472                if data.len() != 4 {
473                    return Err(PgWireError::Protocol(
474                        "MD5 auth: expected 4-byte salt".into(),
475                    ));
476                }
477                let mut salt = [0u8; 4];
478                salt.copy_from_slice(&data[..4]);
479
480                let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
481                let mut payload = hash.into_bytes();
482                payload.push(0);
483                write_password_message(stream, &payload).await
484            }
485            _ => Err(PgWireError::Auth(format!(
486                "unsupported auth method code: {code}"
487            ))),
488        }
489    }
490
491    /// Perform SCRAM-SHA-256 authentication.
492    async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
493        &mut self,
494        stream: &mut S,
495        mechanisms_data: &[u8],
496    ) -> Result<()> {
497        // Parse offered mechanisms
498        let mechanisms = parse_sasl_mechanisms(mechanisms_data);
499
500        if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
501            return Err(PgWireError::Auth(format!(
502                "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
503            )));
504        }
505
506        #[cfg(not(feature = "scram"))]
507        return Err(PgWireError::Auth(
508            "SCRAM authentication required but 'scram' feature not enabled".into(),
509        ));
510
511        #[cfg(feature = "scram")]
512        {
513            use crate::auth::scram::ScramClient;
514
515            let scram = ScramClient::new(&self.cfg.user);
516
517            // Send SASLInitialResponse
518            let mut init = Vec::new();
519            init.extend_from_slice(b"SCRAM-SHA-256\0");
520            init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
521            init.extend_from_slice(scram.client_first.as_bytes());
522            write_password_message(stream, &init).await?;
523
524            // Receive AuthenticationSASLContinue (code 11)
525            let server_first = read_auth_data(stream, 11).await?;
526            let server_first_str = String::from_utf8_lossy(&server_first);
527
528            // Compute and send client-final
529            let (client_final, auth_message, salted_password) =
530                scram.client_final(&self.cfg.password, &server_first_str)?;
531            write_password_message(stream, client_final.as_bytes()).await?;
532
533            // Receive and verify AuthenticationSASLFinal (code 12)
534            let server_final = read_auth_data(stream, 12).await?;
535            let server_final_str = String::from_utf8_lossy(&server_final);
536            ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
537
538            Ok(())
539        }
540    }
541
542    /// Send standby status update to server.
543    async fn send_feedback<S: AsyncWrite + Unpin>(
544        &self,
545        stream: &mut S,
546        applied: Lsn,
547        reply_requested: bool,
548    ) -> Result<()> {
549        let client_time = current_pg_timestamp();
550        let payload = encode_standby_status_update(applied, client_time, reply_requested);
551        write_copy_data(stream, &payload).await
552    }
553}
554
555/// Parse SASL mechanism list from auth data.
556fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
557    let mut mechanisms = Vec::new();
558    let mut remaining = data;
559
560    while !remaining.is_empty() {
561        if let Some(pos) = remaining.iter().position(|&x| x == 0) {
562            if pos == 0 {
563                break; // Empty string terminates list
564            }
565            mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
566            remaining = &remaining[pos + 1..];
567        } else {
568            break;
569        }
570    }
571
572    mechanisms
573}
574
575fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
576    if data.is_empty() {
577        return Ok(None);
578    }
579
580    let tag = data[0];
581    let mut p = &data[1..];
582
583    fn take_i8(p: &mut &[u8]) -> Result<i8> {
584        if p.is_empty() {
585            return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
586        }
587        let v = p[0] as i8;
588        *p = &p[1..];
589        Ok(v)
590    }
591
592    fn take_i32(p: &mut &[u8]) -> Result<i32> {
593        if p.len() < 4 {
594            return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
595        }
596        let (head, tail) = p.split_at(4);
597        *p = tail;
598        Ok(i32::from_be_bytes(head.try_into().unwrap()))
599    }
600
601    fn take_i64(p: &mut &[u8]) -> Result<i64> {
602        if p.len() < 8 {
603            return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
604        }
605        let (head, tail) = p.split_at(8);
606        *p = tail;
607        Ok(i64::from_be_bytes(head.try_into().unwrap()))
608    }
609
610    match tag {
611        b'B' => {
612            let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
613            let commit_time_micros = take_i64(&mut p)?;
614            let xid = take_i32(&mut p)? as u32;
615
616            Ok(Some(ReplicationEvent::Begin {
617                final_lsn,
618                commit_time_micros,
619                xid,
620            }))
621        }
622        b'C' => {
623            let _flags = take_i8(&mut p)?;
624            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); // should be safe
625            let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
626            let commit_time_micros = take_i64(&mut p)?;
627
628            Ok(Some(ReplicationEvent::Commit {
629                lsn,
630                end_lsn,
631                commit_time_micros,
632            }))
633        }
634        b'M' => {
635            // Logical decoding message (pg_logical_emit_message)
636            // Wire: flags(1) + lsn(8) + prefix(null-terminated) + content_len(4) + content(n)
637            let flags = take_i8(&mut p)?;
638            let transactional = (flags & 1) != 0;
639            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
640
641            // Read null-terminated prefix string
642            let prefix_end = p.iter().position(|&b| b == 0).ok_or_else(|| {
643                PgWireError::Protocol("pgoutput Message: missing null terminator for prefix".into())
644            })?;
645            let prefix = String::from_utf8_lossy(&p[..prefix_end]).into_owned();
646            p = &p[prefix_end + 1..]; // advance past null byte
647
648            let content_len = take_i32(&mut p)? as usize;
649            if p.len() < content_len {
650                return Err(PgWireError::Protocol(format!(
651                    "pgoutput Message: expected {} content bytes, got {}",
652                    content_len,
653                    p.len()
654                )));
655            }
656            let content = Bytes::copy_from_slice(&p[..content_len]);
657
658            Ok(Some(ReplicationEvent::Message {
659                transactional,
660                lsn,
661                prefix,
662                content,
663            }))
664        }
665        _ => Ok(None),
666    }
667}
668
669/// Read authentication response data for a specific auth code.
670async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
671    stream: &mut S,
672    expected_code: i32,
673) -> Result<Vec<u8>> {
674    loop {
675        let msg = read_backend_message(stream).await?;
676        match msg.tag {
677            b'R' => {
678                let (code, data) = parse_auth_request(&msg.payload)?;
679                if code == expected_code {
680                    return Ok(data.to_vec());
681                }
682                return Err(PgWireError::Auth(format!(
683                    "unexpected auth code {code}, expected {expected_code}"
684                )));
685            }
686            b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
687            _ => {} // Skip other messages
688        }
689    }
690}
691
692/// Get current time as PostgreSQL timestamp (microseconds since 2000-01-01).
693fn current_pg_timestamp() -> i64 {
694    use std::time::{SystemTime, UNIX_EPOCH};
695
696    let now = SystemTime::now()
697        .duration_since(UNIX_EPOCH)
698        .unwrap_or_default();
699
700    let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
701    unix_micros - PG_EPOCH_MICROS
702}
703
704/// Compute PostgreSQL MD5 password hash.
705#[cfg(feature = "md5")]
706fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
707    fn md5_hex(data: &[u8]) -> String {
708        format!("{:x}", md5::compute(data))
709    }
710
711    // First hash: md5(password + username)
712    let inner = md5_hex(format!("{password}{user}").as_bytes());
713
714    // Second hash: md5(inner_hash + salt)
715    let mut outer_input = inner.into_bytes();
716    outer_input.extend_from_slice(salt);
717
718    format!("md5{}", md5_hex(&outer_input))
719}
720
721#[cfg(test)]
722mod tests {
723    use super::*;
724
725    #[test]
726    fn parse_sasl_mechanisms_single() {
727        let data = b"SCRAM-SHA-256\0\0";
728        let mechs = parse_sasl_mechanisms(data);
729        assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
730    }
731
732    #[test]
733    fn parse_sasl_mechanisms_multiple() {
734        let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
735        let mechs = parse_sasl_mechanisms(data);
736        assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
737    }
738
739    #[test]
740    fn parse_sasl_mechanisms_empty() {
741        let mechs = parse_sasl_mechanisms(b"\0");
742        assert!(mechs.is_empty());
743    }
744
745    #[test]
746    #[cfg(feature = "md5")]
747    fn postgres_md5_known_value() {
748        // Test vector: user="md5_user", password="md5_pass", salt=[0x01, 0x02, 0x03, 0x04]
749        // Can verify with: SELECT 'md5' || md5(md5('md5_passmd5_user') || E'\\x01020304');
750        let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
751        assert!(hash.starts_with("md5"));
752        assert_eq!(hash.len(), 35); // "md5" + 32 hex chars
753    }
754
755    #[test]
756    fn current_pg_timestamp_is_positive() {
757        // Any time after 2000-01-01 should be positive
758        let ts = current_pg_timestamp();
759        assert!(ts > 0);
760    }
761}