1use tracing::{info, warn};
17
18use super::error::MirrorError;
19use super::handshake::{
20 MIRROR_HELLO_ERR_BAD_VERSION, MIRROR_HELLO_ERR_CLUSTER_ID, MIRROR_PROTOCOL_VERSION,
21 MirrorHelloAck, recv_hello, send_ack,
22};
23
24pub struct SourceHandlerParams {
26 pub local_cluster_id: String,
29 pub latest_snapshot_lsn: u64,
33 pub snapshot_bytes_total: u64,
36}
37
38#[derive(Debug)]
40pub struct HandshakeOutcome {
41 pub source_database_id: String,
43 pub mirror_last_applied_lsn: u64,
46 pub stream_from_lsn: u64,
49}
50
51pub async fn handle_mirror_connection(
56 send: &mut quinn::SendStream,
57 recv: &mut quinn::RecvStream,
58 params: &SourceHandlerParams,
59) -> Result<HandshakeOutcome, MirrorError> {
60 let hello = recv_hello(recv).await?;
61
62 if hello.protocol_version != MIRROR_PROTOCOL_VERSION {
64 let ack = MirrorHelloAck {
65 accepted: false,
66 error_code: MIRROR_HELLO_ERR_BAD_VERSION,
67 error_detail: format!(
68 "unsupported mirror protocol version {}, require {MIRROR_PROTOCOL_VERSION}",
69 hello.protocol_version
70 ),
71 source_cluster_id: params.local_cluster_id.clone(),
72 snapshot_lsn: 0,
73 snapshot_bytes_total: 0,
74 };
75 send_ack(send, &ack).await?;
76 return Err(MirrorError::HandshakeCodec {
77 detail: format!(
78 "mirror declared protocol_version={}, we require {MIRROR_PROTOCOL_VERSION}",
79 hello.protocol_version
80 ),
81 });
82 }
83
84 if hello.source_cluster != params.local_cluster_id {
87 warn!(
88 declared = %hello.source_cluster,
89 ours = %params.local_cluster_id,
90 "mirror handshake rejected: cluster-id mismatch"
91 );
92 let ack = MirrorHelloAck {
93 accepted: false,
94 error_code: MIRROR_HELLO_ERR_CLUSTER_ID,
95 error_detail: format!(
96 "cluster-id mismatch: you declared {:?}, we are {:?}",
97 hello.source_cluster, params.local_cluster_id
98 ),
99 source_cluster_id: params.local_cluster_id.clone(),
100 snapshot_lsn: 0,
101 snapshot_bytes_total: 0,
102 };
103 send_ack(send, &ack).await?;
104 return Err(MirrorError::ClusterIdMismatch {
105 declared: hello.source_cluster,
106 remote: params.local_cluster_id.clone(),
107 });
108 }
109
110 let (snapshot_lsn, snapshot_bytes_total) =
112 if hello.last_applied_lsn < params.latest_snapshot_lsn {
113 (params.latest_snapshot_lsn, params.snapshot_bytes_total)
114 } else {
115 (u64::MAX, 0)
117 };
118
119 let stream_from_lsn = if snapshot_lsn == u64::MAX {
120 hello.last_applied_lsn.saturating_add(1)
121 } else {
122 snapshot_lsn.saturating_add(1)
123 };
124
125 let ack = MirrorHelloAck {
126 accepted: true,
127 error_code: 0,
128 error_detail: String::new(),
129 source_cluster_id: params.local_cluster_id.clone(),
130 snapshot_lsn,
131 snapshot_bytes_total,
132 };
133 send_ack(send, &ack).await?;
134
135 info!(
136 source_cluster = %params.local_cluster_id,
137 database_id = %hello.source_database_id,
138 mirror_last_applied = hello.last_applied_lsn,
139 stream_from_lsn,
140 "mirror handshake accepted"
141 );
142
143 Ok(HandshakeOutcome {
144 source_database_id: hello.source_database_id,
145 mirror_last_applied_lsn: hello.last_applied_lsn,
146 stream_from_lsn,
147 })
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::mirror::handshake::{MIRROR_PROTOCOL_VERSION, MirrorHello, recv_ack, send_hello};
154
155 async fn exchange(
157 hello: MirrorHello,
158 params: SourceHandlerParams,
159 ) -> (Result<HandshakeOutcome, MirrorError>, MirrorHelloAck) {
160 let mut mirror_out = Vec::<u8>::new();
162 let mut source_out = Vec::<u8>::new();
163
164 send_hello(&mut mirror_out, &hello).await.unwrap();
166
167 use crate::mirror::handshake::{recv_hello, send_ack};
173
174 let ack_result: Result<HandshakeOutcome, MirrorError> = async {
175 let mut hello_bytes = mirror_out.as_slice();
176 let hello = recv_hello(&mut hello_bytes).await?;
177 if hello.source_cluster != params.local_cluster_id {
178 let ack = MirrorHelloAck {
179 accepted: false,
180 error_code: MIRROR_HELLO_ERR_CLUSTER_ID,
181 error_detail: "cluster-id mismatch".into(),
182 source_cluster_id: params.local_cluster_id.clone(),
183 snapshot_lsn: 0,
184 snapshot_bytes_total: 0,
185 };
186 send_ack(&mut source_out, &ack).await?;
187 return Err(MirrorError::ClusterIdMismatch {
188 declared: hello.source_cluster,
189 remote: params.local_cluster_id.clone(),
190 });
191 }
192 let ack = MirrorHelloAck {
193 accepted: true,
194 error_code: 0,
195 error_detail: String::new(),
196 source_cluster_id: params.local_cluster_id.clone(),
197 snapshot_lsn: params.latest_snapshot_lsn,
198 snapshot_bytes_total: params.snapshot_bytes_total,
199 };
200 send_ack(&mut source_out, &ack).await?;
201 Ok(HandshakeOutcome {
202 source_database_id: hello.source_database_id,
203 mirror_last_applied_lsn: hello.last_applied_lsn,
204 stream_from_lsn: params.latest_snapshot_lsn.saturating_add(1),
205 })
206 }
207 .await;
208
209 let mut source_buf = source_out.as_slice();
210 let ack = recv_ack(&mut source_buf).await.unwrap();
211 (ack_result, ack)
212 }
213
214 #[tokio::test]
215 async fn valid_handshake_accepted() {
216 let hello = MirrorHello {
217 source_cluster: "prod-us".into(),
218 source_database_id: "db_01TEST".into(),
219 last_applied_lsn: 0,
220 protocol_version: MIRROR_PROTOCOL_VERSION,
221 };
222 let params = SourceHandlerParams {
223 local_cluster_id: "prod-us".into(),
224 latest_snapshot_lsn: 42,
225 snapshot_bytes_total: 1024,
226 };
227 let (outcome, ack) = exchange(hello, params).await;
228 assert!(ack.accepted, "ack should be accepted");
229 assert!(outcome.is_ok(), "outcome: {outcome:?}");
230 let o = outcome.unwrap();
231 assert_eq!(o.source_database_id, "db_01TEST");
232 }
233
234 #[tokio::test]
235 async fn mismatched_cluster_id_rejected() {
236 let hello = MirrorHello {
237 source_cluster: "wrong-cluster".into(),
238 source_database_id: "db_01TEST".into(),
239 last_applied_lsn: 0,
240 protocol_version: MIRROR_PROTOCOL_VERSION,
241 };
242 let params = SourceHandlerParams {
243 local_cluster_id: "prod-us".into(),
244 latest_snapshot_lsn: 0,
245 snapshot_bytes_total: 0,
246 };
247 let (outcome, ack) = exchange(hello, params).await;
248 assert!(!ack.accepted, "ack should be rejected");
249 assert_eq!(ack.error_code, MIRROR_HELLO_ERR_CLUSTER_ID);
250 assert!(
251 matches!(outcome, Err(MirrorError::ClusterIdMismatch { .. })),
252 "outcome: {outcome:?}"
253 );
254 }
255}