Skip to main content

nodedb_cluster/mirror/
handshake.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Cross-cluster mirror handshake wire protocol.
4//!
5//! When a mirror cluster opens a QUIC connection to the source cluster it
6//! sends a [`MirrorHello`] on the first bidi stream.  The source replies
7//! with a [`MirrorHelloAck`].  Only after this exchange succeeds does the
8//! link transition to entry-streaming / snapshot-transfer mode.
9//!
10//! # Authentication
11//!
12//! The `source_cluster` field in [`MirrorHello`] is the cluster-id string
13//! the mirror declares.  The source verifies it matches its own cluster-id
14//! and rejects the connection otherwise (error code
15//! [`MIRROR_HELLO_ERR_CLUSTER_ID`]).
16//!
17//! # Observer-role enforcement
18//!
19//! The source tracks this connection as `PeerRole::Observer`.  Any attempt
20//! by the mirror to send a voter-class RPC (RequestVote, ConfChange) over
21//! the same connection is rejected with [`MIRROR_HELLO_ERR_OBSERVER_ONLY`].
22//!
23//! # Wire format
24//!
25//! Both messages use zerompk (MessagePack) encoding, length-prefixed with
26//! a 4-byte big-endian frame length, matching the existing rpc_codec
27//! framing convention.  The discriminant byte precedes the MessagePack
28//! payload so the decoder can branch without buffering the full payload.
29
30use serde::{Deserialize, Serialize};
31use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
32
33use super::error::MirrorError;
34
35/// Discriminant byte for [`MirrorHello`].
36pub const MIRROR_HELLO: u8 = 0x01;
37/// Discriminant byte for [`MirrorHelloAck`].
38pub const MIRROR_HELLO_ACK: u8 = 0x02;
39
40/// Error code: source cluster-id mismatch.
41pub const MIRROR_HELLO_ERR_CLUSTER_ID: u8 = 0x01;
42/// Error code: the peer attempted a voter operation; only observer RPCs allowed.
43pub const MIRROR_HELLO_ERR_OBSERVER_ONLY: u8 = 0x02;
44/// Error code: the mirror declared a wire protocol version this source does
45/// not implement.
46pub const MIRROR_HELLO_ERR_BAD_VERSION: u8 = 0x03;
47
48/// Maximum size (bytes) of a [`MirrorHello`] or [`MirrorHelloAck`] payload.
49///
50/// Bounds the read buffer so a malicious source cannot force unbounded
51/// allocation on the mirror side.
52const MAX_HANDSHAKE_PAYLOAD: usize = 4096;
53
54/// Opening handshake sent by the mirror to the source.
55#[derive(
56    Debug,
57    Clone,
58    PartialEq,
59    Eq,
60    Serialize,
61    Deserialize,
62    zerompk::ToMessagePack,
63    zerompk::FromMessagePack,
64)]
65#[msgpack(map)]
66pub struct MirrorHello {
67    /// Cluster-id the mirror is connecting to (i.e., the source cluster-id).
68    ///
69    /// The source verifies this matches its own id and rejects the connection
70    /// on mismatch.
71    pub source_cluster: String,
72    /// The database id on the *source* cluster being mirrored.
73    pub source_database_id: String,
74    /// The WAL LSN the mirror last applied.  Drives the source's decision on
75    /// whether to start from the last snapshot or stream log from this LSN.
76    pub last_applied_lsn: u64,
77    /// Wire protocol version for this cross-cluster link.
78    pub protocol_version: u16,
79}
80
81/// Acknowledgement sent by the source to the mirror.
82#[derive(
83    Debug,
84    Clone,
85    PartialEq,
86    Eq,
87    Serialize,
88    Deserialize,
89    zerompk::ToMessagePack,
90    zerompk::FromMessagePack,
91)]
92#[msgpack(map)]
93pub struct MirrorHelloAck {
94    /// Whether the source accepted the connection.
95    pub accepted: bool,
96    /// On `accepted = false`, an error code from the `MIRROR_HELLO_ERR_*`
97    /// constants.
98    pub error_code: u8,
99    /// Human-readable explanation (empty string on success).
100    pub error_detail: String,
101    /// The source's own cluster-id, included so the mirror can verify it
102    /// has connected to the right cluster.
103    pub source_cluster_id: String,
104    /// LSN of the snapshot the source is about to send, or `u64::MAX` if the
105    /// source will stream from `last_applied_lsn + 1` without a fresh
106    /// snapshot.
107    pub snapshot_lsn: u64,
108    /// Total snapshot size in bytes (0 if no snapshot will be sent).
109    pub snapshot_bytes_total: u64,
110}
111
112/// Current cross-cluster wire protocol version.
113pub const MIRROR_PROTOCOL_VERSION: u16 = 1;
114
115/// Send a [`MirrorHello`] frame to `writer`.
116pub async fn send_hello<W: AsyncWrite + Unpin>(
117    writer: &mut W,
118    hello: &MirrorHello,
119) -> Result<(), MirrorError> {
120    let payload = zerompk::to_msgpack_vec(hello).map_err(|e| MirrorError::HandshakeCodec {
121        detail: format!("encode MirrorHello: {e}"),
122    })?;
123    write_framed(writer, MIRROR_HELLO, &payload).await
124}
125
126/// Read a [`MirrorHello`] frame from `reader`.
127pub async fn recv_hello<R: AsyncRead + Unpin>(reader: &mut R) -> Result<MirrorHello, MirrorError> {
128    let (discriminant, payload) = read_framed(reader).await?;
129    if discriminant != MIRROR_HELLO {
130        return Err(MirrorError::HandshakeCodec {
131            detail: format!(
132                "expected MirrorHello discriminant {MIRROR_HELLO:#04x}, got {discriminant:#04x}"
133            ),
134        });
135    }
136    zerompk::from_msgpack(&payload).map_err(|e| MirrorError::HandshakeCodec {
137        detail: format!("decode MirrorHello: {e}"),
138    })
139}
140
141/// Send a [`MirrorHelloAck`] frame to `writer`.
142pub async fn send_ack<W: AsyncWrite + Unpin>(
143    writer: &mut W,
144    ack: &MirrorHelloAck,
145) -> Result<(), MirrorError> {
146    let payload = zerompk::to_msgpack_vec(ack).map_err(|e| MirrorError::HandshakeCodec {
147        detail: format!("encode MirrorHelloAck: {e}"),
148    })?;
149    write_framed(writer, MIRROR_HELLO_ACK, &payload).await
150}
151
152/// Read a [`MirrorHelloAck`] frame from `reader`.
153pub async fn recv_ack<R: AsyncRead + Unpin>(reader: &mut R) -> Result<MirrorHelloAck, MirrorError> {
154    let (discriminant, payload) = read_framed(reader).await?;
155    if discriminant != MIRROR_HELLO_ACK {
156        return Err(MirrorError::HandshakeCodec {
157            detail: format!(
158                "expected MirrorHelloAck discriminant {MIRROR_HELLO_ACK:#04x}, \
159                 got {discriminant:#04x}"
160            ),
161        });
162    }
163    zerompk::from_msgpack(&payload).map_err(|e| MirrorError::HandshakeCodec {
164        detail: format!("decode MirrorHelloAck: {e}"),
165    })
166}
167
168/// Write a framed message: `[discriminant u8][len u32 BE][payload bytes]`.
169async fn write_framed<W: AsyncWrite + Unpin>(
170    writer: &mut W,
171    discriminant: u8,
172    payload: &[u8],
173) -> Result<(), MirrorError> {
174    let len = payload.len() as u32;
175    let header = [
176        discriminant,
177        (len >> 24) as u8,
178        (len >> 16) as u8,
179        (len >> 8) as u8,
180        len as u8,
181    ];
182    writer
183        .write_all(&header)
184        .await
185        .map_err(|e| MirrorError::Transport {
186            detail: format!("write framed header: {e}"),
187        })?;
188    writer
189        .write_all(payload)
190        .await
191        .map_err(|e| MirrorError::Transport {
192            detail: format!("write framed payload: {e}"),
193        })?;
194    Ok(())
195}
196
197/// Read a framed message: `[discriminant u8][len u32 BE][payload bytes]`.
198async fn read_framed<R: AsyncRead + Unpin>(reader: &mut R) -> Result<(u8, Vec<u8>), MirrorError> {
199    let mut header = [0u8; 5];
200    reader
201        .read_exact(&mut header)
202        .await
203        .map_err(|e| MirrorError::Transport {
204            detail: format!("read framed header: {e}"),
205        })?;
206    let discriminant = header[0];
207    let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
208
209    if len > MAX_HANDSHAKE_PAYLOAD {
210        return Err(MirrorError::HandshakeCodec {
211            detail: format!("handshake payload {len} bytes exceeds max {MAX_HANDSHAKE_PAYLOAD}"),
212        });
213    }
214
215    let mut payload = vec![0u8; len];
216    reader
217        .read_exact(&mut payload)
218        .await
219        .map_err(|e| MirrorError::Transport {
220            detail: format!("read framed payload: {e}"),
221        })?;
222    Ok((discriminant, payload))
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[tokio::test]
230    async fn hello_roundtrip() {
231        let hello = MirrorHello {
232            source_cluster: "prod-us".into(),
233            source_database_id: "db_01JTEST".into(),
234            last_applied_lsn: 12345,
235            protocol_version: MIRROR_PROTOCOL_VERSION,
236        };
237        let mut buf = Vec::<u8>::new();
238        send_hello(&mut buf, &hello).await.unwrap();
239        let decoded = recv_hello(&mut buf.as_slice()).await.unwrap();
240        assert_eq!(decoded, hello);
241    }
242
243    #[tokio::test]
244    async fn ack_roundtrip() {
245        let ack = MirrorHelloAck {
246            accepted: true,
247            error_code: 0,
248            error_detail: String::new(),
249            source_cluster_id: "prod-us".into(),
250            snapshot_lsn: 42,
251            snapshot_bytes_total: 1024 * 1024,
252        };
253        let mut buf = Vec::<u8>::new();
254        send_ack(&mut buf, &ack).await.unwrap();
255        let decoded = recv_ack(&mut buf.as_slice()).await.unwrap();
256        assert_eq!(decoded, ack);
257    }
258
259    #[tokio::test]
260    async fn wrong_discriminant_rejected() {
261        let ack = MirrorHelloAck {
262            accepted: false,
263            error_code: MIRROR_HELLO_ERR_CLUSTER_ID,
264            error_detail: "bad cluster".into(),
265            source_cluster_id: "wrong".into(),
266            snapshot_lsn: 0,
267            snapshot_bytes_total: 0,
268        };
269        let mut buf = Vec::<u8>::new();
270        // encode as Ack but try to read as Hello
271        send_ack(&mut buf, &ack).await.unwrap();
272        let err = recv_hello(&mut buf.as_slice()).await.unwrap_err();
273        assert!(
274            matches!(err, MirrorError::HandshakeCodec { .. }),
275            "expected HandshakeCodec, got: {err:?}"
276        );
277    }
278}