use tracing::{info, warn};
use super::error::MirrorError;
use super::handshake::{
MIRROR_HELLO_ERR_BAD_VERSION, MIRROR_HELLO_ERR_CLUSTER_ID, MIRROR_PROTOCOL_VERSION,
MirrorHelloAck, recv_hello, send_ack,
};
pub struct SourceHandlerParams {
pub local_cluster_id: String,
pub latest_snapshot_lsn: u64,
pub snapshot_bytes_total: u64,
}
#[derive(Debug)]
pub struct HandshakeOutcome {
pub source_database_id: String,
pub mirror_last_applied_lsn: u64,
pub stream_from_lsn: u64,
}
pub async fn handle_mirror_connection(
send: &mut quinn::SendStream,
recv: &mut quinn::RecvStream,
params: &SourceHandlerParams,
) -> Result<HandshakeOutcome, MirrorError> {
let hello = recv_hello(recv).await?;
if hello.protocol_version != MIRROR_PROTOCOL_VERSION {
let ack = MirrorHelloAck {
accepted: false,
error_code: MIRROR_HELLO_ERR_BAD_VERSION,
error_detail: format!(
"unsupported mirror protocol version {}, require {MIRROR_PROTOCOL_VERSION}",
hello.protocol_version
),
source_cluster_id: params.local_cluster_id.clone(),
snapshot_lsn: 0,
snapshot_bytes_total: 0,
};
send_ack(send, &ack).await?;
return Err(MirrorError::HandshakeCodec {
detail: format!(
"mirror declared protocol_version={}, we require {MIRROR_PROTOCOL_VERSION}",
hello.protocol_version
),
});
}
if hello.source_cluster != params.local_cluster_id {
warn!(
declared = %hello.source_cluster,
ours = %params.local_cluster_id,
"mirror handshake rejected: cluster-id mismatch"
);
let ack = MirrorHelloAck {
accepted: false,
error_code: MIRROR_HELLO_ERR_CLUSTER_ID,
error_detail: format!(
"cluster-id mismatch: you declared {:?}, we are {:?}",
hello.source_cluster, params.local_cluster_id
),
source_cluster_id: params.local_cluster_id.clone(),
snapshot_lsn: 0,
snapshot_bytes_total: 0,
};
send_ack(send, &ack).await?;
return Err(MirrorError::ClusterIdMismatch {
declared: hello.source_cluster,
remote: params.local_cluster_id.clone(),
});
}
let (snapshot_lsn, snapshot_bytes_total) =
if hello.last_applied_lsn < params.latest_snapshot_lsn {
(params.latest_snapshot_lsn, params.snapshot_bytes_total)
} else {
(u64::MAX, 0)
};
let stream_from_lsn = if snapshot_lsn == u64::MAX {
hello.last_applied_lsn.saturating_add(1)
} else {
snapshot_lsn.saturating_add(1)
};
let ack = MirrorHelloAck {
accepted: true,
error_code: 0,
error_detail: String::new(),
source_cluster_id: params.local_cluster_id.clone(),
snapshot_lsn,
snapshot_bytes_total,
};
send_ack(send, &ack).await?;
info!(
source_cluster = %params.local_cluster_id,
database_id = %hello.source_database_id,
mirror_last_applied = hello.last_applied_lsn,
stream_from_lsn,
"mirror handshake accepted"
);
Ok(HandshakeOutcome {
source_database_id: hello.source_database_id,
mirror_last_applied_lsn: hello.last_applied_lsn,
stream_from_lsn,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mirror::handshake::{MIRROR_PROTOCOL_VERSION, MirrorHello, recv_ack, send_hello};
async fn exchange(
hello: MirrorHello,
params: SourceHandlerParams,
) -> (Result<HandshakeOutcome, MirrorError>, MirrorHelloAck) {
let mut mirror_out = Vec::<u8>::new();
let mut source_out = Vec::<u8>::new();
send_hello(&mut mirror_out, &hello).await.unwrap();
use crate::mirror::handshake::{recv_hello, send_ack};
let ack_result: Result<HandshakeOutcome, MirrorError> = async {
let mut hello_bytes = mirror_out.as_slice();
let hello = recv_hello(&mut hello_bytes).await?;
if hello.source_cluster != params.local_cluster_id {
let ack = MirrorHelloAck {
accepted: false,
error_code: MIRROR_HELLO_ERR_CLUSTER_ID,
error_detail: "cluster-id mismatch".into(),
source_cluster_id: params.local_cluster_id.clone(),
snapshot_lsn: 0,
snapshot_bytes_total: 0,
};
send_ack(&mut source_out, &ack).await?;
return Err(MirrorError::ClusterIdMismatch {
declared: hello.source_cluster,
remote: params.local_cluster_id.clone(),
});
}
let ack = MirrorHelloAck {
accepted: true,
error_code: 0,
error_detail: String::new(),
source_cluster_id: params.local_cluster_id.clone(),
snapshot_lsn: params.latest_snapshot_lsn,
snapshot_bytes_total: params.snapshot_bytes_total,
};
send_ack(&mut source_out, &ack).await?;
Ok(HandshakeOutcome {
source_database_id: hello.source_database_id,
mirror_last_applied_lsn: hello.last_applied_lsn,
stream_from_lsn: params.latest_snapshot_lsn.saturating_add(1),
})
}
.await;
let mut source_buf = source_out.as_slice();
let ack = recv_ack(&mut source_buf).await.unwrap();
(ack_result, ack)
}
#[tokio::test]
async fn valid_handshake_accepted() {
let hello = MirrorHello {
source_cluster: "prod-us".into(),
source_database_id: "db_01TEST".into(),
last_applied_lsn: 0,
protocol_version: MIRROR_PROTOCOL_VERSION,
};
let params = SourceHandlerParams {
local_cluster_id: "prod-us".into(),
latest_snapshot_lsn: 42,
snapshot_bytes_total: 1024,
};
let (outcome, ack) = exchange(hello, params).await;
assert!(ack.accepted, "ack should be accepted");
assert!(outcome.is_ok(), "outcome: {outcome:?}");
let o = outcome.unwrap();
assert_eq!(o.source_database_id, "db_01TEST");
}
#[tokio::test]
async fn mismatched_cluster_id_rejected() {
let hello = MirrorHello {
source_cluster: "wrong-cluster".into(),
source_database_id: "db_01TEST".into(),
last_applied_lsn: 0,
protocol_version: MIRROR_PROTOCOL_VERSION,
};
let params = SourceHandlerParams {
local_cluster_id: "prod-us".into(),
latest_snapshot_lsn: 0,
snapshot_bytes_total: 0,
};
let (outcome, ack) = exchange(hello, params).await;
assert!(!ack.accepted, "ack should be rejected");
assert_eq!(ack.error_code, MIRROR_HELLO_ERR_CLUSTER_ID);
assert!(
matches!(outcome, Err(MirrorError::ClusterIdMismatch { .. })),
"outcome: {outcome:?}"
);
}
}