1use std::io::Write as _;
6use std::path::{Path, PathBuf};
7use std::sync::{Mutex, MutexGuard, PoisonError};
8
9use crc32c::{crc32c, crc32c_append};
10use nodedb_types::{Lsn, MirrorStatus};
11use tracing::{debug, info};
12
13use super::envelope::{
14 BootstrapChunkOutcome, CrossClusterSnapshotEnvelope, PROGRESS_REPORT_CHUNK_BYTES,
15 ProgressCallback,
16};
17use crate::mirror::error::MirrorError;
18
19struct PartialState {
21 source_cluster_id: String,
22 database_id: String,
23 snapshot_lsn: u64,
24 total_bytes: u64,
25 declared_crc32c: u32,
28 bytes_done: u64,
29 next_expected_offset: u64,
30 running_crc: u32,
31 crc_initialized: bool,
32 partial_file: Option<std::fs::File>,
33 partial_path: PathBuf,
34 since_last_report: u64,
36}
37
38pub struct MirrorBootstrapReceiver {
42 state: Mutex<Option<PartialState>>,
43 data_dir: PathBuf,
44 on_progress: ProgressCallback,
46}
47
48impl MirrorBootstrapReceiver {
49 pub fn new(data_dir: PathBuf, on_progress: ProgressCallback) -> Self {
53 Self {
54 state: Mutex::new(None),
55 data_dir,
56 on_progress,
57 }
58 }
59
60 fn lock_state(&self) -> Result<MutexGuard<'_, Option<PartialState>>, MirrorError> {
69 self.state
70 .lock()
71 .map_err(|e: PoisonError<_>| MirrorError::Transport {
72 detail: format!(
73 "mirror bootstrap state poisoned (panic in previous chunk handler): {e}"
74 ),
75 })
76 }
77
78 pub async fn handle_chunk(
84 &self,
85 envelope: CrossClusterSnapshotEnvelope,
86 ) -> Result<BootstrapChunkOutcome, MirrorError> {
87 let database_id = envelope.source_database_id.clone();
88 let recv_dir = self.data_dir.join("recv_snapshots");
89
90 spawn_blocking_io(
91 {
92 let d = recv_dir.clone();
93 move || std::fs::create_dir_all(&d)
94 },
95 "create recv_snapshots dir",
96 )
97 .await?;
98
99 if envelope.offset == 0 {
100 let partial_path = partial_path_for(&recv_dir, &database_id);
102 let file = spawn_blocking_io(
103 {
104 let p = partial_path.clone();
105 move || {
106 std::fs::OpenOptions::new()
107 .write(true)
108 .create(true)
109 .truncate(true)
110 .open(&p)
111 }
112 },
113 "open partial file",
114 )
115 .await?;
116
117 let ps = PartialState {
118 source_cluster_id: envelope.source_cluster_id.clone(),
119 database_id: database_id.clone(),
120 snapshot_lsn: envelope.snapshot_lsn,
121 total_bytes: envelope.total_bytes,
122 declared_crc32c: envelope.total_crc32c,
123 bytes_done: 0,
124 next_expected_offset: 0,
125 running_crc: 0,
126 crc_initialized: false,
127 partial_file: Some(file),
128 partial_path,
129 since_last_report: 0,
130 };
131
132 let mut guard = self.lock_state()?;
133 *guard = Some(ps);
134 } else {
135 let guard = self.lock_state()?;
137 match guard.as_ref() {
138 None => {
139 return Err(MirrorError::SnapshotOffsetRegression {
140 database_id,
141 expected: 0,
142 actual: envelope.offset,
143 });
144 }
145 Some(ps) if ps.next_expected_offset != envelope.offset => {
146 let expected = ps.next_expected_offset;
147 drop(guard);
148 return Err(MirrorError::SnapshotOffsetRegression {
149 database_id,
150 expected,
151 actual: envelope.offset,
152 });
153 }
154 Some(_) => {}
155 }
156 }
157
158 let chunk_bytes = envelope.data.clone();
159 let written_len = chunk_bytes.len() as u64;
160
161 let file = {
163 let file = {
164 let mut guard = self.lock_state()?;
165 let ps = guard.as_mut().ok_or_else(|| MirrorError::Transport {
166 detail: "partial state disappeared during write".into(),
167 })?;
168 ps.partial_file
169 .take()
170 .ok_or_else(|| MirrorError::Transport {
171 detail: "partial file already taken".into(),
172 })?
173 };
174 spawn_blocking_io(
175 {
176 let bytes = chunk_bytes.clone();
177 move || -> std::io::Result<std::fs::File> {
178 let mut f = file;
179 f.write_all(&bytes)?;
180 f.flush()?;
181 Ok(f)
182 }
183 },
184 "write partial chunk",
185 )
186 .await?
187 };
188
189 let (bytes_done, total_bytes, should_report) = {
193 let mut guard = self.lock_state()?;
194 let ps = guard.as_mut().ok_or_else(|| MirrorError::Transport {
195 detail: "partial state disappeared after write".into(),
196 })?;
197 if written_len > 0 {
198 if !ps.crc_initialized {
199 ps.running_crc = crc32c(&chunk_bytes);
200 ps.crc_initialized = true;
201 } else {
202 ps.running_crc = crc32c_append(ps.running_crc, &chunk_bytes);
203 }
204 }
205 ps.next_expected_offset += written_len;
206 ps.bytes_done += written_len;
207 ps.since_last_report += written_len;
208 ps.partial_file = Some(file);
209 let should_report = ps.since_last_report >= PROGRESS_REPORT_CHUNK_BYTES;
210 if should_report {
211 ps.since_last_report = 0;
212 }
213 (ps.bytes_done, ps.total_bytes, should_report)
214 };
215
216 if should_report {
217 (self.on_progress)(MirrorStatus::Bootstrapping {
218 bytes_done,
219 bytes_total: total_bytes,
220 });
221 }
222
223 if !envelope.done {
224 return Ok(BootstrapChunkOutcome::Pending { bytes_done });
225 }
226
227 let ps = {
229 let mut guard = self.lock_state()?;
230 guard.take().ok_or_else(|| MirrorError::Transport {
231 detail: "partial state disappeared before finalization".into(),
232 })?
233 };
234
235 let computed_crc = if ps.crc_initialized {
239 ps.running_crc
240 } else {
241 0
242 };
243 if computed_crc != ps.declared_crc32c {
244 drop(ps.partial_file);
247 let path = ps.partial_path.clone();
248 let _ = tokio::task::spawn_blocking(move || std::fs::remove_file(&path)).await;
249 return Err(MirrorError::SnapshotCrcMismatch {
250 database_id: ps.database_id,
251 stored: ps.declared_crc32c,
252 computed: computed_crc,
253 });
254 }
255
256 drop(ps.partial_file);
258
259 let snapshot_lsn = Lsn::new(ps.snapshot_lsn);
260 let snapshot_path = ps.partial_path.clone();
261
262 let final_path = snapshot_path.with_extension("snapshot");
264 spawn_blocking_io(
265 {
266 let src = snapshot_path.clone();
267 let dst = final_path.clone();
268 move || std::fs::rename(&src, &dst)
269 },
270 "rename partial to snapshot",
271 )
272 .await?;
273
274 info!(
275 database_id = %ps.database_id,
276 source_cluster = %ps.source_cluster_id,
277 snapshot_lsn = ps.snapshot_lsn,
278 total_bytes = ps.bytes_done,
279 crc32c = format!("{:#010x}", computed_crc),
280 "cross-cluster snapshot transfer complete"
281 );
282
283 (self.on_progress)(MirrorStatus::Following);
285
286 Ok(BootstrapChunkOutcome::Committed {
287 snapshot_lsn,
288 snapshot_path: final_path,
289 })
290 }
291
292 pub async fn abort(&self) {
298 let ps = match self.state.lock() {
299 Ok(mut g) => g.take(),
300 Err(p) => p.into_inner().take(),
301 };
302 if let Some(mut ps) = ps {
303 drop(ps.partial_file.take());
304 let path = ps.partial_path.clone();
305 let _ = tokio::task::spawn_blocking(move || std::fs::remove_file(&path)).await;
306 debug!(database_id = %ps.database_id, "aborted in-progress bootstrap");
307 }
308 }
309}
310
311fn partial_path_for(recv_dir: &Path, database_id: &str) -> PathBuf {
312 recv_dir.join(format!("{database_id}.partial"))
313}
314
315async fn spawn_blocking_io<F, T>(f: F, op: &'static str) -> Result<T, MirrorError>
319where
320 F: FnOnce() -> std::io::Result<T> + Send + 'static,
321 T: Send + 'static,
322{
323 match tokio::task::spawn_blocking(f).await {
324 Ok(Ok(v)) => Ok(v),
325 Ok(Err(e)) => Err(MirrorError::Transport {
326 detail: format!("{op}: {e}"),
327 }),
328 Err(join_err) => Err(MirrorError::Transport {
329 detail: format!("{op}: blocking task panicked or was cancelled: {join_err}"),
330 }),
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use std::sync::Arc;
338 use std::sync::atomic::{AtomicU64, Ordering};
339 use tempfile::TempDir;
340
341 fn crc_of_chunks(chunks: &[&[u8]]) -> u32 {
342 let mut state: Option<u32> = None;
343 for c in chunks {
344 state = Some(match state {
345 None => crc32c(c),
346 Some(prev) => crc32c_append(prev, c),
347 });
348 }
349 state.unwrap_or(0)
350 }
351
352 fn make_envelope(
353 offset: u64,
354 data: Vec<u8>,
355 done: bool,
356 total_bytes: u64,
357 total_crc32c: u32,
358 ) -> CrossClusterSnapshotEnvelope {
359 CrossClusterSnapshotEnvelope {
360 source_cluster_id: "prod-us".into(),
361 source_database_id: "db_01TEST".into(),
362 snapshot_lsn: 42,
363 total_bytes,
364 total_crc32c,
365 offset,
366 data,
367 done,
368 }
369 }
370
371 #[tokio::test]
372 async fn bootstrap_streams_full_snapshot_and_transitions() {
373 let tmp = TempDir::new().unwrap();
374 let status_log: Arc<Mutex<Vec<MirrorStatus>>> = Arc::new(Mutex::new(Vec::new()));
375 let log2 = Arc::clone(&status_log);
376 let cb: ProgressCallback = Arc::new(move |s| {
377 log2.lock().unwrap().push(s);
378 });
379
380 let receiver = MirrorBootstrapReceiver::new(tmp.path().to_path_buf(), cb);
381
382 let crc = crc_of_chunks(&[b"hel", b"lo!"]);
383 let c1 = make_envelope(0, b"hel".to_vec(), false, 6, crc);
384 let c2 = make_envelope(3, b"lo!".to_vec(), true, 6, crc);
385
386 let r1 = receiver.handle_chunk(c1).await.unwrap();
387 assert!(matches!(
388 r1,
389 BootstrapChunkOutcome::Pending { bytes_done: 3 }
390 ));
391
392 let r2 = receiver.handle_chunk(c2).await.unwrap();
393 assert!(
394 matches!(r2, BootstrapChunkOutcome::Committed { snapshot_lsn, .. }
395 if snapshot_lsn == Lsn::new(42)),
396 "unexpected outcome: {r2:?}"
397 );
398
399 let log = status_log.lock().unwrap();
400 assert!(
401 log.contains(&MirrorStatus::Following),
402 "Following status not reported; log: {log:?}"
403 );
404 }
405
406 #[tokio::test]
407 async fn crc_mismatch_rejects_final_chunk() {
408 let tmp = TempDir::new().unwrap();
409 let cb: ProgressCallback = Arc::new(|_| {});
410 let receiver = MirrorBootstrapReceiver::new(tmp.path().to_path_buf(), cb);
411
412 let bad_crc = crc_of_chunks(&[b"hel", b"lo!"]).wrapping_add(1);
414 let c1 = make_envelope(0, b"hel".to_vec(), false, 6, bad_crc);
415 let c2 = make_envelope(3, b"lo!".to_vec(), true, 6, bad_crc);
416
417 receiver.handle_chunk(c1).await.unwrap();
418 let err = receiver.handle_chunk(c2).await.unwrap_err();
419 assert!(
420 matches!(err, MirrorError::SnapshotCrcMismatch { stored, computed, .. }
421 if stored == bad_crc && computed == crc_of_chunks(&[b"hel", b"lo!"])),
422 "unexpected error: {err:?}"
423 );
424 }
425
426 #[tokio::test]
427 async fn crc_mismatch_removes_partial_file() {
428 let tmp = TempDir::new().unwrap();
429 let cb: ProgressCallback = Arc::new(|_| {});
430 let receiver = MirrorBootstrapReceiver::new(tmp.path().to_path_buf(), cb);
431
432 let bad_crc = 0xDEAD_BEEF;
433 let c1 = make_envelope(0, b"hello".to_vec(), true, 5, bad_crc);
434 let _ = receiver.handle_chunk(c1).await;
435
436 let partial = tmp.path().join("recv_snapshots").join("db_01TEST.partial");
437 assert!(
438 !partial.exists(),
439 "partial file should be removed after CRC mismatch"
440 );
441 }
442
443 #[tokio::test]
444 async fn empty_snapshot_with_zero_crc_commits() {
445 let tmp = TempDir::new().unwrap();
446 let cb: ProgressCallback = Arc::new(|_| {});
447 let receiver = MirrorBootstrapReceiver::new(tmp.path().to_path_buf(), cb);
448
449 let c1 = make_envelope(0, vec![], true, 0, 0);
452 let r = receiver.handle_chunk(c1).await.unwrap();
453 assert!(matches!(r, BootstrapChunkOutcome::Committed { .. }));
454 }
455
456 #[tokio::test]
457 async fn progress_reported_every_mib() {
458 let tmp = TempDir::new().unwrap();
459 let report_count = Arc::new(AtomicU64::new(0));
460 let count2 = Arc::clone(&report_count);
461 let cb: ProgressCallback = Arc::new(move |s| {
462 if matches!(s, MirrorStatus::Bootstrapping { .. }) {
463 count2.fetch_add(1, Ordering::Relaxed);
464 }
465 });
466
467 let receiver = MirrorBootstrapReceiver::new(tmp.path().to_path_buf(), cb);
468 let mib = PROGRESS_REPORT_CHUNK_BYTES as usize;
469 let total = (mib * 3) as u64;
470 let chunks: Vec<Vec<u8>> = (0..3).map(|_| vec![0u8; mib]).collect();
471 let crc = crc_of_chunks(&chunks.iter().map(|v| v.as_slice()).collect::<Vec<_>>());
472 let c1 = make_envelope(0, chunks[0].clone(), false, total, crc);
473 let c2 = make_envelope(mib as u64, chunks[1].clone(), false, total, crc);
474 let c3 = make_envelope((mib * 2) as u64, chunks[2].clone(), true, total, crc);
475 receiver.handle_chunk(c1).await.unwrap();
476 receiver.handle_chunk(c2).await.unwrap();
477 receiver.handle_chunk(c3).await.unwrap();
478
479 let count = report_count.load(Ordering::Relaxed);
480 assert!(count >= 2, "expected ≥2 Bootstrapping reports, got {count}");
481 }
482
483 #[tokio::test]
484 async fn offset_regression_returns_error() {
485 let tmp = TempDir::new().unwrap();
486 let cb: ProgressCallback = Arc::new(|_| {});
487 let receiver = MirrorBootstrapReceiver::new(tmp.path().to_path_buf(), cb);
488
489 let crc = crc_of_chunks(&[b"abc"]);
490 let c1 = make_envelope(0, b"abc".to_vec(), false, 6, crc);
491 receiver.handle_chunk(c1).await.unwrap();
492
493 let bad = make_envelope(5, b"xx".to_vec(), false, 6, crc);
494 let err = receiver.handle_chunk(bad).await.unwrap_err();
495 assert!(
496 matches!(
497 err,
498 MirrorError::SnapshotOffsetRegression {
499 expected: 3,
500 actual: 5,
501 ..
502 }
503 ),
504 "unexpected error: {err:?}"
505 );
506 }
507
508 #[tokio::test]
509 async fn bytes_done_is_monotonic() {
510 let tmp = TempDir::new().unwrap();
511 let cb: ProgressCallback = Arc::new(|_| {});
512 let receiver = MirrorBootstrapReceiver::new(tmp.path().to_path_buf(), cb);
513
514 let chunks: Vec<Vec<u8>> = (0u64..4).map(|i| vec![i as u8; 16]).collect();
515 let crc = crc_of_chunks(&chunks.iter().map(|v| v.as_slice()).collect::<Vec<_>>());
516
517 let mut prev = 0u64;
518 for (i, c) in chunks.into_iter().enumerate() {
519 let offset = (i as u64) * 16;
520 let done = i == 3;
521 let chunk = make_envelope(offset, c, done, 64, crc);
522 match receiver.handle_chunk(chunk).await.unwrap() {
523 BootstrapChunkOutcome::Pending { bytes_done } => {
524 assert!(bytes_done > prev, "not monotonic at step {i}");
525 prev = bytes_done;
526 }
527 BootstrapChunkOutcome::Committed { .. } => {}
528 }
529 }
530 }
531
532 #[tokio::test]
533 async fn poisoned_state_returns_transport_error() {
534 let tmp = TempDir::new().unwrap();
535 let cb: ProgressCallback = Arc::new(|_| {});
536 let receiver = Arc::new(MirrorBootstrapReceiver::new(tmp.path().to_path_buf(), cb));
537
538 let r2 = Arc::clone(&receiver);
540 let _ = std::thread::spawn(move || {
541 let _g = r2.state.lock().unwrap();
542 panic!("intentional panic to poison the mutex");
543 })
544 .join();
545
546 let crc = crc_of_chunks(&[b"x"]);
547 let c1 = make_envelope(0, b"x".to_vec(), true, 1, crc);
548 let err = receiver.handle_chunk(c1).await.unwrap_err();
549 assert!(
550 matches!(&err, MirrorError::Transport { detail } if detail.contains("poisoned")),
551 "expected Transport(poisoned), got: {err:?}"
552 );
553 }
554}