1use serde::{Deserialize, Serialize};
31use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
32
33use super::error::MirrorError;
34
35pub const MIRROR_HELLO: u8 = 0x01;
37pub const MIRROR_HELLO_ACK: u8 = 0x02;
39
40pub const MIRROR_HELLO_ERR_CLUSTER_ID: u8 = 0x01;
42pub const MIRROR_HELLO_ERR_OBSERVER_ONLY: u8 = 0x02;
44pub const MIRROR_HELLO_ERR_BAD_VERSION: u8 = 0x03;
47
48const MAX_HANDSHAKE_PAYLOAD: usize = 4096;
53
54#[derive(
56 Debug,
57 Clone,
58 PartialEq,
59 Eq,
60 Serialize,
61 Deserialize,
62 zerompk::ToMessagePack,
63 zerompk::FromMessagePack,
64)]
65#[msgpack(map)]
66pub struct MirrorHello {
67 pub source_cluster: String,
72 pub source_database_id: String,
74 pub last_applied_lsn: u64,
77 pub protocol_version: u16,
79}
80
81#[derive(
83 Debug,
84 Clone,
85 PartialEq,
86 Eq,
87 Serialize,
88 Deserialize,
89 zerompk::ToMessagePack,
90 zerompk::FromMessagePack,
91)]
92#[msgpack(map)]
93pub struct MirrorHelloAck {
94 pub accepted: bool,
96 pub error_code: u8,
99 pub error_detail: String,
101 pub source_cluster_id: String,
104 pub snapshot_lsn: u64,
108 pub snapshot_bytes_total: u64,
110}
111
112pub const MIRROR_PROTOCOL_VERSION: u16 = 1;
114
115pub async fn send_hello<W: AsyncWrite + Unpin>(
117 writer: &mut W,
118 hello: &MirrorHello,
119) -> Result<(), MirrorError> {
120 let payload = zerompk::to_msgpack_vec(hello).map_err(|e| MirrorError::HandshakeCodec {
121 detail: format!("encode MirrorHello: {e}"),
122 })?;
123 write_framed(writer, MIRROR_HELLO, &payload).await
124}
125
126pub async fn recv_hello<R: AsyncRead + Unpin>(reader: &mut R) -> Result<MirrorHello, MirrorError> {
128 let (discriminant, payload) = read_framed(reader).await?;
129 if discriminant != MIRROR_HELLO {
130 return Err(MirrorError::HandshakeCodec {
131 detail: format!(
132 "expected MirrorHello discriminant {MIRROR_HELLO:#04x}, got {discriminant:#04x}"
133 ),
134 });
135 }
136 zerompk::from_msgpack(&payload).map_err(|e| MirrorError::HandshakeCodec {
137 detail: format!("decode MirrorHello: {e}"),
138 })
139}
140
141pub async fn send_ack<W: AsyncWrite + Unpin>(
143 writer: &mut W,
144 ack: &MirrorHelloAck,
145) -> Result<(), MirrorError> {
146 let payload = zerompk::to_msgpack_vec(ack).map_err(|e| MirrorError::HandshakeCodec {
147 detail: format!("encode MirrorHelloAck: {e}"),
148 })?;
149 write_framed(writer, MIRROR_HELLO_ACK, &payload).await
150}
151
152pub async fn recv_ack<R: AsyncRead + Unpin>(reader: &mut R) -> Result<MirrorHelloAck, MirrorError> {
154 let (discriminant, payload) = read_framed(reader).await?;
155 if discriminant != MIRROR_HELLO_ACK {
156 return Err(MirrorError::HandshakeCodec {
157 detail: format!(
158 "expected MirrorHelloAck discriminant {MIRROR_HELLO_ACK:#04x}, \
159 got {discriminant:#04x}"
160 ),
161 });
162 }
163 zerompk::from_msgpack(&payload).map_err(|e| MirrorError::HandshakeCodec {
164 detail: format!("decode MirrorHelloAck: {e}"),
165 })
166}
167
168async fn write_framed<W: AsyncWrite + Unpin>(
170 writer: &mut W,
171 discriminant: u8,
172 payload: &[u8],
173) -> Result<(), MirrorError> {
174 let len = payload.len() as u32;
175 let header = [
176 discriminant,
177 (len >> 24) as u8,
178 (len >> 16) as u8,
179 (len >> 8) as u8,
180 len as u8,
181 ];
182 writer
183 .write_all(&header)
184 .await
185 .map_err(|e| MirrorError::Transport {
186 detail: format!("write framed header: {e}"),
187 })?;
188 writer
189 .write_all(payload)
190 .await
191 .map_err(|e| MirrorError::Transport {
192 detail: format!("write framed payload: {e}"),
193 })?;
194 Ok(())
195}
196
197async fn read_framed<R: AsyncRead + Unpin>(reader: &mut R) -> Result<(u8, Vec<u8>), MirrorError> {
199 let mut header = [0u8; 5];
200 reader
201 .read_exact(&mut header)
202 .await
203 .map_err(|e| MirrorError::Transport {
204 detail: format!("read framed header: {e}"),
205 })?;
206 let discriminant = header[0];
207 let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
208
209 if len > MAX_HANDSHAKE_PAYLOAD {
210 return Err(MirrorError::HandshakeCodec {
211 detail: format!("handshake payload {len} bytes exceeds max {MAX_HANDSHAKE_PAYLOAD}"),
212 });
213 }
214
215 let mut payload = vec![0u8; len];
216 reader
217 .read_exact(&mut payload)
218 .await
219 .map_err(|e| MirrorError::Transport {
220 detail: format!("read framed payload: {e}"),
221 })?;
222 Ok((discriminant, payload))
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[tokio::test]
230 async fn hello_roundtrip() {
231 let hello = MirrorHello {
232 source_cluster: "prod-us".into(),
233 source_database_id: "db_01JTEST".into(),
234 last_applied_lsn: 12345,
235 protocol_version: MIRROR_PROTOCOL_VERSION,
236 };
237 let mut buf = Vec::<u8>::new();
238 send_hello(&mut buf, &hello).await.unwrap();
239 let decoded = recv_hello(&mut buf.as_slice()).await.unwrap();
240 assert_eq!(decoded, hello);
241 }
242
243 #[tokio::test]
244 async fn ack_roundtrip() {
245 let ack = MirrorHelloAck {
246 accepted: true,
247 error_code: 0,
248 error_detail: String::new(),
249 source_cluster_id: "prod-us".into(),
250 snapshot_lsn: 42,
251 snapshot_bytes_total: 1024 * 1024,
252 };
253 let mut buf = Vec::<u8>::new();
254 send_ack(&mut buf, &ack).await.unwrap();
255 let decoded = recv_ack(&mut buf.as_slice()).await.unwrap();
256 assert_eq!(decoded, ack);
257 }
258
259 #[tokio::test]
260 async fn wrong_discriminant_rejected() {
261 let ack = MirrorHelloAck {
262 accepted: false,
263 error_code: MIRROR_HELLO_ERR_CLUSTER_ID,
264 error_detail: "bad cluster".into(),
265 source_cluster_id: "wrong".into(),
266 snapshot_lsn: 0,
267 snapshot_bytes_total: 0,
268 };
269 let mut buf = Vec::<u8>::new();
270 send_ack(&mut buf, &ack).await.unwrap();
272 let err = recv_hello(&mut buf.as_slice()).await.unwrap_err();
273 assert!(
274 matches!(err, MirrorError::HandshakeCodec { .. }),
275 "expected HandshakeCodec, got: {err:?}"
276 );
277 }
278}