use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use super::error::MirrorError;
pub const MIRROR_HELLO: u8 = 0x01;
pub const MIRROR_HELLO_ACK: u8 = 0x02;
pub const MIRROR_HELLO_ERR_CLUSTER_ID: u8 = 0x01;
pub const MIRROR_HELLO_ERR_OBSERVER_ONLY: u8 = 0x02;
pub const MIRROR_HELLO_ERR_BAD_VERSION: u8 = 0x03;
const MAX_HANDSHAKE_PAYLOAD: usize = 4096;
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
zerompk::ToMessagePack,
zerompk::FromMessagePack,
)]
#[msgpack(map)]
pub struct MirrorHello {
pub source_cluster: String,
pub source_database_id: String,
pub last_applied_lsn: u64,
pub protocol_version: u16,
}
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
zerompk::ToMessagePack,
zerompk::FromMessagePack,
)]
#[msgpack(map)]
pub struct MirrorHelloAck {
pub accepted: bool,
pub error_code: u8,
pub error_detail: String,
pub source_cluster_id: String,
pub snapshot_lsn: u64,
pub snapshot_bytes_total: u64,
}
pub const MIRROR_PROTOCOL_VERSION: u16 = 1;
pub async fn send_hello<W: AsyncWrite + Unpin>(
writer: &mut W,
hello: &MirrorHello,
) -> Result<(), MirrorError> {
let payload = zerompk::to_msgpack_vec(hello).map_err(|e| MirrorError::HandshakeCodec {
detail: format!("encode MirrorHello: {e}"),
})?;
write_framed(writer, MIRROR_HELLO, &payload).await
}
pub async fn recv_hello<R: AsyncRead + Unpin>(reader: &mut R) -> Result<MirrorHello, MirrorError> {
let (discriminant, payload) = read_framed(reader).await?;
if discriminant != MIRROR_HELLO {
return Err(MirrorError::HandshakeCodec {
detail: format!(
"expected MirrorHello discriminant {MIRROR_HELLO:#04x}, got {discriminant:#04x}"
),
});
}
zerompk::from_msgpack(&payload).map_err(|e| MirrorError::HandshakeCodec {
detail: format!("decode MirrorHello: {e}"),
})
}
pub async fn send_ack<W: AsyncWrite + Unpin>(
writer: &mut W,
ack: &MirrorHelloAck,
) -> Result<(), MirrorError> {
let payload = zerompk::to_msgpack_vec(ack).map_err(|e| MirrorError::HandshakeCodec {
detail: format!("encode MirrorHelloAck: {e}"),
})?;
write_framed(writer, MIRROR_HELLO_ACK, &payload).await
}
pub async fn recv_ack<R: AsyncRead + Unpin>(reader: &mut R) -> Result<MirrorHelloAck, MirrorError> {
let (discriminant, payload) = read_framed(reader).await?;
if discriminant != MIRROR_HELLO_ACK {
return Err(MirrorError::HandshakeCodec {
detail: format!(
"expected MirrorHelloAck discriminant {MIRROR_HELLO_ACK:#04x}, \
got {discriminant:#04x}"
),
});
}
zerompk::from_msgpack(&payload).map_err(|e| MirrorError::HandshakeCodec {
detail: format!("decode MirrorHelloAck: {e}"),
})
}
async fn write_framed<W: AsyncWrite + Unpin>(
writer: &mut W,
discriminant: u8,
payload: &[u8],
) -> Result<(), MirrorError> {
let len = payload.len() as u32;
let header = [
discriminant,
(len >> 24) as u8,
(len >> 16) as u8,
(len >> 8) as u8,
len as u8,
];
writer
.write_all(&header)
.await
.map_err(|e| MirrorError::Transport {
detail: format!("write framed header: {e}"),
})?;
writer
.write_all(payload)
.await
.map_err(|e| MirrorError::Transport {
detail: format!("write framed payload: {e}"),
})?;
Ok(())
}
async fn read_framed<R: AsyncRead + Unpin>(reader: &mut R) -> Result<(u8, Vec<u8>), MirrorError> {
let mut header = [0u8; 5];
reader
.read_exact(&mut header)
.await
.map_err(|e| MirrorError::Transport {
detail: format!("read framed header: {e}"),
})?;
let discriminant = header[0];
let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
if len > MAX_HANDSHAKE_PAYLOAD {
return Err(MirrorError::HandshakeCodec {
detail: format!("handshake payload {len} bytes exceeds max {MAX_HANDSHAKE_PAYLOAD}"),
});
}
let mut payload = vec![0u8; len];
reader
.read_exact(&mut payload)
.await
.map_err(|e| MirrorError::Transport {
detail: format!("read framed payload: {e}"),
})?;
Ok((discriminant, payload))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn hello_roundtrip() {
let hello = MirrorHello {
source_cluster: "prod-us".into(),
source_database_id: "db_01JTEST".into(),
last_applied_lsn: 12345,
protocol_version: MIRROR_PROTOCOL_VERSION,
};
let mut buf = Vec::<u8>::new();
send_hello(&mut buf, &hello).await.unwrap();
let decoded = recv_hello(&mut buf.as_slice()).await.unwrap();
assert_eq!(decoded, hello);
}
#[tokio::test]
async fn ack_roundtrip() {
let ack = MirrorHelloAck {
accepted: true,
error_code: 0,
error_detail: String::new(),
source_cluster_id: "prod-us".into(),
snapshot_lsn: 42,
snapshot_bytes_total: 1024 * 1024,
};
let mut buf = Vec::<u8>::new();
send_ack(&mut buf, &ack).await.unwrap();
let decoded = recv_ack(&mut buf.as_slice()).await.unwrap();
assert_eq!(decoded, ack);
}
#[tokio::test]
async fn wrong_discriminant_rejected() {
let ack = MirrorHelloAck {
accepted: false,
error_code: MIRROR_HELLO_ERR_CLUSTER_ID,
error_detail: "bad cluster".into(),
source_cluster_id: "wrong".into(),
snapshot_lsn: 0,
snapshot_bytes_total: 0,
};
let mut buf = Vec::<u8>::new();
send_ack(&mut buf, &ack).await.unwrap();
let err = recv_hello(&mut buf.as_slice()).await.unwrap_err();
assert!(
matches!(err, MirrorError::HandshakeCodec { .. }),
"expected HandshakeCodec, got: {err:?}"
);
}
}