1use std::net::SocketAddr;
30use std::time::Duration;
31
32use rand::Rng as _;
33use tokio::sync::Mutex;
34use tracing::{info, warn};
35
36use super::error::MirrorError;
37use super::handshake::{
38 MIRROR_HELLO_ERR_BAD_VERSION, MIRROR_HELLO_ERR_CLUSTER_ID, MIRROR_HELLO_ERR_OBSERVER_ONLY,
39 MIRROR_PROTOCOL_VERSION, MirrorHello, MirrorHelloAck, recv_ack, send_hello,
40};
41use super::throttle::SendThrottle;
42
43const RECONNECT_BASE_MS: u64 = 500;
45const RECONNECT_MAX_MS: u64 = 30_000;
47const JITTER_FRACTION: f64 = 0.25;
49
50#[derive(Debug)]
52enum LinkState {
53 Disconnected,
55 Connected(quinn::Connection),
57}
58
59pub struct CrossClusterLink {
64 source_cluster_id: String,
67 source_database_id: String,
69 source_addr: SocketAddr,
71 client_config: quinn::ClientConfig,
73 endpoint: quinn::Endpoint,
75 state: Mutex<LinkState>,
78 pub throttle: SendThrottle,
80}
81
82impl CrossClusterLink {
83 pub fn new(
87 source_cluster_id: String,
88 source_database_id: String,
89 source_addr: SocketAddr,
90 endpoint: quinn::Endpoint,
91 client_config: quinn::ClientConfig,
92 throttle: SendThrottle,
93 ) -> Self {
94 Self {
95 source_cluster_id,
96 source_database_id,
97 source_addr,
98 client_config,
99 endpoint,
100 state: Mutex::new(LinkState::Disconnected),
101 throttle,
102 }
103 }
104
105 pub fn source_cluster_id(&self) -> &str {
107 &self.source_cluster_id
108 }
109
110 pub async fn connect(&self, last_applied_lsn: u64) -> Result<MirrorHelloAck, MirrorError> {
118 let conn = self.dial().await?;
119 let ack = self.run_handshake(&conn, last_applied_lsn).await?;
120 let mut state = self.state.lock().await;
121 *state = LinkState::Connected(conn);
122 Ok(ack)
123 }
124
125 pub async fn open_bidi_stream(
130 &self,
131 ) -> Result<(quinn::SendStream, quinn::RecvStream), MirrorError> {
132 let state = self.state.lock().await;
133 match &*state {
134 LinkState::Disconnected => Err(MirrorError::Transport {
135 detail: "cross-cluster link is disconnected".into(),
136 }),
137 LinkState::Connected(conn) => {
138 conn.open_bi().await.map_err(|e| MirrorError::Transport {
139 detail: format!("open bidi stream to source: {e}"),
140 })
141 }
142 }
143 }
144
145 pub async fn schedule_reconnect(
155 &self,
156 last_applied_lsn: u64,
157 ) -> Result<MirrorHelloAck, MirrorError> {
158 {
159 let mut state = self.state.lock().await;
160 *state = LinkState::Disconnected;
161 }
162 self.throttle.reset();
163
164 let mut delay_ms = RECONNECT_BASE_MS;
165
166 loop {
167 let jitter = jitter_for(delay_ms);
168 let sleep_ms = delay_ms.saturating_add_signed(jitter);
169 info!(
170 source_cluster = %self.source_cluster_id,
171 source_addr = %self.source_addr,
172 sleep_ms,
173 "mirror link: reconnecting after disconnect"
174 );
175 tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
176
177 match self.dial().await {
178 Err(e) => {
179 warn!(
180 source_cluster = %self.source_cluster_id,
181 error = %e,
182 "mirror link: dial failed, will retry"
183 );
184 }
185 Ok(conn) => match self.run_handshake(&conn, last_applied_lsn).await {
186 Err(e @ MirrorError::ClusterIdMismatch { .. })
187 | Err(e @ MirrorError::ObserverRoleViolation { .. })
188 | Err(e @ MirrorError::ProtocolVersionMismatch { .. })
189 | Err(e @ MirrorError::MirrorPromoted { .. }) => {
190 return Err(e);
192 }
193 Err(e) => {
194 warn!(
195 source_cluster = %self.source_cluster_id,
196 error = %e,
197 "mirror link: handshake failed, will retry"
198 );
199 }
200 Ok(ack) => {
201 let mut state = self.state.lock().await;
202 *state = LinkState::Connected(conn);
203 return Ok(ack);
204 }
205 },
206 }
207
208 delay_ms = (delay_ms * 2).min(RECONNECT_MAX_MS);
209 }
210 }
211
212 async fn dial(&self) -> Result<quinn::Connection, MirrorError> {
214 self.endpoint
215 .connect_with(
216 self.client_config.clone(),
217 self.source_addr,
218 &self.source_cluster_id,
219 )
220 .map_err(|e| MirrorError::Transport {
221 detail: format!("connect to source {}: {e}", self.source_addr),
222 })?
223 .await
224 .map_err(|e| MirrorError::Transport {
225 detail: format!("QUIC handshake with source {}: {e}", self.source_addr),
226 })
227 }
228
229 async fn run_handshake(
231 &self,
232 conn: &quinn::Connection,
233 last_applied_lsn: u64,
234 ) -> Result<MirrorHelloAck, MirrorError> {
235 let (mut send, mut recv) = conn.open_bi().await.map_err(|e| MirrorError::Transport {
236 detail: format!("open handshake stream: {e}"),
237 })?;
238
239 let hello = MirrorHello {
240 source_cluster: self.source_cluster_id.clone(),
241 source_database_id: self.source_database_id.clone(),
242 last_applied_lsn,
243 protocol_version: MIRROR_PROTOCOL_VERSION,
244 };
245 send_hello(&mut send, &hello).await?;
246 let _ = send.finish();
247
248 let ack = recv_ack(&mut recv).await?;
249
250 if !ack.accepted {
251 return Err(match ack.error_code {
252 MIRROR_HELLO_ERR_CLUSTER_ID => MirrorError::ClusterIdMismatch {
253 declared: self.source_cluster_id.clone(),
254 remote: ack.source_cluster_id,
255 },
256 MIRROR_HELLO_ERR_OBSERVER_ONLY => MirrorError::ObserverRoleViolation {
257 detail: ack.error_detail,
258 },
259 MIRROR_HELLO_ERR_BAD_VERSION => MirrorError::ProtocolVersionMismatch {
260 local: MIRROR_PROTOCOL_VERSION,
261 detail: ack.error_detail,
262 },
263 other => MirrorError::Transport {
264 detail: format!(
265 "source rejected mirror handshake: code={other:#04x} {}",
266 ack.error_detail
267 ),
268 },
269 });
270 }
271
272 if ack.source_cluster_id != self.source_cluster_id {
274 return Err(MirrorError::ClusterIdMismatch {
275 declared: self.source_cluster_id.clone(),
276 remote: ack.source_cluster_id,
277 });
278 }
279
280 Ok(ack)
281 }
282}
283
284fn jitter_for(delay_ms: u64) -> i64 {
287 let max = (delay_ms as f64 * JITTER_FRACTION) as i64;
288 if max == 0 {
289 return 0;
290 }
291 rand::rng().random_range(-max..=max)
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn jitter_bounds() {
300 for delay in [500u64, 1000, 5000, 30_000] {
301 for _ in 0..200 {
302 let j = jitter_for(delay);
303 let max = (delay as f64 * JITTER_FRACTION) as i64;
304 assert!(
305 j.abs() <= max,
306 "jitter {j} out of bounds ±{max} for delay {delay}"
307 );
308 }
309 }
310 }
311
312 #[test]
313 fn backoff_capped_at_max() {
314 let mut d: u64 = RECONNECT_BASE_MS;
315 for _ in 0..30 {
316 d = (d * 2).min(RECONNECT_MAX_MS);
317 }
318 assert_eq!(d, RECONNECT_MAX_MS);
319 }
320}