1#[allow(clippy::disallowed_types)] use std::collections::HashMap;
20use std::sync::Arc;
21use std::time::Duration;
22
23use tracing::{info, warn};
24
25use super::checkpointer::{verify_integrity, Checkpointer, CheckpointerError};
26use super::layout::CheckpointId;
27
28pub trait Restorable: Send + Sync {
30 fn restore(&mut self, data: &[u8]) -> Result<(), RecoveryError>;
37
38 fn apply_delta(&mut self, delta: &[u8]) -> Result<(), RecoveryError>;
45
46 fn operator_id(&self) -> &str;
48}
49
50pub trait Seekable: Send + Sync {
52 fn seek(&mut self, offsets: &HashMap<String, String>) -> Result<(), RecoveryError>;
59
60 fn source_id(&self) -> &str;
62}
63
64pub trait TypedSeekable: Send + Sync {
69 fn seek_typed(
76 &mut self,
77 position: &crate::checkpoint::source_offsets::SourcePosition,
78 ) -> Result<(), RecoveryError>;
79
80 fn can_seek_to(&self, position: &crate::checkpoint::source_offsets::SourcePosition) -> bool {
85 let _ = position;
86 true
87 }
88
89 fn source_id(&self) -> &str;
91}
92
93#[derive(Debug, Clone)]
95pub struct RecoveryConfig {
96 pub max_fallback_attempts: usize,
98 pub verify_integrity: bool,
100 pub recovery_timeout: Duration,
102}
103
104impl Default for RecoveryConfig {
105 fn default() -> Self {
106 Self {
107 max_fallback_attempts: 3,
108 verify_integrity: true,
109 recovery_timeout: Duration::from_secs(300),
110 }
111 }
112}
113
114#[derive(Debug)]
116pub struct RecoveryResult {
117 pub checkpoint_id: CheckpointId,
119 pub epoch: u64,
121 pub watermark: Option<i64>,
123 pub operators_restored: usize,
125 pub sources_seeked: usize,
127}
128
129#[derive(Debug, thiserror::Error)]
131pub enum RecoveryError {
132 #[error("no checkpoint available")]
134 NoCheckpointAvailable,
135
136 #[error("all checkpoints corrupt after {0} attempts")]
138 AllCheckpointsCorrupt(usize),
139
140 #[error("operator mismatch: expected {expected:?}, found {found:?}")]
142 OperatorMismatch {
143 expected: Vec<String>,
145 found: Vec<String>,
147 },
148
149 #[error("restore failed for operator `{operator}`: {reason}")]
151 RestoreFailed {
152 operator: String,
154 reason: String,
156 },
157
158 #[error("seek failed for source `{source_id}`: {reason}")]
160 SeekFailed {
161 source_id: String,
163 reason: String,
165 },
166
167 #[error("integrity check failed: {0}")]
169 IntegrityFailed(#[from] CheckpointerError),
170
171 #[error("recovery timed out after {0:?}")]
173 Timeout(Duration),
174}
175
176pub struct RecoveryManager {
181 checkpointer: Arc<dyn Checkpointer>,
182 config: RecoveryConfig,
183}
184
185impl RecoveryManager {
186 pub fn new(checkpointer: Arc<dyn Checkpointer>, config: RecoveryConfig) -> Self {
188 Self {
189 checkpointer,
190 config,
191 }
192 }
193
194 pub async fn recover(
206 &self,
207 restorables: &mut [&mut dyn Restorable],
208 seekables: &mut [&mut dyn Seekable],
209 ) -> Result<RecoveryResult, RecoveryError> {
210 let candidates = self.discover_candidates().await?;
212
213 let max_attempts = self.config.max_fallback_attempts.min(candidates.len());
215 let mut last_error = None;
216
217 for (attempt, id) in candidates.iter().take(max_attempts).enumerate() {
218 info!(
219 checkpoint_id = %id,
220 attempt = attempt + 1,
221 "attempting recovery from checkpoint"
222 );
223
224 match self.try_recover_from(id, restorables, seekables).await {
225 Ok(result) => return Ok(result),
226 Err(e) => {
227 warn!(
228 checkpoint_id = %id,
229 error = %e,
230 "checkpoint recovery failed, trying fallback"
231 );
232 last_error = Some(e);
233 }
234 }
235 }
236
237 match last_error {
239 Some(e) => {
240 warn!(
241 attempts = max_attempts,
242 last_error = %e,
243 "all checkpoint recovery attempts failed"
244 );
245 Err(RecoveryError::AllCheckpointsCorrupt(max_attempts))
246 }
247 None => Err(RecoveryError::NoCheckpointAvailable),
248 }
249 }
250
251 pub async fn recover_typed(
260 &self,
261 restorables: &mut [&mut dyn Restorable],
262 typed_seekables: &mut [&mut dyn TypedSeekable],
263 ) -> Result<RecoveryResult, RecoveryError> {
264 let candidates = self.discover_candidates().await?;
265
266 let max_attempts = self.config.max_fallback_attempts.min(candidates.len());
267 let mut last_error = None;
268
269 for (attempt, id) in candidates.iter().take(max_attempts).enumerate() {
270 info!(
271 checkpoint_id = %id,
272 attempt = attempt + 1,
273 "attempting typed recovery from checkpoint"
274 );
275
276 match self
277 .try_recover_typed_from(id, restorables, typed_seekables)
278 .await
279 {
280 Ok(result) => return Ok(result),
281 Err(e) => {
282 warn!(
283 checkpoint_id = %id,
284 error = %e,
285 "typed checkpoint recovery failed, trying fallback"
286 );
287 last_error = Some(e);
288 }
289 }
290 }
291
292 match last_error {
293 Some(e) => {
294 warn!(
295 attempts = max_attempts,
296 last_error = %e,
297 "all typed checkpoint recovery attempts failed"
298 );
299 Err(RecoveryError::AllCheckpointsCorrupt(max_attempts))
300 }
301 None => Err(RecoveryError::NoCheckpointAvailable),
302 }
303 }
304
305 async fn discover_candidates(&self) -> Result<Vec<CheckpointId>, RecoveryError> {
311 let mut candidates = Vec::new();
312
313 match self.checkpointer.read_latest().await {
315 Ok(Some(id)) => {
316 info!(checkpoint_id = %id, "found latest checkpoint pointer");
317 candidates.push(id);
318 }
319 Ok(None) => {
320 info!("no latest checkpoint pointer found");
321 }
322 Err(e) => {
323 warn!(error = %e, "failed to read latest pointer, trying list");
324 }
325 }
326
327 match self.checkpointer.list_checkpoints().await {
329 Ok(mut ids) => {
330 ids.sort();
331 ids.reverse(); #[allow(clippy::disallowed_types)] let mut seen: std::collections::HashSet<CheckpointId> =
334 candidates.iter().copied().collect();
335 for id in ids {
336 if seen.insert(id) {
337 candidates.push(id);
338 }
339 }
340 }
341 Err(e) => {
342 warn!(error = %e, "failed to list checkpoints");
343 }
344 }
345
346 if candidates.is_empty() {
347 return Err(RecoveryError::NoCheckpointAvailable);
348 }
349
350 Ok(candidates)
351 }
352
353 async fn try_recover_from(
355 &self,
356 id: &CheckpointId,
357 restorables: &mut [&mut dyn Restorable],
358 seekables: &mut [&mut dyn Seekable],
359 ) -> Result<RecoveryResult, RecoveryError> {
360 let manifest = self.checkpointer.load_manifest(id).await?;
362
363 let expected: Vec<String> = restorables
365 .iter()
366 .map(|r| r.operator_id().to_string())
367 .collect();
368 let found: Vec<String> = manifest.operators.keys().cloned().collect();
369
370 for op_id in &expected {
374 if !manifest.operators.contains_key(op_id) {
375 return Err(RecoveryError::OperatorMismatch {
376 expected: expected.clone(),
377 found,
378 });
379 }
380 }
381
382 let mut operators_restored = 0;
384 for restorable in restorables.iter_mut() {
385 let op_id = restorable.operator_id().to_string();
386 if let Some(op_entry) = manifest.operators.get(&op_id) {
387 let mut partitions = op_entry.partitions.clone();
390 partitions.sort_by_key(|p| p.is_delta);
391 for partition in &partitions {
392 let data = self.checkpointer.load_artifact(&partition.path).await?;
394
395 if self.config.verify_integrity {
397 if let Some(expected_sha) = &partition.sha256 {
398 verify_integrity(&partition.path, &data, expected_sha)?;
399 }
400 }
401
402 if partition.is_delta {
404 restorable.apply_delta(&data).map_err(|e| {
405 RecoveryError::RestoreFailed {
406 operator: op_id.clone(),
407 reason: e.to_string(),
408 }
409 })?;
410 } else {
411 restorable
412 .restore(&data)
413 .map_err(|e| RecoveryError::RestoreFailed {
414 operator: op_id.clone(),
415 reason: e.to_string(),
416 })?;
417 }
418 }
419 operators_restored += 1;
420 }
421 }
422
423 let mut sources_seeked = 0;
425 for seekable in seekables.iter_mut() {
426 let src_id = seekable.source_id().to_string();
427 if let Some(offset_entry) = manifest.source_offsets.get(&src_id) {
428 seekable
429 .seek(&offset_entry.offsets)
430 .map_err(|e| RecoveryError::SeekFailed {
431 source_id: src_id,
432 reason: e.to_string(),
433 })?;
434 sources_seeked += 1;
435 }
436 }
437
438 Ok(RecoveryResult {
439 checkpoint_id: manifest.checkpoint_id,
440 epoch: manifest.epoch,
441 watermark: manifest.watermark,
442 operators_restored,
443 sources_seeked,
444 })
445 }
446
447 async fn try_recover_typed_from(
452 &self,
453 id: &CheckpointId,
454 restorables: &mut [&mut dyn Restorable],
455 typed_seekables: &mut [&mut dyn TypedSeekable],
456 ) -> Result<RecoveryResult, RecoveryError> {
457 use crate::checkpoint::source_offsets::SourcePosition;
458
459 let manifest = self.checkpointer.load_manifest(id).await?;
460
461 let expected: Vec<String> = restorables
463 .iter()
464 .map(|r| r.operator_id().to_string())
465 .collect();
466 let found: Vec<String> = manifest.operators.keys().cloned().collect();
467 for op_id in &expected {
468 if !manifest.operators.contains_key(op_id) {
469 return Err(RecoveryError::OperatorMismatch {
470 expected: expected.clone(),
471 found,
472 });
473 }
474 }
475
476 let mut operators_restored = 0;
478 for restorable in restorables.iter_mut() {
479 let op_id = restorable.operator_id().to_string();
480 if let Some(op_entry) = manifest.operators.get(&op_id) {
481 let mut partitions = op_entry.partitions.clone();
483 partitions.sort_by_key(|p| p.is_delta);
484 for partition in &partitions {
485 let data = self.checkpointer.load_artifact(&partition.path).await?;
486
487 if self.config.verify_integrity {
488 if let Some(expected_sha) = &partition.sha256 {
489 verify_integrity(&partition.path, &data, expected_sha)?;
490 }
491 }
492
493 if partition.is_delta {
494 restorable.apply_delta(&data).map_err(|e| {
495 RecoveryError::RestoreFailed {
496 operator: op_id.clone(),
497 reason: e.to_string(),
498 }
499 })?;
500 } else {
501 restorable
502 .restore(&data)
503 .map_err(|e| RecoveryError::RestoreFailed {
504 operator: op_id.clone(),
505 reason: e.to_string(),
506 })?;
507 }
508 }
509 operators_restored += 1;
510 }
511 }
512
513 let mut sources_seeked = 0;
515 for seekable in typed_seekables.iter_mut() {
516 let src_id = seekable.source_id().to_string();
517 if let Some(offset_entry) = manifest.source_offsets.get(&src_id) {
518 if let Some(position) = SourcePosition::from_offset_entry(offset_entry) {
519 if !seekable.can_seek_to(&position) {
521 return Err(RecoveryError::SeekFailed {
522 source_id: src_id,
523 reason: "source reports position is unreachable".into(),
524 });
525 }
526 seekable
527 .seek_typed(&position)
528 .map_err(|e| RecoveryError::SeekFailed {
529 source_id: src_id,
530 reason: e.to_string(),
531 })?;
532 sources_seeked += 1;
533 } else {
534 warn!(
538 source_id = %src_id,
539 source_type = %offset_entry.source_type,
540 "could not parse typed position from offset entry"
541 );
542 return Err(RecoveryError::SeekFailed {
543 source_id: src_id,
544 reason: format!(
545 "could not parse typed position for source_type '{}'",
546 offset_entry.source_type
547 ),
548 });
549 }
550 }
551 }
552
553 Ok(RecoveryResult {
554 checkpoint_id: manifest.checkpoint_id,
555 epoch: manifest.epoch,
556 watermark: manifest.watermark,
557 operators_restored,
558 sources_seeked,
559 })
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use crate::checkpoint::checkpointer::ObjectStoreCheckpointer;
567 use crate::checkpoint::layout::{
568 CheckpointManifestV2, CheckpointPaths, OperatorSnapshotEntry, PartitionSnapshotEntry,
569 SourceOffsetEntry,
570 };
571 use bytes::Bytes;
572 use object_store::memory::InMemory;
573 use sha2::{Digest, Sha256};
574
575 struct TestRestorable {
578 id: String,
579 state: Vec<u8>,
580 deltas: Vec<Vec<u8>>,
581 }
582
583 impl TestRestorable {
584 fn new(id: &str) -> Self {
585 Self {
586 id: id.to_string(),
587 state: Vec::new(),
588 deltas: Vec::new(),
589 }
590 }
591 }
592
593 impl Restorable for TestRestorable {
594 fn restore(&mut self, data: &[u8]) -> Result<(), RecoveryError> {
595 self.state = data.to_vec();
596 Ok(())
597 }
598
599 fn apply_delta(&mut self, delta: &[u8]) -> Result<(), RecoveryError> {
600 self.deltas.push(delta.to_vec());
601 self.state.extend_from_slice(delta);
602 Ok(())
603 }
604
605 fn operator_id(&self) -> &str {
606 &self.id
607 }
608 }
609
610 struct TestSeekable {
611 id: String,
612 offsets: Option<HashMap<String, String>>,
613 }
614
615 impl TestSeekable {
616 fn new(id: &str) -> Self {
617 Self {
618 id: id.to_string(),
619 offsets: None,
620 }
621 }
622 }
623
624 impl Seekable for TestSeekable {
625 fn seek(&mut self, offsets: &HashMap<String, String>) -> Result<(), RecoveryError> {
626 self.offsets = Some(offsets.clone());
627 Ok(())
628 }
629
630 fn source_id(&self) -> &str {
631 &self.id
632 }
633 }
634
635 fn sha256_hex(data: &[u8]) -> String {
636 let mut hasher = Sha256::new();
637 hasher.update(data);
638 format!("{:x}", hasher.finalize())
639 }
640
641 fn setup_checkpointer() -> Arc<ObjectStoreCheckpointer> {
642 let store = Arc::new(InMemory::new());
643 let paths = CheckpointPaths::default();
644 Arc::new(ObjectStoreCheckpointer::new(store, paths, 4))
645 }
646
647 fn make_paths() -> CheckpointPaths {
648 CheckpointPaths::default()
649 }
650
651 async fn save_checkpoint(
652 ckpt: &ObjectStoreCheckpointer,
653 id: &CheckpointId,
654 epoch: u64,
655 operators: Vec<(&str, &[u8], bool)>,
656 sources: Vec<(&str, HashMap<String, String>)>,
657 watermark: Option<i64>,
658 ) {
659 let paths = make_paths();
660 let mut manifest = CheckpointManifestV2::new(*id, epoch);
661 manifest.watermark = watermark;
662
663 for (op_name, data, is_delta) in &operators {
664 let (artifact_path, digest) = if *is_delta {
667 let d = ckpt
668 .save_delta(id, op_name, 0, Bytes::from(data.to_vec()))
669 .await
670 .unwrap();
671 (paths.delta(id, op_name, 0), d)
672 } else {
673 let d = ckpt
674 .save_snapshot(id, op_name, 0, Bytes::from(data.to_vec()))
675 .await
676 .unwrap();
677 (paths.snapshot(id, op_name, 0), d)
678 };
679
680 manifest.operators.insert(
681 op_name.to_string(),
682 OperatorSnapshotEntry {
683 partitions: vec![PartitionSnapshotEntry {
684 partition_id: 0,
685 is_delta: *is_delta,
686 path: artifact_path,
687 size_bytes: data.len() as u64,
688 sha256: Some(digest),
689 }],
690 total_bytes: data.len() as u64,
691 },
692 );
693 }
694
695 for (src_name, offsets) in sources {
696 manifest.source_offsets.insert(
697 src_name.to_string(),
698 SourceOffsetEntry {
699 source_type: "test".into(),
700 offsets,
701 epoch,
702 },
703 );
704 }
705
706 ckpt.save_manifest(&manifest).await.unwrap();
707 ckpt.update_latest(id).await.unwrap();
708 }
709
710 #[tokio::test]
713 async fn test_recovery_fresh_start() {
714 let ckpt = setup_checkpointer();
715 let rm = RecoveryManager::new(ckpt, RecoveryConfig::default());
716
717 let mut op = TestRestorable::new("op1");
718 let result = rm.recover(&mut [&mut op], &mut []).await;
719
720 assert!(matches!(
721 result.unwrap_err(),
722 RecoveryError::NoCheckpointAvailable
723 ));
724 }
725
726 #[tokio::test]
727 async fn test_recovery_full_snapshot() {
728 let ckpt = setup_checkpointer();
729 let id = CheckpointId::now();
730
731 save_checkpoint(
732 &ckpt,
733 &id,
734 5,
735 vec![("op1", b"full_state", false)],
736 vec![("kafka", HashMap::from([("p0".into(), "100".into())]))],
737 Some(9999),
738 )
739 .await;
740
741 let rm = RecoveryManager::new(ckpt, RecoveryConfig::default());
742 let mut op = TestRestorable::new("op1");
743 let mut src = TestSeekable::new("kafka");
744
745 let result = rm.recover(&mut [&mut op], &mut [&mut src]).await.unwrap();
746
747 assert_eq!(result.checkpoint_id, id);
748 assert_eq!(result.epoch, 5);
749 assert_eq!(result.watermark, Some(9999));
750 assert_eq!(result.operators_restored, 1);
751 assert_eq!(result.sources_seeked, 1);
752 assert_eq!(op.state, b"full_state");
753 assert_eq!(
754 src.offsets.as_ref().unwrap().get("p0"),
755 Some(&"100".to_string())
756 );
757 }
758
759 #[tokio::test]
760 async fn test_recovery_incremental() {
761 let ckpt = setup_checkpointer();
762 let id = CheckpointId::now();
763 let paths = make_paths();
764
765 let full_data = b"base_state";
767 let full_digest = ckpt
768 .save_snapshot(&id, "op1", 0, Bytes::from_static(full_data))
769 .await
770 .unwrap();
771
772 let delta_data = b"_delta";
774 let delta_digest = ckpt
775 .save_delta(&id, "op1", 1, Bytes::from_static(delta_data))
776 .await
777 .unwrap();
778
779 let mut manifest = CheckpointManifestV2::new(id, 10);
780 manifest.operators.insert(
781 "op1".into(),
782 OperatorSnapshotEntry {
783 partitions: vec![
784 PartitionSnapshotEntry {
785 partition_id: 0,
786 is_delta: false,
787 path: paths.snapshot(&id, "op1", 0),
788 size_bytes: full_data.len() as u64,
789 sha256: Some(full_digest),
790 },
791 PartitionSnapshotEntry {
792 partition_id: 1,
793 is_delta: true,
794 path: paths.delta(&id, "op1", 1),
795 size_bytes: delta_data.len() as u64,
796 sha256: Some(delta_digest),
797 },
798 ],
799 total_bytes: (full_data.len() + delta_data.len()) as u64,
800 },
801 );
802 ckpt.save_manifest(&manifest).await.unwrap();
803 ckpt.update_latest(&id).await.unwrap();
804
805 let rm = RecoveryManager::new(ckpt, RecoveryConfig::default());
806 let mut op = TestRestorable::new("op1");
807
808 let result = rm.recover(&mut [&mut op], &mut []).await.unwrap();
809 assert_eq!(result.epoch, 10);
810 assert_eq!(op.state, b"base_state_delta");
812 assert_eq!(op.deltas.len(), 1);
813 }
814
815 #[tokio::test]
816 async fn test_recovery_integrity_check() {
817 let ckpt = setup_checkpointer();
818 let id = CheckpointId::now();
819
820 let data = b"real_data";
822 let path_str = {
823 let paths = CheckpointPaths::default();
824 paths.snapshot(&id, "op1", 0)
825 };
826 ckpt.save_snapshot(&id, "op1", 0, Bytes::from_static(data))
827 .await
828 .unwrap();
829
830 let mut manifest = CheckpointManifestV2::new(id, 1);
831 manifest.operators.insert(
832 "op1".into(),
833 OperatorSnapshotEntry {
834 partitions: vec![PartitionSnapshotEntry {
835 partition_id: 0,
836 is_delta: false,
837 path: path_str,
838 size_bytes: data.len() as u64,
839 sha256: Some("bad_hash_value".into()), }],
841 total_bytes: data.len() as u64,
842 },
843 );
844 ckpt.save_manifest(&manifest).await.unwrap();
845 ckpt.update_latest(&id).await.unwrap();
846
847 let rm = RecoveryManager::new(
848 ckpt,
849 RecoveryConfig {
850 max_fallback_attempts: 1,
851 ..Default::default()
852 },
853 );
854 let mut op = TestRestorable::new("op1");
855
856 let result = rm.recover(&mut [&mut op], &mut []).await;
857 assert!(result.is_err());
858 assert!(matches!(
859 result.unwrap_err(),
860 RecoveryError::AllCheckpointsCorrupt(1)
861 ));
862 }
863
864 #[tokio::test]
865 async fn test_recovery_fallback() {
866 let ckpt = setup_checkpointer();
867
868 let id1 = CheckpointId::now();
870 {
871 let path_str = {
872 let paths = CheckpointPaths::default();
873 paths.snapshot(&id1, "op1", 0)
874 };
875 ckpt.save_snapshot(&id1, "op1", 0, Bytes::from_static(b"data1"))
876 .await
877 .unwrap();
878 let mut manifest = CheckpointManifestV2::new(id1, 1);
879 manifest.operators.insert(
880 "op1".into(),
881 OperatorSnapshotEntry {
882 partitions: vec![PartitionSnapshotEntry {
883 partition_id: 0,
884 is_delta: false,
885 path: path_str,
886 size_bytes: 5,
887 sha256: Some("corrupted".into()),
888 }],
889 total_bytes: 5,
890 },
891 );
892 ckpt.save_manifest(&manifest).await.unwrap();
893 }
894
895 tokio::time::sleep(Duration::from_millis(2)).await;
897
898 let id2 = CheckpointId::now();
900 save_checkpoint(
901 &ckpt,
902 &id2,
903 2,
904 vec![("op1", b"good_state", false)],
905 vec![],
906 None,
907 )
908 .await;
909
910 ckpt.update_latest(&id1).await.unwrap();
912
913 let rm = RecoveryManager::new(ckpt, RecoveryConfig::default());
914 let mut op = TestRestorable::new("op1");
915
916 let result = rm.recover(&mut [&mut op], &mut []).await.unwrap();
917 assert_eq!(result.checkpoint_id, id2);
919 assert_eq!(op.state, b"good_state");
920 }
921
922 #[tokio::test]
923 async fn test_recovery_dual_source_discovery() {
924 let ckpt = setup_checkpointer();
926 let id = CheckpointId::now();
927
928 let data = b"discovered";
930 let digest = sha256_hex(data);
931 let path_str = {
932 let paths = CheckpointPaths::default();
933 paths.snapshot(&id, "op1", 0)
934 };
935 ckpt.save_snapshot(&id, "op1", 0, Bytes::from_static(data))
936 .await
937 .unwrap();
938
939 let mut manifest = CheckpointManifestV2::new(id, 7);
940 manifest.operators.insert(
941 "op1".into(),
942 OperatorSnapshotEntry {
943 partitions: vec![PartitionSnapshotEntry {
944 partition_id: 0,
945 is_delta: false,
946 path: path_str,
947 size_bytes: data.len() as u64,
948 sha256: Some(digest),
949 }],
950 total_bytes: data.len() as u64,
951 },
952 );
953 ckpt.save_manifest(&manifest).await.unwrap();
954 let rm = RecoveryManager::new(ckpt, RecoveryConfig::default());
957 let mut op = TestRestorable::new("op1");
958
959 let result = rm.recover(&mut [&mut op], &mut []).await.unwrap();
960 assert_eq!(result.checkpoint_id, id);
961 assert_eq!(op.state, b"discovered");
962 }
963
964 #[tokio::test]
965 async fn test_recovery_operator_mismatch() {
966 let ckpt = setup_checkpointer();
967 let id = CheckpointId::now();
968
969 save_checkpoint(&ckpt, &id, 1, vec![("op1", b"state", false)], vec![], None).await;
971
972 let rm = RecoveryManager::new(
973 ckpt,
974 RecoveryConfig {
975 max_fallback_attempts: 1,
976 ..Default::default()
977 },
978 );
979 let mut op = TestRestorable::new("op2");
980
981 let result = rm.recover(&mut [&mut op], &mut []).await;
982 assert!(result.is_err());
983 assert!(matches!(
984 result.unwrap_err(),
985 RecoveryError::AllCheckpointsCorrupt(1)
986 ));
987 }
988
989 struct TestTypedSeekable {
992 id: String,
993 position: Option<crate::checkpoint::source_offsets::SourcePosition>,
994 reachable: bool,
995 }
996
997 impl TestTypedSeekable {
998 fn new(id: &str) -> Self {
999 Self {
1000 id: id.to_string(),
1001 position: None,
1002 reachable: true,
1003 }
1004 }
1005
1006 fn unreachable(id: &str) -> Self {
1007 Self {
1008 id: id.to_string(),
1009 position: None,
1010 reachable: false,
1011 }
1012 }
1013 }
1014
1015 impl TypedSeekable for TestTypedSeekable {
1016 fn seek_typed(
1017 &mut self,
1018 position: &crate::checkpoint::source_offsets::SourcePosition,
1019 ) -> Result<(), RecoveryError> {
1020 self.position = Some(position.clone());
1021 Ok(())
1022 }
1023
1024 fn can_seek_to(
1025 &self,
1026 _position: &crate::checkpoint::source_offsets::SourcePosition,
1027 ) -> bool {
1028 self.reachable
1029 }
1030
1031 fn source_id(&self) -> &str {
1032 &self.id
1033 }
1034 }
1035
1036 #[tokio::test]
1037 async fn test_typed_recovery() {
1038 let ckpt = setup_checkpointer();
1039 let id = CheckpointId::now();
1040
1041 save_checkpoint(
1042 &ckpt,
1043 &id,
1044 5,
1045 vec![("op1", b"state", false)],
1046 vec![(
1047 "kafka-src",
1048 HashMap::from([
1049 ("group_id".into(), "g1".into()),
1050 ("events-0".into(), "100".into()),
1051 ]),
1052 )],
1053 Some(8000),
1054 )
1055 .await;
1056
1057 {
1059 let mut manifest = ckpt.load_manifest(&id).await.unwrap();
1060 manifest
1061 .source_offsets
1062 .get_mut("kafka-src")
1063 .unwrap()
1064 .source_type = "kafka".into();
1065 ckpt.save_manifest(&manifest).await.unwrap();
1066 }
1067
1068 let rm = RecoveryManager::new(ckpt, RecoveryConfig::default());
1069 let mut op = TestRestorable::new("op1");
1070 let mut src = TestTypedSeekable::new("kafka-src");
1071
1072 let result = rm
1073 .recover_typed(&mut [&mut op], &mut [&mut src])
1074 .await
1075 .unwrap();
1076
1077 assert_eq!(result.epoch, 5);
1078 assert_eq!(result.operators_restored, 1);
1079 assert_eq!(result.sources_seeked, 1);
1080 assert!(src.position.is_some());
1081 }
1082
1083 #[tokio::test]
1084 async fn test_typed_recovery_unreachable_position() {
1085 let ckpt = setup_checkpointer();
1086 let id = CheckpointId::now();
1087
1088 save_checkpoint(
1089 &ckpt,
1090 &id,
1091 5,
1092 vec![("op1", b"state", false)],
1093 vec![(
1094 "kafka-src",
1095 HashMap::from([
1096 ("group_id".into(), "g1".into()),
1097 ("events-0".into(), "100".into()),
1098 ]),
1099 )],
1100 None,
1101 )
1102 .await;
1103
1104 {
1106 let mut manifest = ckpt.load_manifest(&id).await.unwrap();
1107 manifest
1108 .source_offsets
1109 .get_mut("kafka-src")
1110 .unwrap()
1111 .source_type = "kafka".into();
1112 ckpt.save_manifest(&manifest).await.unwrap();
1113 }
1114
1115 let rm = RecoveryManager::new(
1116 ckpt,
1117 RecoveryConfig {
1118 max_fallback_attempts: 1,
1119 ..Default::default()
1120 },
1121 );
1122 let mut op = TestRestorable::new("op1");
1123 let mut src = TestTypedSeekable::unreachable("kafka-src");
1124
1125 let result = rm.recover_typed(&mut [&mut op], &mut [&mut src]).await;
1126 assert!(result.is_err());
1127 }
1128}