use std::net::SocketAddr;
use std::time::Duration;
use rand::Rng as _;
use tokio::sync::Mutex;
use tracing::{info, warn};
use super::error::MirrorError;
use super::handshake::{
MIRROR_HELLO_ERR_BAD_VERSION, MIRROR_HELLO_ERR_CLUSTER_ID, MIRROR_HELLO_ERR_OBSERVER_ONLY,
MIRROR_PROTOCOL_VERSION, MirrorHello, MirrorHelloAck, recv_ack, send_hello,
};
use super::throttle::SendThrottle;
const RECONNECT_BASE_MS: u64 = 500;
const RECONNECT_MAX_MS: u64 = 30_000;
const JITTER_FRACTION: f64 = 0.25;
#[derive(Debug)]
enum LinkState {
Disconnected,
Connected(quinn::Connection),
}
pub struct CrossClusterLink {
source_cluster_id: String,
source_database_id: String,
source_addr: SocketAddr,
client_config: quinn::ClientConfig,
endpoint: quinn::Endpoint,
state: Mutex<LinkState>,
pub throttle: SendThrottle,
}
impl CrossClusterLink {
pub fn new(
source_cluster_id: String,
source_database_id: String,
source_addr: SocketAddr,
endpoint: quinn::Endpoint,
client_config: quinn::ClientConfig,
throttle: SendThrottle,
) -> Self {
Self {
source_cluster_id,
source_database_id,
source_addr,
client_config,
endpoint,
state: Mutex::new(LinkState::Disconnected),
throttle,
}
}
pub fn source_cluster_id(&self) -> &str {
&self.source_cluster_id
}
pub async fn connect(&self, last_applied_lsn: u64) -> Result<MirrorHelloAck, MirrorError> {
let conn = self.dial().await?;
let ack = self.run_handshake(&conn, last_applied_lsn).await?;
let mut state = self.state.lock().await;
*state = LinkState::Connected(conn);
Ok(ack)
}
pub async fn open_bidi_stream(
&self,
) -> Result<(quinn::SendStream, quinn::RecvStream), MirrorError> {
let state = self.state.lock().await;
match &*state {
LinkState::Disconnected => Err(MirrorError::Transport {
detail: "cross-cluster link is disconnected".into(),
}),
LinkState::Connected(conn) => {
conn.open_bi().await.map_err(|e| MirrorError::Transport {
detail: format!("open bidi stream to source: {e}"),
})
}
}
}
pub async fn schedule_reconnect(
&self,
last_applied_lsn: u64,
) -> Result<MirrorHelloAck, MirrorError> {
{
let mut state = self.state.lock().await;
*state = LinkState::Disconnected;
}
self.throttle.reset();
let mut delay_ms = RECONNECT_BASE_MS;
loop {
let jitter = jitter_for(delay_ms);
let sleep_ms = delay_ms.saturating_add_signed(jitter);
info!(
source_cluster = %self.source_cluster_id,
source_addr = %self.source_addr,
sleep_ms,
"mirror link: reconnecting after disconnect"
);
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
match self.dial().await {
Err(e) => {
warn!(
source_cluster = %self.source_cluster_id,
error = %e,
"mirror link: dial failed, will retry"
);
}
Ok(conn) => match self.run_handshake(&conn, last_applied_lsn).await {
Err(e @ MirrorError::ClusterIdMismatch { .. })
| Err(e @ MirrorError::ObserverRoleViolation { .. })
| Err(e @ MirrorError::ProtocolVersionMismatch { .. })
| Err(e @ MirrorError::MirrorPromoted { .. }) => {
return Err(e);
}
Err(e) => {
warn!(
source_cluster = %self.source_cluster_id,
error = %e,
"mirror link: handshake failed, will retry"
);
}
Ok(ack) => {
let mut state = self.state.lock().await;
*state = LinkState::Connected(conn);
return Ok(ack);
}
},
}
delay_ms = (delay_ms * 2).min(RECONNECT_MAX_MS);
}
}
async fn dial(&self) -> Result<quinn::Connection, MirrorError> {
self.endpoint
.connect_with(
self.client_config.clone(),
self.source_addr,
&self.source_cluster_id,
)
.map_err(|e| MirrorError::Transport {
detail: format!("connect to source {}: {e}", self.source_addr),
})?
.await
.map_err(|e| MirrorError::Transport {
detail: format!("QUIC handshake with source {}: {e}", self.source_addr),
})
}
async fn run_handshake(
&self,
conn: &quinn::Connection,
last_applied_lsn: u64,
) -> Result<MirrorHelloAck, MirrorError> {
let (mut send, mut recv) = conn.open_bi().await.map_err(|e| MirrorError::Transport {
detail: format!("open handshake stream: {e}"),
})?;
let hello = MirrorHello {
source_cluster: self.source_cluster_id.clone(),
source_database_id: self.source_database_id.clone(),
last_applied_lsn,
protocol_version: MIRROR_PROTOCOL_VERSION,
};
send_hello(&mut send, &hello).await?;
let _ = send.finish();
let ack = recv_ack(&mut recv).await?;
if !ack.accepted {
return Err(match ack.error_code {
MIRROR_HELLO_ERR_CLUSTER_ID => MirrorError::ClusterIdMismatch {
declared: self.source_cluster_id.clone(),
remote: ack.source_cluster_id,
},
MIRROR_HELLO_ERR_OBSERVER_ONLY => MirrorError::ObserverRoleViolation {
detail: ack.error_detail,
},
MIRROR_HELLO_ERR_BAD_VERSION => MirrorError::ProtocolVersionMismatch {
local: MIRROR_PROTOCOL_VERSION,
detail: ack.error_detail,
},
other => MirrorError::Transport {
detail: format!(
"source rejected mirror handshake: code={other:#04x} {}",
ack.error_detail
),
},
});
}
if ack.source_cluster_id != self.source_cluster_id {
return Err(MirrorError::ClusterIdMismatch {
declared: self.source_cluster_id.clone(),
remote: ack.source_cluster_id,
});
}
Ok(ack)
}
}
fn jitter_for(delay_ms: u64) -> i64 {
let max = (delay_ms as f64 * JITTER_FRACTION) as i64;
if max == 0 {
return 0;
}
rand::rng().random_range(-max..=max)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn jitter_bounds() {
for delay in [500u64, 1000, 5000, 30_000] {
for _ in 0..200 {
let j = jitter_for(delay);
let max = (delay as f64 * JITTER_FRACTION) as i64;
assert!(
j.abs() <= max,
"jitter {j} out of bounds ±{max} for delay {delay}"
);
}
}
}
#[test]
fn backoff_capped_at_max() {
let mut d: u64 = RECONNECT_BASE_MS;
for _ in 0..30 {
d = (d * 2).min(RECONNECT_MAX_MS);
}
assert_eq!(d, RECONNECT_MAX_MS);
}
}