1use std::collections::HashMap;
30
31use fsqlite_types::CommitSeq;
32
33pub const SQLITE_MAX_ATTACHED: usize = 10;
35
36pub const MAX_TOTAL_DATABASES: usize = SQLITE_MAX_ATTACHED + 2;
38
39pub type DatabaseId = u32;
41
42pub const MAIN_DB_ID: DatabaseId = 0;
44pub const TEMP_DB_ID: DatabaseId = 1;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
48pub enum TwoPhaseState {
49 Idle,
51 Preparing,
53 AllPrepared,
55 MarkerWritten,
57 Committing,
59 Committed,
61 Aborted,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum TwoPhaseError {
68 InvalidState(TwoPhaseState),
70 PrepareFailed { db_id: DatabaseId, reason: String },
72 TooManyDatabases { count: usize, max: usize },
74 DetachWithActiveTransaction { db_id: DatabaseId },
76 UnknownDatabase(DatabaseId),
78 DuplicateDatabase(DatabaseId),
80 MarkerWriteError(String),
82 WalIndexUpdateError { db_id: DatabaseId, reason: String },
84 NotWalMode(DatabaseId),
86}
87
88impl std::fmt::Display for TwoPhaseError {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 match self {
91 Self::InvalidState(state) => write!(f, "2PC invalid state: {state:?}"),
92 Self::PrepareFailed { db_id, reason } => {
93 write!(f, "2PC prepare failed for db {db_id}: {reason}")
94 }
95 Self::TooManyDatabases { count, max } => {
96 write!(f, "too many databases: {count} exceeds max {max}")
97 }
98 Self::DetachWithActiveTransaction { db_id } => {
99 write!(f, "cannot detach db {db_id}: active transaction")
100 }
101 Self::UnknownDatabase(db_id) => write!(f, "unknown database: {db_id}"),
102 Self::DuplicateDatabase(db_id) => {
103 write!(f, "database {db_id} already registered for 2PC")
104 }
105 Self::MarkerWriteError(reason) => write!(f, "commit marker write error: {reason}"),
106 Self::WalIndexUpdateError { db_id, reason } => {
107 write!(f, "WAL-index update error for db {db_id}: {reason}")
108 }
109 Self::NotWalMode(db_id) => write!(f, "database {db_id} not in WAL mode"),
110 }
111 }
112}
113
114impl std::error::Error for TwoPhaseError {}
115
116#[derive(Debug, Clone, PartialEq, Eq)]
118pub enum PrepareResult {
119 Ok {
121 wal_offset: u64,
123 frame_count: u32,
125 },
126 Failed(String),
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct ParticipantState {
133 pub db_id: DatabaseId,
135 pub schema_name: String,
137 pub wal_mode: bool,
139 pub prepare_result: Option<PrepareResult>,
141 pub wal_index_updated: bool,
143}
144
145impl ParticipantState {
146 #[must_use]
148 pub fn new(db_id: DatabaseId, schema_name: String, wal_mode: bool) -> Self {
149 Self {
150 db_id,
151 schema_name,
152 wal_mode,
153 prepare_result: None,
154 wal_index_updated: false,
155 }
156 }
157
158 #[must_use]
160 pub fn is_prepared(&self) -> bool {
161 matches!(self.prepare_result, Some(PrepareResult::Ok { .. }))
162 }
163
164 #[must_use]
166 pub const fn is_committed(&self) -> bool {
167 self.wal_index_updated
168 }
169}
170
171#[derive(Debug, Clone, PartialEq, Eq)]
176pub struct GlobalCommitMarker {
177 pub magic: [u8; 4],
179 pub txn_id: u64,
181 pub commit_seq: CommitSeq,
183 pub participants: Vec<(DatabaseId, u64)>,
185 pub timestamp_ns: u64,
187}
188
189pub const COMMIT_MARKER_MAGIC: [u8; 4] = [b'2', b'P', b'C', b'M'];
191
192pub const COMMIT_MARKER_MIN_SIZE: usize = 32;
195
196impl GlobalCommitMarker {
197 #[must_use]
199 pub fn new(
200 txn_id: u64,
201 commit_seq: CommitSeq,
202 participants: Vec<(DatabaseId, u64)>,
203 timestamp_ns: u64,
204 ) -> Self {
205 Self {
206 magic: COMMIT_MARKER_MAGIC,
207 txn_id,
208 commit_seq,
209 participants,
210 timestamp_ns,
211 }
212 }
213
214 #[must_use]
216 pub fn to_bytes(&self) -> Vec<u8> {
217 let participant_count = u32::try_from(self.participants.len()).unwrap_or(u32::MAX);
218 let mut buf = Vec::with_capacity(COMMIT_MARKER_MIN_SIZE + self.participants.len() * 12);
219 buf.extend_from_slice(&self.magic);
220 buf.extend_from_slice(&self.txn_id.to_le_bytes());
221 buf.extend_from_slice(&self.commit_seq.get().to_le_bytes());
222 buf.extend_from_slice(&participant_count.to_le_bytes());
223 buf.extend_from_slice(&self.timestamp_ns.to_le_bytes());
224 for &(db_id, wal_offset) in &self.participants {
225 buf.extend_from_slice(&db_id.to_le_bytes());
226 buf.extend_from_slice(&wal_offset.to_le_bytes());
227 }
228 buf
229 }
230
231 pub fn from_bytes(data: &[u8]) -> Option<Self> {
235 if data.len() < COMMIT_MARKER_MIN_SIZE {
236 return None;
237 }
238 let magic: [u8; 4] = data[..4].try_into().ok()?;
239 if magic != COMMIT_MARKER_MAGIC {
240 return None;
241 }
242 let txn_id = u64::from_le_bytes(data[4..12].try_into().ok()?);
243 let commit_seq_raw = u64::from_le_bytes(data[12..20].try_into().ok()?);
244 let participant_count = u32::from_le_bytes(data[20..24].try_into().ok()?);
245 let timestamp_ns = u64::from_le_bytes(data[24..32].try_into().ok()?);
246
247 let count = usize::try_from(participant_count).ok()?;
248 let needed = count.checked_mul(12)?.checked_add(COMMIT_MARKER_MIN_SIZE)?;
249 if data.len() < needed {
250 return None;
251 }
252
253 let mut participants = Vec::with_capacity(count);
254 for i in 0..count {
255 let base = COMMIT_MARKER_MIN_SIZE + i * 12;
256 let db_id = u32::from_le_bytes(data[base..base + 4].try_into().ok()?);
257 let wal_offset = u64::from_le_bytes(data[base + 4..base + 12].try_into().ok()?);
258 participants.push((db_id, wal_offset));
259 }
260
261 Some(Self {
262 magic,
263 txn_id,
264 commit_seq: CommitSeq::new(commit_seq_raw),
265 participants,
266 timestamp_ns,
267 })
268 }
269
270 #[must_use]
272 pub const fn is_valid(&self) -> bool {
273 self.magic[0] == b'2'
274 && self.magic[1] == b'P'
275 && self.magic[2] == b'C'
276 && self.magic[3] == b'M'
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq)]
282pub enum RecoveryAction {
283 NoAction,
285 RollForward,
288 RollBack,
291}
292
293#[derive(Debug)]
298pub struct TwoPhaseCoordinator {
299 state: TwoPhaseState,
301 participants: HashMap<DatabaseId, ParticipantState>,
303 commit_marker: Option<GlobalCommitMarker>,
305 txn_id: u64,
307}
308
309impl TwoPhaseCoordinator {
310 #[must_use]
312 pub fn new(txn_id: u64) -> Self {
313 Self {
314 state: TwoPhaseState::Idle,
315 participants: HashMap::new(),
316 commit_marker: None,
317 txn_id,
318 }
319 }
320
321 #[must_use]
323 pub const fn state(&self) -> TwoPhaseState {
324 self.state
325 }
326
327 #[must_use]
329 pub fn participant_count(&self) -> usize {
330 self.participants.len()
331 }
332
333 #[must_use]
335 pub const fn txn_id(&self) -> u64 {
336 self.txn_id
337 }
338
339 pub fn add_participant(
341 &mut self,
342 db_id: DatabaseId,
343 schema_name: String,
344 wal_mode: bool,
345 ) -> Result<(), TwoPhaseError> {
346 if self.state != TwoPhaseState::Idle {
347 return Err(TwoPhaseError::InvalidState(self.state));
348 }
349 if !wal_mode {
350 return Err(TwoPhaseError::NotWalMode(db_id));
351 }
352 if self.participants.contains_key(&db_id) {
353 return Err(TwoPhaseError::DuplicateDatabase(db_id));
354 }
355 if self.participants.len() >= MAX_TOTAL_DATABASES {
356 return Err(TwoPhaseError::TooManyDatabases {
357 count: self.participants.len() + 1,
358 max: MAX_TOTAL_DATABASES,
359 });
360 }
361 self.participants
362 .insert(db_id, ParticipantState::new(db_id, schema_name, wal_mode));
363 Ok(())
364 }
365
366 pub fn check_detach(&self, db_id: DatabaseId) -> Result<(), TwoPhaseError> {
370 if self.state != TwoPhaseState::Idle
371 && self.state != TwoPhaseState::Committed
372 && self.state != TwoPhaseState::Aborted
373 && self.participants.contains_key(&db_id)
374 {
375 return Err(TwoPhaseError::DetachWithActiveTransaction { db_id });
376 }
377 Ok(())
378 }
379
380 pub fn prepare_participant(
385 &mut self,
386 db_id: DatabaseId,
387 result: PrepareResult,
388 ) -> Result<(), TwoPhaseError> {
389 if self.state != TwoPhaseState::Idle && self.state != TwoPhaseState::Preparing {
390 return Err(TwoPhaseError::InvalidState(self.state));
391 }
392 let participant = self
393 .participants
394 .get_mut(&db_id)
395 .ok_or(TwoPhaseError::UnknownDatabase(db_id))?;
396 self.state = TwoPhaseState::Preparing;
397 participant.prepare_result = Some(result);
398 Ok(())
399 }
400
401 pub fn check_all_prepared(&mut self) -> Result<(), TwoPhaseError> {
405 if self.state != TwoPhaseState::Preparing {
406 return Err(TwoPhaseError::InvalidState(self.state));
407 }
408 for participant in self.participants.values() {
409 match &participant.prepare_result {
410 None => {
411 return Err(TwoPhaseError::PrepareFailed {
412 db_id: participant.db_id,
413 reason: "not yet prepared".to_owned(),
414 });
415 }
416 Some(PrepareResult::Failed(reason)) => {
417 return Err(TwoPhaseError::PrepareFailed {
418 db_id: participant.db_id,
419 reason: reason.clone(),
420 });
421 }
422 Some(PrepareResult::Ok { .. }) => {}
423 }
424 }
425 self.state = TwoPhaseState::AllPrepared;
426 Ok(())
427 }
428
429 pub fn write_commit_marker(
435 &mut self,
436 commit_seq: CommitSeq,
437 timestamp_ns: u64,
438 ) -> Result<GlobalCommitMarker, TwoPhaseError> {
439 if self.state != TwoPhaseState::AllPrepared {
440 return Err(TwoPhaseError::InvalidState(self.state));
441 }
442
443 let mut participants: Vec<(DatabaseId, u64)> = self
444 .participants
445 .values()
446 .filter_map(|p| {
447 if let Some(PrepareResult::Ok { wal_offset, .. }) = &p.prepare_result {
448 Some((p.db_id, *wal_offset))
449 } else {
450 None
451 }
452 })
453 .collect();
454 participants.sort_unstable_by_key(|&(db_id, _)| db_id);
455
456 let marker = GlobalCommitMarker::new(self.txn_id, commit_seq, participants, timestamp_ns);
457 self.commit_marker = Some(marker.clone());
458 self.state = TwoPhaseState::MarkerWritten;
459 Ok(marker)
460 }
461
462 pub fn commit_participant(&mut self, db_id: DatabaseId) -> Result<(), TwoPhaseError> {
466 if self.state != TwoPhaseState::MarkerWritten && self.state != TwoPhaseState::Committing {
467 return Err(TwoPhaseError::InvalidState(self.state));
468 }
469 let participant = self
470 .participants
471 .get_mut(&db_id)
472 .ok_or(TwoPhaseError::UnknownDatabase(db_id))?;
473 self.state = TwoPhaseState::Committing;
474 participant.wal_index_updated = true;
475 Ok(())
476 }
477
478 pub fn check_all_committed(&mut self) -> Result<(), TwoPhaseError> {
480 if self.state != TwoPhaseState::Committing {
481 return Err(TwoPhaseError::InvalidState(self.state));
482 }
483 for participant in self.participants.values() {
484 if !participant.wal_index_updated {
485 return Err(TwoPhaseError::WalIndexUpdateError {
486 db_id: participant.db_id,
487 reason: "WAL-index not yet updated".to_owned(),
488 });
489 }
490 }
491 self.state = TwoPhaseState::Committed;
492 Ok(())
493 }
494
495 pub fn abort(&mut self) -> Result<(), TwoPhaseError> {
501 if matches!(
502 self.state,
503 TwoPhaseState::MarkerWritten | TwoPhaseState::Committing | TwoPhaseState::Committed
504 ) {
505 return Err(TwoPhaseError::InvalidState(self.state));
506 }
507 self.state = TwoPhaseState::Aborted;
508 self.commit_marker = None;
509 Ok(())
510 }
511
512 #[must_use]
517 pub fn determine_recovery(marker_found: bool, all_wal_indices_updated: bool) -> RecoveryAction {
518 if !marker_found {
519 if all_wal_indices_updated {
520 RecoveryAction::NoAction
521 } else {
522 RecoveryAction::RollBack
523 }
524 } else if all_wal_indices_updated {
525 RecoveryAction::NoAction
526 } else {
527 RecoveryAction::RollForward
528 }
529 }
530
531 #[must_use]
533 pub fn commit_marker(&self) -> Option<&GlobalCommitMarker> {
534 self.commit_marker.as_ref()
535 }
536
537 #[must_use]
539 pub const fn is_committed(&self) -> bool {
540 matches!(self.state, TwoPhaseState::Committed)
541 }
542
543 #[must_use]
545 pub const fn is_aborted(&self) -> bool {
546 matches!(self.state, TwoPhaseState::Aborted)
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use fsqlite_types::CommitSeq;
553
554 use super::{
555 COMMIT_MARKER_MAGIC, GlobalCommitMarker, MAIN_DB_ID, MAX_TOTAL_DATABASES, PrepareResult,
556 RecoveryAction, SQLITE_MAX_ATTACHED, TEMP_DB_ID, TwoPhaseCoordinator, TwoPhaseError,
557 TwoPhaseState,
558 };
559
560 #[test]
564 fn test_cross_database_two_phase_commit() {
565 let mut coord = TwoPhaseCoordinator::new(1);
566
567 coord
569 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
570 .expect("add main");
571 coord
572 .add_participant(2, "aux".to_owned(), true)
573 .expect("add aux");
574
575 coord
577 .prepare_participant(
578 MAIN_DB_ID,
579 PrepareResult::Ok {
580 wal_offset: 4096,
581 frame_count: 2,
582 },
583 )
584 .expect("prepare main");
585 coord
586 .prepare_participant(
587 2,
588 PrepareResult::Ok {
589 wal_offset: 8192,
590 frame_count: 3,
591 },
592 )
593 .expect("prepare aux");
594 coord.check_all_prepared().expect("all prepared");
595
596 let marker = coord
598 .write_commit_marker(CommitSeq::new(100), 1_000_000)
599 .expect("marker");
600 assert!(marker.is_valid());
601 assert_eq!(marker.participants.len(), 2);
602
603 coord.commit_participant(MAIN_DB_ID).expect("commit main");
605 coord.commit_participant(2).expect("commit aux");
606 coord.check_all_committed().expect("all committed");
607
608 assert!(coord.is_committed());
609 }
610
611 #[test]
615 fn test_cross_db_2pc_one_db_fails_prepare() {
616 let mut coord = TwoPhaseCoordinator::new(2);
617 coord
618 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
619 .unwrap();
620 coord.add_participant(2, "aux".to_owned(), true).unwrap();
621
622 coord
624 .prepare_participant(
625 MAIN_DB_ID,
626 PrepareResult::Ok {
627 wal_offset: 4096,
628 frame_count: 1,
629 },
630 )
631 .unwrap();
632 coord
633 .prepare_participant(2, PrepareResult::Failed("disk full".to_owned()))
634 .unwrap();
635
636 let result = coord.check_all_prepared();
637 assert!(result.is_err());
638 let err = result.unwrap_err();
639 assert!(
640 matches!(err, TwoPhaseError::PrepareFailed { db_id: 2, .. }),
641 "expected prepare failure for db 2: {err:?}"
642 );
643
644 coord.abort().expect("abort");
646 assert!(coord.is_aborted());
647 }
648
649 #[test]
653 fn test_attach_detach_limit() {
654 let mut coord = TwoPhaseCoordinator::new(3);
655
656 coord
658 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
659 .unwrap();
660 coord
661 .add_participant(TEMP_DB_ID, "temp".to_owned(), true)
662 .unwrap();
663 for i in 0..SQLITE_MAX_ATTACHED {
664 let db_id = u32::try_from(i + 2).expect("fits in u32");
665 coord
666 .add_participant(db_id, format!("aux{i}"), true)
667 .unwrap();
668 }
669 assert_eq!(coord.participant_count(), MAX_TOTAL_DATABASES);
670
671 let result = coord.add_participant(99, "overflow".to_owned(), true);
673 assert!(
674 matches!(result, Err(TwoPhaseError::TooManyDatabases { .. })),
675 "expected too many databases: {result:?}"
676 );
677 }
678
679 #[test]
683 fn test_cross_db_2pc_max_attached() {
684 let mut coord = TwoPhaseCoordinator::new(4);
685
686 coord
688 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
689 .unwrap();
690 for i in 0..SQLITE_MAX_ATTACHED {
691 let db_id = u32::try_from(i + 2).expect("fits in u32");
692 coord
693 .add_participant(db_id, format!("aux{i}"), true)
694 .unwrap();
695 }
696
697 for &db_id in coord.participants.clone().keys() {
699 coord
700 .prepare_participant(
701 db_id,
702 PrepareResult::Ok {
703 wal_offset: u64::from(db_id) * 4096,
704 frame_count: 1,
705 },
706 )
707 .unwrap();
708 }
709 coord.check_all_prepared().unwrap();
710
711 coord
713 .write_commit_marker(CommitSeq::new(200), 2_000_000)
714 .unwrap();
715 for &db_id in &coord.participants.keys().copied().collect::<Vec<_>>() {
716 coord.commit_participant(db_id).unwrap();
717 }
718 coord.check_all_committed().unwrap();
719 assert!(coord.is_committed());
720 }
721
722 #[test]
726 fn test_cross_db_2pc_wal_mode_required() {
727 let mut coord = TwoPhaseCoordinator::new(5);
728 let result = coord.add_participant(MAIN_DB_ID, "main".to_owned(), false);
729 assert!(
730 matches!(result, Err(TwoPhaseError::NotWalMode(MAIN_DB_ID))),
731 "expected NotWalMode error: {result:?}"
732 );
733 }
734
735 #[test]
736 fn test_add_participant_rejects_duplicate_database() {
737 let mut coord = TwoPhaseCoordinator::new(55);
738 coord
739 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
740 .unwrap();
741
742 let result = coord.add_participant(MAIN_DB_ID, "main_shadow".to_owned(), true);
743 assert!(
744 matches!(result, Err(TwoPhaseError::DuplicateDatabase(MAIN_DB_ID))),
745 "expected duplicate database error: {result:?}"
746 );
747 assert_eq!(coord.participant_count(), 1);
748 assert_eq!(coord.participants[&MAIN_DB_ID].schema_name, "main");
749 }
750
751 #[test]
755 fn test_detach_with_active_transaction() {
756 let mut coord = TwoPhaseCoordinator::new(6);
757 coord
758 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
759 .unwrap();
760 coord.add_participant(2, "aux".to_owned(), true).unwrap();
761
762 coord
764 .prepare_participant(
765 MAIN_DB_ID,
766 PrepareResult::Ok {
767 wal_offset: 4096,
768 frame_count: 1,
769 },
770 )
771 .unwrap();
772
773 let result = coord.check_detach(2);
775 assert!(matches!(
776 result,
777 Err(TwoPhaseError::DetachWithActiveTransaction { db_id: 2 })
778 ));
779
780 assert!(coord.check_detach(99).is_ok());
782 }
783
784 #[test]
788 fn test_commit_marker_roundtrip() {
789 let marker = GlobalCommitMarker::new(
790 42,
791 CommitSeq::new(100),
792 vec![(MAIN_DB_ID, 4096), (2, 8192), (3, 12288)],
793 1_000_000_000,
794 );
795 assert!(marker.is_valid());
796
797 let bytes = marker.to_bytes();
798 let decoded = GlobalCommitMarker::from_bytes(&bytes).expect("decode should succeed");
799 assert_eq!(decoded.magic, COMMIT_MARKER_MAGIC);
800 assert_eq!(decoded.txn_id, 42);
801 assert_eq!(decoded.commit_seq, CommitSeq::new(100));
802 assert_eq!(decoded.participants.len(), 3);
803 assert_eq!(decoded.timestamp_ns, 1_000_000_000);
804 assert_eq!(decoded.participants[0], (MAIN_DB_ID, 4096));
805 assert_eq!(decoded.participants[1], (2, 8192));
806 assert_eq!(decoded.participants[2], (3, 12288));
807 }
808
809 #[test]
813 fn test_commit_marker_invalid() {
814 assert!(GlobalCommitMarker::from_bytes(&[0; 10]).is_none());
816
817 let mut bad = vec![0u8; 32];
819 bad[0] = b'X';
820 assert!(GlobalCommitMarker::from_bytes(&bad).is_none());
821
822 let marker = GlobalCommitMarker::new(1, CommitSeq::new(1), vec![(0, 100)], 0);
824 let bytes = marker.to_bytes();
825 let truncated = &bytes[..bytes.len() - 4];
826 assert!(GlobalCommitMarker::from_bytes(truncated).is_none());
827 }
828
829 #[test]
833 fn test_recovery_actions() {
834 assert_eq!(
836 TwoPhaseCoordinator::determine_recovery(false, true),
837 RecoveryAction::NoAction
838 );
839
840 assert_eq!(
843 TwoPhaseCoordinator::determine_recovery(false, false),
844 RecoveryAction::RollBack
845 );
846
847 assert_eq!(
849 TwoPhaseCoordinator::determine_recovery(true, false),
850 RecoveryAction::RollForward
851 );
852
853 assert_eq!(
855 TwoPhaseCoordinator::determine_recovery(true, true),
856 RecoveryAction::NoAction
857 );
858 }
859
860 #[test]
864 fn test_state_machine_invalid_transitions() {
865 let mut coord = TwoPhaseCoordinator::new(10);
866
867 assert!(matches!(
869 coord.check_all_prepared(),
870 Err(TwoPhaseError::InvalidState(TwoPhaseState::Idle))
871 ));
872
873 assert!(matches!(
875 coord.write_commit_marker(CommitSeq::new(1), 0),
876 Err(TwoPhaseError::InvalidState(TwoPhaseState::Idle))
877 ));
878
879 assert!(matches!(
881 coord.commit_participant(MAIN_DB_ID),
882 Err(TwoPhaseError::InvalidState(TwoPhaseState::Idle))
883 ));
884
885 let mut coord2 = TwoPhaseCoordinator::new(11);
887 coord2
888 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
889 .unwrap();
890 coord2
891 .prepare_participant(
892 MAIN_DB_ID,
893 PrepareResult::Ok {
894 wal_offset: 0,
895 frame_count: 0,
896 },
897 )
898 .unwrap();
899 coord2.check_all_prepared().unwrap();
900 coord2.write_commit_marker(CommitSeq::new(1), 0).unwrap();
901 coord2.commit_participant(MAIN_DB_ID).unwrap();
902 coord2.check_all_committed().unwrap();
903 assert!(matches!(
904 coord2.abort(),
905 Err(TwoPhaseError::InvalidState(TwoPhaseState::Committed))
906 ));
907 }
908
909 #[test]
913 fn test_cross_db_2pc_both_committed() {
914 let mut coord = TwoPhaseCoordinator::new(11);
915 coord
916 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
917 .unwrap();
918 coord.add_participant(2, "aux".to_owned(), true).unwrap();
919
920 coord
922 .prepare_participant(
923 MAIN_DB_ID,
924 PrepareResult::Ok {
925 wal_offset: 4096,
926 frame_count: 5,
927 },
928 )
929 .unwrap();
930 coord
931 .prepare_participant(
932 2,
933 PrepareResult::Ok {
934 wal_offset: 8192,
935 frame_count: 3,
936 },
937 )
938 .unwrap();
939 coord.check_all_prepared().unwrap();
940
941 assert!(coord.participants[&MAIN_DB_ID].is_prepared());
943 assert!(coord.participants[&2].is_prepared());
944
945 let marker = coord
947 .write_commit_marker(CommitSeq::new(50), 500_000)
948 .unwrap();
949 assert_eq!(marker.participants.len(), 2);
950
951 coord.commit_participant(MAIN_DB_ID).unwrap();
952 coord.commit_participant(2).unwrap();
953 coord.check_all_committed().unwrap();
954
955 assert!(coord.participants[&MAIN_DB_ID].is_committed());
957 assert!(coord.participants[&2].is_committed());
958 assert!(coord.is_committed());
959 }
960
961 #[test]
965 fn test_cross_db_2pc_crash_after_prepare() {
966 let mut coord = TwoPhaseCoordinator::new(12);
967 coord
968 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
969 .unwrap();
970 coord.add_participant(2, "aux".to_owned(), true).unwrap();
971
972 coord
973 .prepare_participant(
974 MAIN_DB_ID,
975 PrepareResult::Ok {
976 wal_offset: 4096,
977 frame_count: 2,
978 },
979 )
980 .unwrap();
981 coord
982 .prepare_participant(
983 2,
984 PrepareResult::Ok {
985 wal_offset: 8192,
986 frame_count: 2,
987 },
988 )
989 .unwrap();
990 coord.check_all_prepared().unwrap();
991
992 let recovery = TwoPhaseCoordinator::determine_recovery(false, false);
994 assert!(matches!(
995 recovery,
996 RecoveryAction::RollBack | RecoveryAction::RollForward
997 ));
998
999 match recovery {
1000 RecoveryAction::RollBack => {
1001 coord.abort().unwrap();
1002 assert!(coord.is_aborted());
1003 assert!(!coord.is_committed());
1004 }
1005 RecoveryAction::RollForward => {
1006 coord
1007 .write_commit_marker(CommitSeq::new(320), 3_200_000)
1008 .unwrap();
1009 for db_id in [MAIN_DB_ID, 2] {
1010 coord.commit_participant(db_id).unwrap();
1011 }
1012 coord.check_all_committed().unwrap();
1013 assert!(coord.is_committed());
1014 }
1015 RecoveryAction::NoAction => panic!("recovery cannot be NoAction after crash"),
1016 }
1017 }
1018
1019 #[test]
1023 fn test_cross_db_2pc_crash_during_phase2() {
1024 let mut coord = TwoPhaseCoordinator::new(13);
1025 coord
1026 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
1027 .unwrap();
1028 coord.add_participant(2, "aux".to_owned(), true).unwrap();
1029
1030 coord
1031 .prepare_participant(
1032 MAIN_DB_ID,
1033 PrepareResult::Ok {
1034 wal_offset: 4096,
1035 frame_count: 1,
1036 },
1037 )
1038 .unwrap();
1039 coord
1040 .prepare_participant(
1041 2,
1042 PrepareResult::Ok {
1043 wal_offset: 8192,
1044 frame_count: 1,
1045 },
1046 )
1047 .unwrap();
1048 coord.check_all_prepared().unwrap();
1049 coord
1050 .write_commit_marker(CommitSeq::new(330), 3_300_000)
1051 .unwrap();
1052
1053 coord.commit_participant(MAIN_DB_ID).unwrap();
1055 let recovery = TwoPhaseCoordinator::determine_recovery(true, false);
1056 assert_eq!(recovery, RecoveryAction::RollForward);
1057
1058 coord.commit_participant(2).unwrap();
1060 coord.check_all_committed().unwrap();
1061 assert!(coord.is_committed());
1062 }
1063
1064 #[test]
1068 fn test_2pc_abort_before_marker() {
1069 let mut coord = TwoPhaseCoordinator::new(14);
1070 coord
1071 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
1072 .unwrap();
1073 coord.add_participant(2, "aux".to_owned(), true).unwrap();
1074
1075 coord
1077 .prepare_participant(
1078 MAIN_DB_ID,
1079 PrepareResult::Ok {
1080 wal_offset: 4096,
1081 frame_count: 1,
1082 },
1083 )
1084 .unwrap();
1085
1086 coord.abort().expect("abort should succeed");
1088 assert!(coord.is_aborted());
1089 assert!(coord.commit_marker().is_none());
1090 }
1091
1092 #[test]
1093 fn test_2pc_abort_after_marker_rejected() {
1094 let mut coord = TwoPhaseCoordinator::new(15);
1095 coord
1096 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
1097 .unwrap();
1098 coord
1099 .prepare_participant(
1100 MAIN_DB_ID,
1101 PrepareResult::Ok {
1102 wal_offset: 4096,
1103 frame_count: 1,
1104 },
1105 )
1106 .unwrap();
1107 coord.check_all_prepared().unwrap();
1108 coord.write_commit_marker(CommitSeq::new(1), 0).unwrap();
1109
1110 let err = coord.abort();
1112 assert!(matches!(
1113 err,
1114 Err(TwoPhaseError::InvalidState(TwoPhaseState::MarkerWritten))
1115 ));
1116 }
1117
1118 #[test]
1119 fn test_prepare_unknown_database_does_not_advance_state() {
1120 let mut coord = TwoPhaseCoordinator::new(16);
1121 coord
1122 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
1123 .unwrap();
1124
1125 let err = coord.prepare_participant(
1126 99,
1127 PrepareResult::Ok {
1128 wal_offset: 4096,
1129 frame_count: 1,
1130 },
1131 );
1132 assert!(matches!(err, Err(TwoPhaseError::UnknownDatabase(99))));
1133 assert_eq!(coord.state(), TwoPhaseState::Idle);
1134 assert_eq!(coord.participants[&MAIN_DB_ID].prepare_result, None);
1135 }
1136
1137 #[test]
1138 fn test_commit_unknown_database_does_not_advance_state() {
1139 let mut coord = TwoPhaseCoordinator::new(17);
1140 coord
1141 .add_participant(MAIN_DB_ID, "main".to_owned(), true)
1142 .unwrap();
1143 coord
1144 .prepare_participant(
1145 MAIN_DB_ID,
1146 PrepareResult::Ok {
1147 wal_offset: 4096,
1148 frame_count: 1,
1149 },
1150 )
1151 .unwrap();
1152 coord.check_all_prepared().unwrap();
1153 coord.write_commit_marker(CommitSeq::new(1), 0).unwrap();
1154
1155 let err = coord.commit_participant(99);
1156 assert!(matches!(err, Err(TwoPhaseError::UnknownDatabase(99))));
1157 assert_eq!(coord.state(), TwoPhaseState::MarkerWritten);
1158 assert!(!coord.participants[&MAIN_DB_ID].wal_index_updated);
1159 }
1160}