Skip to main content

pgwire_replication/client/
worker.rs

1use bytes::Bytes;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::sync::Arc;
4use tokio::io::{AsyncRead, AsyncWrite, BufReader};
5use tokio::sync::{mpsc, watch};
6use tokio::time::Instant;
7
8use crate::config::ReplicationConfig;
9use crate::error::{PgWireError, Result};
10use crate::lsn::Lsn;
11use crate::protocol::framing::{
12    read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
13    write_startup_message, MessageReader,
14};
15use crate::protocol::messages::{parse_auth_request, parse_error_response};
16use crate::protocol::replication::{
17    encode_standby_status_update, parse_copy_data, ReplicationCopyData, PG_EPOCH_MICROS,
18};
19
20/// Shared replication progress updated by the consumer and read by the worker.
21///
22/// Stored as an AtomicU64 so progress updates are cheap and monotonic
23/// without async backpressure.
24pub struct SharedProgress {
25    applied: AtomicU64,
26}
27
28impl SharedProgress {
29    pub fn new(start: Lsn) -> Self {
30        Self {
31            applied: AtomicU64::new(start.as_u64()),
32        }
33    }
34
35    #[inline]
36    pub fn load_applied(&self) -> Lsn {
37        Lsn::from_u64(self.applied.load(Ordering::Acquire))
38    }
39
40    /// Monotonic update: if `lsn` is lower than the currently stored applied LSN,
41    /// this is a no-op.
42    #[inline]
43    pub fn update_applied(&self, lsn: Lsn) {
44        let new = lsn.as_u64();
45        let mut cur = self.applied.load(Ordering::Relaxed);
46
47        while new > cur {
48            match self
49                .applied
50                .compare_exchange_weak(cur, new, Ordering::Release, Ordering::Relaxed)
51            {
52                Ok(_) => break,
53                Err(observed) => cur = observed,
54            }
55        }
56    }
57}
58
59/// Events emitted by the replication worker.
60#[derive(Debug, Clone)]
61pub enum ReplicationEvent {
62    /// Server heartbeat message.
63    KeepAlive {
64        /// Current server WAL end position
65        wal_end: Lsn,
66        /// Whether server requested a reply (already handled internally)
67        reply_requested: bool,
68        /// Server timestamp (microseconds since 2000-01-01)
69        server_time_micros: i64,
70    },
71
72    /// Start of a transaction (pgoutput Begin message).
73    Begin {
74        final_lsn: Lsn,
75        xid: u32,
76        commit_time_micros: i64,
77    },
78
79    /// WAL data containing transaction changes.
80    XLogData {
81        /// WAL position where this data starts
82        wal_start: Lsn,
83        /// WAL end position (may be 0 for mid-transaction messages)
84        wal_end: Lsn,
85        /// Server timestamp (microseconds since 2000-01-01)
86        server_time_micros: i64,
87        /// pgoutput-encoded change data
88        data: Bytes,
89    },
90
91    /// End of a transaction (pgoutput Commit message).
92    Commit {
93        lsn: Lsn,
94        end_lsn: Lsn,
95        commit_time_micros: i64,
96    },
97
98    /// Logical decoding message emitted via `pg_logical_emit_message()`.
99    ///
100    /// Transactional messages are delivered only after the enclosing
101    /// transaction commits. Non-transactional messages are delivered
102    /// immediately.
103    Message {
104        /// Whether the message was emitted inside a transaction.
105        transactional: bool,
106        /// LSN of the message in the WAL.
107        lsn: Lsn,
108        /// Application-defined message prefix (e.g. `"myapp.checkpoint"`).
109        prefix: String,
110        /// Raw message content bytes.
111        content: Bytes,
112    },
113
114    /// Emitted when `stop_at_lsn` has been reached.
115    ///
116    /// After this event, no more events will be emitted and the
117    /// replication stream will be closed.
118    StoppedAt {
119        /// The LSN that triggered the stop condition
120        reached: Lsn,
121    },
122}
123
124/// Channel receiver type for replication events.
125pub type ReplicationEventReceiver =
126    mpsc::Receiver<std::result::Result<ReplicationEvent, PgWireError>>;
127
128/// Internal worker state.
129pub struct WorkerState {
130    cfg: ReplicationConfig,
131    progress: Arc<SharedProgress>,
132    stop_rx: watch::Receiver<bool>,
133    out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
134}
135
136impl WorkerState {
137    pub fn new(
138        cfg: ReplicationConfig,
139        progress: Arc<SharedProgress>,
140        stop_rx: watch::Receiver<bool>,
141        out: mpsc::Sender<std::result::Result<ReplicationEvent, PgWireError>>,
142    ) -> Self {
143        Self {
144            cfg,
145            progress,
146            stop_rx,
147            out,
148        }
149    }
150
151    /// Run the replication protocol on the given stream.
152    pub async fn run_on_stream<S: AsyncRead + AsyncWrite + Unpin>(
153        &mut self,
154        stream: &mut S,
155    ) -> Result<()> {
156        // Wrap in a 128KB read buffer to batch multiple WAL messages into fewer
157        // recv() syscalls. BufReader delegates AsyncWrite to the inner stream,
158        // so writes (standby status replies, etc.) are unaffected.
159        let mut stream = BufReader::with_capacity(128 * 1024, stream);
160        self.startup(&mut stream).await?;
161        self.authenticate(&mut stream).await?;
162        self.start_replication(&mut stream).await?;
163        self.stream_loop(&mut stream).await
164    }
165
166    /// Send startup message with replication parameters.
167    async fn startup<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<()> {
168        let params = [
169            ("user", self.cfg.user.as_str()),
170            ("database", self.cfg.database.as_str()),
171            ("replication", "database"),
172            ("client_encoding", "UTF8"),
173            ("application_name", "pgwire-replication"),
174        ];
175        write_startup_message(stream, 196608, &params).await
176    }
177
178    /// Start the logical replication stream.
179    async fn start_replication<S: AsyncRead + AsyncWrite + Unpin>(
180        &self,
181        stream: &mut S,
182    ) -> Result<()> {
183        // Escape single quotes in publication name
184        let publication = self.cfg.publication.replace('\'', "''");
185        let sql = format!(
186            "START_REPLICATION SLOT {} LOGICAL {} \
187            (proto_version '1', publication_names '{}', messages 'true')",
188            self.cfg.slot, self.cfg.start_lsn, publication,
189        );
190        write_query(stream, &sql).await?;
191
192        // Wait for CopyBothResponse
193        loop {
194            let msg = read_backend_message(stream).await?;
195            match msg.tag {
196                b'W' => return Ok(()), // CopyBothResponse - ready to stream
197                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
198                b'N' | b'S' | b'K' => continue, // Notice, ParameterStatus, BackendKeyData
199                _ => continue,
200            }
201        }
202    }
203
204    /// Main replication streaming loop.
205    ///
206    /// Uses a two-phase approach for throughput:
207    /// 1. **Drain phase**: while the BufReader has buffered data, read messages
208    ///    in a tight loop without `select!` or timeout overhead.
209    /// 2. **Wait phase**: when the buffer is empty, fall back to `select!` with
210    ///    timeout + stop signal to handle idle keepalives and graceful shutdown.
211    ///
212    /// Reads use [`MessageReader`], which preserves partial-read state across
213    /// dropped futures so the wait-phase `select!` is cancellation-safe.
214    async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
215        &mut self,
216        stream: &mut BufReader<S>,
217    ) -> Result<()> {
218        let mut last_status_sent = Instant::now() - self.cfg.status_interval;
219        let mut last_applied = self.progress.load_applied();
220        // Cancellation-safe message reader, partial reads survive dropped futures.
221        let mut reader = MessageReader::new();
222        // How many messages to process in the tight loop before checking
223        // stop signal and sending periodic status feedback.
224        const DRAIN_BATCH: usize = 256;
225
226        loop {
227            // Update applied LSN from client
228            let current_applied = self.progress.load_applied();
229            if current_applied != last_applied {
230                last_applied = current_applied;
231            }
232
233            // Send periodic status feedback
234            if last_status_sent.elapsed() >= self.cfg.status_interval {
235                self.send_feedback(stream, last_applied, false).await?;
236                last_status_sent = Instant::now();
237            }
238
239            // ── Drain phase: tight loop while BufReader has buffered data ──
240            // The BufReader has a 128KB internal buffer. When the kernel delivers
241            // a large TCP segment, many WAL messages are available without syscalls.
242            // Read them in a tight loop to avoid select!/timeout overhead per message.
243            let mut drained = 0usize;
244            while stream.buffer().len() >= 5 && drained < DRAIN_BATCH {
245                let msg = reader.read(stream).await?;
246                drained += 1;
247                if msg.tag == b'E' {
248                    return Err(PgWireError::Server(parse_error_response(&msg.payload)));
249                }
250                if msg.tag == b'd'
251                    && self
252                        .handle_copy_data(
253                            stream,
254                            msg.payload,
255                            &mut last_applied,
256                            &mut last_status_sent,
257                        )
258                        .await?
259                {
260                    return Ok(());
261                }
262            }
263
264            // If we drained messages, loop back to check stop/status before
265            // potentially blocking on the next read.
266            if drained > 0 {
267                // Check stop signal without blocking
268                if self.stop_rx.has_changed().unwrap_or(false) && *self.stop_rx.borrow() {
269                    let _ = write_copy_done(stream).await;
270                    return Ok(());
271                }
272                continue;
273            }
274
275            // ── Wait phase: buffer empty, need to wait for socket data ──
276            //
277            // Both `stop_rx.changed()` and the timeout can drop the read future
278            // mid-message. `MessageReader::read` is cancellation-safe — partial
279            // header/payload state lives on `reader` and is preserved across the
280            // drop, so the next iteration resumes the read without losing bytes.
281            let msg = tokio::select! {
282                biased;
283
284                _ = self.stop_rx.changed() => {
285                    if *self.stop_rx.borrow() {
286                        let _ = write_copy_done(stream).await;
287                        return Ok(());
288                    }
289                    continue;
290                }
291
292                msg_result = tokio::time::timeout(
293                    self.cfg.idle_wakeup_interval,
294                    reader.read(stream),
295                ) => {
296                    match msg_result {
297                        Ok(res) => res?,
298                        Err(_) => {
299                            let applied = self.progress.load_applied();
300                            last_applied = applied;
301                            self.send_feedback(stream, applied, false).await?;
302                            last_status_sent = Instant::now();
303                            continue;
304                        }
305                    }
306                }
307            };
308
309            if msg.tag == b'E' {
310                return Err(PgWireError::Server(parse_error_response(&msg.payload)));
311            }
312            if msg.tag == b'd'
313                && self
314                    .handle_copy_data(
315                        stream,
316                        msg.payload,
317                        &mut last_applied,
318                        &mut last_status_sent,
319                    )
320                    .await?
321            {
322                return Ok(());
323            }
324        }
325    }
326
327    /// Handle a CopyData message. Returns true if we should stop.
328    async fn handle_copy_data<S: AsyncRead + AsyncWrite + Unpin>(
329        &mut self,
330        stream: &mut BufReader<S>,
331        payload: Bytes,
332        last_applied: &mut Lsn,
333        last_status_sent: &mut Instant,
334    ) -> Result<bool> {
335        let cd = parse_copy_data(payload)?;
336
337        match cd {
338            ReplicationCopyData::KeepAlive {
339                wal_end,
340                server_time_micros,
341                reply_requested,
342            } => {
343                // Respond immediately if server requests it
344                if reply_requested {
345                    let applied = self.progress.load_applied();
346                    *last_applied = applied;
347                    self.send_feedback(stream, applied, true).await?;
348                    *last_status_sent = Instant::now();
349                }
350
351                self.send_event(Ok(ReplicationEvent::KeepAlive {
352                    wal_end,
353                    reply_requested,
354                    server_time_micros,
355                }))
356                .await;
357
358                Ok(false)
359            }
360            ReplicationCopyData::XLogData {
361                wal_start,
362                wal_end,
363                server_time_micros,
364                data,
365            } => {
366                // If the payload is a pgoutput Begin/Commit message, emit only the boundary event.
367                if let Some(boundary_ev) = parse_pgoutput_boundary(&data)? {
368                    let reached_lsn = match boundary_ev {
369                        ReplicationEvent::Begin { final_lsn, .. } => final_lsn,
370                        ReplicationEvent::Commit { end_lsn, .. } => end_lsn,
371                        _ => wal_end, // should never happen if parser only returns Begin/Commit
372                    };
373
374                    self.send_event(Ok(boundary_ev)).await;
375
376                    // Stop condition (prefer boundary LSN semantics when available)
377                    if let Some(stop_lsn) = self.cfg.stop_at_lsn {
378                        if reached_lsn >= stop_lsn {
379                            self.send_event(Ok(ReplicationEvent::StoppedAt {
380                                reached: reached_lsn,
381                            }))
382                            .await;
383                            let _ = write_copy_done(stream).await;
384                            return Ok(true); // should stop.
385                        }
386                    }
387
388                    return Ok(false);
389                }
390                // Otherwise, emit raw payload
391                // Check stop condition
392                if let Some(stop_lsn) = self.cfg.stop_at_lsn {
393                    if wal_end >= stop_lsn {
394                        // Send final event, then stop signal
395                        self.send_event(Ok(ReplicationEvent::XLogData {
396                            wal_start,
397                            wal_end,
398                            server_time_micros,
399                            data,
400                        }))
401                        .await;
402
403                        self.send_event(Ok(ReplicationEvent::StoppedAt { reached: wal_end }))
404                            .await;
405
406                        let _ = write_copy_done(stream).await;
407                        return Ok(true);
408                    }
409                }
410
411                self.send_event(Ok(ReplicationEvent::XLogData {
412                    wal_start,
413                    wal_end,
414                    server_time_micros,
415                    data,
416                }))
417                .await;
418
419                Ok(false)
420            }
421        }
422    }
423
424    /// Send an event to the client channel.
425    ///
426    /// If the channel is full or closed, we log and continue - the client
427    /// may have stopped listening but we don't want to crash the worker.
428    async fn send_event(&self, event: std::result::Result<ReplicationEvent, PgWireError>) {
429        if self.out.send(event).await.is_err() {
430            tracing::debug!("event channel closed, client may have disconnected");
431        }
432    }
433
434    /// Handle PostgreSQL authentication exchange.
435    async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
436        &mut self,
437        stream: &mut S,
438    ) -> Result<()> {
439        loop {
440            let msg = read_backend_message(stream).await?;
441            match msg.tag {
442                b'R' => {
443                    let (code, rest) = parse_auth_request(&msg.payload)?;
444                    self.handle_auth_request(stream, code, rest).await?;
445                }
446                b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
447                b'S' | b'K' => {}      // ParameterStatus, BackendKeyData - ignore
448                b'Z' => return Ok(()), // ReadyForQuery - auth complete
449                _ => {}
450            }
451        }
452    }
453
454    /// Handle a specific authentication request.
455    async fn handle_auth_request<S: AsyncRead + AsyncWrite + Unpin>(
456        &mut self,
457        stream: &mut S,
458        code: i32,
459        data: &[u8],
460    ) -> Result<()> {
461        match code {
462            0 => Ok(()), // AuthenticationOk
463            3 => {
464                // Cleartext password
465                let mut payload = Vec::from(self.cfg.password.as_bytes());
466                payload.push(0);
467                write_password_message(stream, &payload).await
468            }
469            10 => {
470                // SASL (SCRAM-SHA-256)
471                self.auth_scram(stream, data).await
472            }
473            #[cfg(feature = "md5")]
474            5 => {
475                // MD5 password
476                if data.len() != 4 {
477                    return Err(PgWireError::Protocol(
478                        "MD5 auth: expected 4-byte salt".into(),
479                    ));
480                }
481                let mut salt = [0u8; 4];
482                salt.copy_from_slice(&data[..4]);
483
484                let hash = postgres_md5(&self.cfg.password, &self.cfg.user, &salt);
485                let mut payload = hash.into_bytes();
486                payload.push(0);
487                write_password_message(stream, &payload).await
488            }
489            _ => Err(PgWireError::Auth(format!(
490                "unsupported auth method code: {code}"
491            ))),
492        }
493    }
494
495    /// Perform SCRAM-SHA-256 authentication.
496    async fn auth_scram<S: AsyncRead + AsyncWrite + Unpin>(
497        &mut self,
498        stream: &mut S,
499        mechanisms_data: &[u8],
500    ) -> Result<()> {
501        // Parse offered mechanisms
502        let mechanisms = parse_sasl_mechanisms(mechanisms_data);
503
504        if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
505            return Err(PgWireError::Auth(format!(
506                "server doesn't offer SCRAM-SHA-256, available: {mechanisms:?}"
507            )));
508        }
509
510        #[cfg(not(feature = "scram"))]
511        return Err(PgWireError::Auth(
512            "SCRAM authentication required but 'scram' feature not enabled".into(),
513        ));
514
515        #[cfg(feature = "scram")]
516        {
517            use crate::auth::scram::ScramClient;
518
519            let scram = ScramClient::new(&self.cfg.user);
520
521            // Send SASLInitialResponse
522            let mut init = Vec::new();
523            init.extend_from_slice(b"SCRAM-SHA-256\0");
524            init.extend_from_slice(&(scram.client_first.len() as i32).to_be_bytes());
525            init.extend_from_slice(scram.client_first.as_bytes());
526            write_password_message(stream, &init).await?;
527
528            // Receive AuthenticationSASLContinue (code 11)
529            let server_first = read_auth_data(stream, 11).await?;
530            let server_first_str = String::from_utf8_lossy(&server_first);
531
532            // Compute and send client-final
533            let (client_final, auth_message, salted_password) =
534                scram.client_final(&self.cfg.password, &server_first_str)?;
535            write_password_message(stream, client_final.as_bytes()).await?;
536
537            // Receive and verify AuthenticationSASLFinal (code 12)
538            let server_final = read_auth_data(stream, 12).await?;
539            let server_final_str = String::from_utf8_lossy(&server_final);
540            ScramClient::verify_server_final(&server_final_str, &salted_password, &auth_message)?;
541
542            Ok(())
543        }
544    }
545
546    /// Send standby status update to server.
547    async fn send_feedback<S: AsyncWrite + Unpin>(
548        &self,
549        stream: &mut S,
550        applied: Lsn,
551        reply_requested: bool,
552    ) -> Result<()> {
553        let client_time = current_pg_timestamp();
554        let payload = encode_standby_status_update(applied, client_time, reply_requested);
555        write_copy_data(stream, &payload).await
556    }
557}
558
559/// Parse SASL mechanism list from auth data.
560fn parse_sasl_mechanisms(data: &[u8]) -> Vec<String> {
561    let mut mechanisms = Vec::new();
562    let mut remaining = data;
563
564    while !remaining.is_empty() {
565        if let Some(pos) = remaining.iter().position(|&x| x == 0) {
566            if pos == 0 {
567                break; // Empty string terminates list
568            }
569            mechanisms.push(String::from_utf8_lossy(&remaining[..pos]).to_string());
570            remaining = &remaining[pos + 1..];
571        } else {
572            break;
573        }
574    }
575
576    mechanisms
577}
578
579fn parse_pgoutput_boundary(data: &Bytes) -> Result<Option<ReplicationEvent>> {
580    if data.is_empty() {
581        return Ok(None);
582    }
583
584    let tag = data[0];
585    let mut p = &data[1..];
586
587    fn take_i8(p: &mut &[u8]) -> Result<i8> {
588        if p.is_empty() {
589            return Err(PgWireError::Protocol("pgoutput: truncated i8".into()));
590        }
591        let v = p[0] as i8;
592        *p = &p[1..];
593        Ok(v)
594    }
595
596    fn take_i32(p: &mut &[u8]) -> Result<i32> {
597        if p.len() < 4 {
598            return Err(PgWireError::Protocol("pgoutput: truncated i32".into()));
599        }
600        let (head, tail) = p.split_at(4);
601        *p = tail;
602        Ok(i32::from_be_bytes(head.try_into().unwrap()))
603    }
604
605    fn take_i64(p: &mut &[u8]) -> Result<i64> {
606        if p.len() < 8 {
607            return Err(PgWireError::Protocol("pgoutput: truncated i64".into()));
608        }
609        let (head, tail) = p.split_at(8);
610        *p = tail;
611        Ok(i64::from_be_bytes(head.try_into().unwrap()))
612    }
613
614    match tag {
615        b'B' => {
616            let final_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
617            let commit_time_micros = take_i64(&mut p)?;
618            let xid = take_i32(&mut p)? as u32;
619
620            Ok(Some(ReplicationEvent::Begin {
621                final_lsn,
622                commit_time_micros,
623                xid,
624            }))
625        }
626        b'C' => {
627            let _flags = take_i8(&mut p)?;
628            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64); // should be safe
629            let end_lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
630            let commit_time_micros = take_i64(&mut p)?;
631
632            Ok(Some(ReplicationEvent::Commit {
633                lsn,
634                end_lsn,
635                commit_time_micros,
636            }))
637        }
638        b'M' => {
639            // Logical decoding message (pg_logical_emit_message)
640            // Wire: flags(1) + lsn(8) + prefix(null-terminated) + content_len(4) + content(n)
641            let flags = take_i8(&mut p)?;
642            let transactional = (flags & 1) != 0;
643            let lsn = Lsn::from_u64(take_i64(&mut p)? as u64);
644
645            // Read null-terminated prefix string
646            let prefix_end = p.iter().position(|&b| b == 0).ok_or_else(|| {
647                PgWireError::Protocol("pgoutput Message: missing null terminator for prefix".into())
648            })?;
649            let prefix = String::from_utf8_lossy(&p[..prefix_end]).into_owned();
650            p = &p[prefix_end + 1..]; // advance past null byte
651
652            let content_len = take_i32(&mut p)? as usize;
653            if p.len() < content_len {
654                return Err(PgWireError::Protocol(format!(
655                    "pgoutput Message: expected {} content bytes, got {}",
656                    content_len,
657                    p.len()
658                )));
659            }
660            let content = Bytes::copy_from_slice(&p[..content_len]);
661
662            Ok(Some(ReplicationEvent::Message {
663                transactional,
664                lsn,
665                prefix,
666                content,
667            }))
668        }
669        _ => Ok(None),
670    }
671}
672
673/// Read authentication response data for a specific auth code.
674async fn read_auth_data<S: AsyncRead + AsyncWrite + Unpin>(
675    stream: &mut S,
676    expected_code: i32,
677) -> Result<Vec<u8>> {
678    loop {
679        let msg = read_backend_message(stream).await?;
680        match msg.tag {
681            b'R' => {
682                let (code, data) = parse_auth_request(&msg.payload)?;
683                if code == expected_code {
684                    return Ok(data.to_vec());
685                }
686                return Err(PgWireError::Auth(format!(
687                    "unexpected auth code {code}, expected {expected_code}"
688                )));
689            }
690            b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
691            _ => {} // Skip other messages
692        }
693    }
694}
695
696/// Get current time as PostgreSQL timestamp (microseconds since 2000-01-01).
697fn current_pg_timestamp() -> i64 {
698    use std::time::{SystemTime, UNIX_EPOCH};
699
700    let now = SystemTime::now()
701        .duration_since(UNIX_EPOCH)
702        .unwrap_or_default();
703
704    let unix_micros = (now.as_secs() as i64) * 1_000_000 + (now.subsec_micros() as i64);
705    unix_micros - PG_EPOCH_MICROS
706}
707
708/// Compute PostgreSQL MD5 password hash.
709#[cfg(feature = "md5")]
710fn postgres_md5(password: &str, user: &str, salt: &[u8; 4]) -> String {
711    fn md5_hex(data: &[u8]) -> String {
712        format!("{:x}", md5::compute(data))
713    }
714
715    // First hash: md5(password + username)
716    let inner = md5_hex(format!("{password}{user}").as_bytes());
717
718    // Second hash: md5(inner_hash + salt)
719    let mut outer_input = inner.into_bytes();
720    outer_input.extend_from_slice(salt);
721
722    format!("md5{}", md5_hex(&outer_input))
723}
724
725#[cfg(test)]
726mod tests {
727    use super::*;
728
729    #[test]
730    fn parse_sasl_mechanisms_single() {
731        let data = b"SCRAM-SHA-256\0\0";
732        let mechs = parse_sasl_mechanisms(data);
733        assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
734    }
735
736    #[test]
737    fn parse_sasl_mechanisms_multiple() {
738        let data = b"SCRAM-SHA-256\0SCRAM-SHA-256-PLUS\0\0";
739        let mechs = parse_sasl_mechanisms(data);
740        assert_eq!(mechs, vec!["SCRAM-SHA-256", "SCRAM-SHA-256-PLUS"]);
741    }
742
743    #[test]
744    fn parse_sasl_mechanisms_empty() {
745        let mechs = parse_sasl_mechanisms(b"\0");
746        assert!(mechs.is_empty());
747    }
748
749    #[test]
750    #[cfg(feature = "md5")]
751    fn postgres_md5_known_value() {
752        // Test vector: user="md5_user", password="md5_pass", salt=[0x01, 0x02, 0x03, 0x04]
753        // Can verify with: SELECT 'md5' || md5(md5('md5_passmd5_user') || E'\\x01020304');
754        let hash = postgres_md5("md5_pass", "md5_user", &[0x01, 0x02, 0x03, 0x04]);
755        assert!(hash.starts_with("md5"));
756        assert_eq!(hash.len(), 35); // "md5" + 32 hex chars
757    }
758
759    #[test]
760    fn current_pg_timestamp_is_positive() {
761        // Any time after 2000-01-01 should be positive
762        let ts = current_pg_timestamp();
763        assert!(ts > 0);
764    }
765}