1use gbp::{CodecError, ControlMessage, ErrorObject, GbpFrame};
23use gbp_core::{
24 ControlOpcode, ErrorClass, GbpFlags, GroupId, MemberId, NodeState, PayloadCodec, SequenceNo,
25 StreamId, StreamType, TransitionId, TransitionState, codes, errors::ErrorSpec, timeouts,
26};
27use gbp_mls::{MlsError, label_for};
28use std::collections::HashMap;
29use std::time::Duration;
30#[cfg(not(target_arch = "wasm32"))]
31use std::time::Instant;
32#[cfg(target_arch = "wasm32")]
33use web_time::Instant;
34
35#[derive(Debug, thiserror::Error)]
37pub enum NodeError {
38 #[error("codec: {0}")]
40 Codec(#[from] CodecError),
41 #[error("mls: {0}")]
43 Mls(#[from] MlsError),
44 #[error("invalid state: {0}")]
46 InvalidState(String),
47}
48
49pub struct OutboundFrame {
51 pub to: MemberId,
53 pub wire: Vec<u8>,
55}
56
57#[derive(Debug, Clone)]
59pub struct DeliveredPayload {
60 pub stream_type: StreamType,
62 pub stream_id: StreamId,
65 pub sequence_no: SequenceNo,
67 pub flags: u16,
69 pub plaintext: Vec<u8>,
71 pub codec: PayloadCodec,
73}
74
75#[derive(Debug, Clone)]
77pub enum Event {
78 StateChanged {
80 from: NodeState,
82 to: NodeState,
84 },
85 PayloadReceived(DeliveredPayload),
89 Control {
91 from: MemberId,
93 opcode: ControlOpcode,
95 transition_id: TransitionId,
97 request_id: u32,
99 args: Vec<u8>,
102 },
103 Error {
105 code: u16,
107 class: ErrorClass,
109 retryable: bool,
111 fatal: bool,
113 reason: String,
115 },
116 EpochAdvanced {
118 epoch: u64,
120 transition_id: TransitionId,
122 },
123 CoordinatorElectionNeeded,
127 BecameCoordinator,
130 CoordinatorClaim {
132 claimant: MemberId,
134 },
135}
136
137pub struct GroupNode {
144 pub member_id: MemberId,
146 pub is_coordinator: bool,
148 pub group_id: GroupId,
150 pub current_epoch: u64,
153 pub last_transition_id: TransitionId,
155 pub pending_transition_id: TransitionId,
157 pub state: NodeState,
159 pub transition_state: TransitionState,
161
162 out_seq: HashMap<(StreamType, StreamId), SequenceNo>,
163 in_hw: HashMap<(StreamType, StreamId), SequenceNo>,
164 events: Vec<Event>,
165
166 pending_commit_sender: Option<MemberId>,
170 prepare_deadline: Option<Instant>,
173 execute_deadline: Option<Instant>,
176 coordinator_last_seen: Option<Instant>,
179}
180
181impl GroupNode {
182 pub fn new(member_id: MemberId, group_id: GroupId) -> Self {
184 Self {
185 member_id,
186 group_id,
187 is_coordinator: false,
188 current_epoch: 0,
189 last_transition_id: 0,
190 pending_transition_id: 0,
191 state: NodeState::Idle,
192 transition_state: TransitionState::TIdle,
193 out_seq: HashMap::new(),
194 in_hw: HashMap::new(),
195 events: Vec::new(),
196 pending_commit_sender: None,
197 prepare_deadline: None,
198 execute_deadline: None,
199 coordinator_last_seen: None,
200 }
201 }
202
203 pub fn bootstrap_as_creator(&mut self, epoch: u64) {
205 self.transition(NodeState::Connecting);
206 self.transition(NodeState::EstablishingGroup);
207 self.current_epoch = epoch;
208 self.transition(NodeState::Active);
209 }
210
211 pub fn bootstrap_as_joiner(&mut self, epoch: u64, expected_first_tid: u32) {
220 self.transition(NodeState::Connecting);
221 self.transition(NodeState::EstablishingGroup);
222 self.current_epoch = epoch;
223 if expected_first_tid > 0 {
224 self.pending_transition_id = expected_first_tid;
225 self.transition_state = TransitionState::TPrepared;
226 }
227 self.transition(NodeState::Active);
228 }
229
230 pub fn drain_events(&mut self) -> Vec<Event> {
232 std::mem::take(&mut self.events)
233 }
234
235 pub fn member_stream_id(&self, base: u32) -> StreamId {
240 debug_assert!(
241 self.member_id < 1_000_000,
242 "member_id overflow: {0}",
243 self.member_id
244 );
245 base + self.member_id * 100
246 }
247
248 pub fn export_out_seq(&self) -> Vec<u8> {
259 let mut out = Vec::with_capacity(4 + self.out_seq.len() * 9);
260 out.extend_from_slice(&(self.out_seq.len() as u32).to_le_bytes());
261 for ((st, sid), seq) in &self.out_seq {
262 out.push(*st as u8);
263 out.extend_from_slice(&sid.to_le_bytes());
264 out.extend_from_slice(&seq.to_le_bytes());
265 }
266 out
267 }
268
269 pub fn restore_out_seq(&mut self, bytes: &[u8]) {
272 if bytes.len() < 4 {
273 return;
274 }
275 let n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
276 let mut cur = &bytes[4..];
277 for _ in 0..n {
278 if cur.len() < 9 {
279 break;
280 }
281 let st = match StreamType::try_from(cur[0]) {
282 Ok(s) => s,
283 Err(_) => break,
284 };
285 let sid = u32::from_le_bytes([cur[1], cur[2], cur[3], cur[4]]);
286 let seq = u32::from_le_bytes([cur[5], cur[6], cur[7], cur[8]]);
287 self.out_seq.insert((st, sid), seq);
288 cur = &cur[9..];
289 }
290 }
291
292 pub fn send_payload<S: Sealer>(
299 &mut self,
300 seal: &mut S,
301 target: MemberId,
302 stream_type: StreamType,
303 stream_id: StreamId,
304 flags: u16,
305 plaintext: &[u8],
306 codec: PayloadCodec,
307 ) -> Result<OutboundFrame, NodeError> {
308 self.assert_can_send()?;
309 let seq = self.next_seq(stream_type, stream_id);
310 let ciphertext = seal.seal(stream_type, seq, plaintext)?;
311 let frame = GbpFrame::new(
312 self.group_id,
313 self.current_epoch,
314 self.last_transition_id,
315 stream_type,
316 stream_id,
317 flags,
318 seq,
319 ciphertext,
320 codec.as_u8(),
321 );
322 Ok(OutboundFrame {
323 to: target,
324 wire: frame.to_cbor(),
325 })
326 }
327
328 pub fn send_control<S: Sealer>(
338 &mut self,
339 seal: &mut S,
340 target: MemberId,
341 opcode: ControlOpcode,
342 transition_id: TransitionId,
343 request_id: u32,
344 args: Vec<u8>,
345 ) -> Result<OutboundFrame, NodeError> {
346 let ctl = ControlMessage::with_args(
347 opcode as u16,
348 request_id,
349 self.member_id,
350 transition_id,
351 args,
352 );
353 let mut flags = GbpFlags::ordered_reliable_system();
354 if matches!(
355 opcode,
356 ControlOpcode::PrepareTransition
357 | ControlOpcode::ReadyForTransition
358 | ControlOpcode::ExecuteTransition
359 ) {
360 flags |= GbpFlags::CRITICAL;
361 }
362 match opcode {
366 ControlOpcode::PrepareTransition => {
367 self.pending_transition_id = transition_id;
368 self.transition_state = TransitionState::TPrepared;
369 self.prepare_deadline =
370 Some(Instant::now() + Duration::from_millis(timeouts::T_PREPARE_MAX_MS));
371 self.execute_deadline = None;
372 }
373 ControlOpcode::ReadyForTransition => {
374 self.execute_deadline =
375 Some(Instant::now() + Duration::from_millis(timeouts::T_EXECUTE_MAX_MS));
376 }
377 ControlOpcode::ExecuteTransition | ControlOpcode::AbortTransition => {
378 self.prepare_deadline = None;
379 self.execute_deadline = None;
380 if opcode == ControlOpcode::AbortTransition {
381 self.pending_transition_id = 0;
382 self.transition_state = TransitionState::TAborted;
383 }
384 }
385 _ => {}
386 }
387 let stream_id = self.member_stream_id(0);
388 self.send_payload(
389 seal,
390 target,
391 StreamType::Control,
392 stream_id,
393 flags,
394 &ctl.to_cbor(),
395 PayloadCodec::Cbor,
396 )
397 }
398
399 pub fn on_wire<S: Sealer>(
410 &mut self,
411 seal: &mut S,
412 wire: &[u8],
413 ) -> Result<Vec<Event>, NodeError> {
414 let frame = match GbpFrame::decode(wire) {
419 Ok(f) => f,
420 Err(e) => {
421 self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, format!("frame decode: {e}"));
422 return Ok(self.drain_events());
423 }
424 };
425 self.deliver_frame(seal, frame)?;
426 Ok(self.drain_events())
427 }
428
429 fn deliver_frame<S: Sealer>(&mut self, seal: &mut S, frame: GbpFrame) -> Result<(), NodeError> {
430 if frame.version != 1 {
433 self.emit_err_spec(codes::UNSUPPORTED_VERSION, "version != 1");
434 return Ok(());
435 }
436 if frame.group_id_array() != self.group_id {
437 self.emit_err_spec(codes::UNKNOWN_GROUP, "group_id");
438 return Ok(());
439 }
440 if frame.epoch != self.current_epoch {
441 self.emit_err_spec(
442 codes::EPOCH_MISMATCH,
443 format!("got {}, expected {}", frame.epoch, self.current_epoch),
444 );
445 self.trigger_resync();
446 return Ok(());
447 }
448 if let Err(e) = frame.validate_payload_size() {
449 self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, format!("payload size: {e}"));
450 return Ok(());
451 }
452 let flags = GbpFlags::from_bits(frame.flags);
453 let st = match frame.stream_type_typed() {
454 Ok(st) => st,
455 Err(_) => {
456 self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "unknown stream_type");
457 return Ok(());
458 }
459 };
460
461 if st != StreamType::Control
467 && flags.has(GbpFlags::CRITICAL)
468 && frame.transition_id != self.last_transition_id
469 {
470 self.emit_err_spec(
471 codes::TRANSITION_MISMATCH,
472 format!(
473 "got tid={}, expected {}",
474 frame.transition_id, self.last_transition_id
475 ),
476 );
477 return Ok(());
478 }
479
480 let key = (st, frame.stream_id);
481 let hw = self.in_hw.get(&key).copied().unwrap_or(0);
482 if frame.sequence_no <= hw {
483 self.emit_err_spec(
484 codes::REPLAY_DETECTED,
485 format!(
486 "st={} sid={} seq={} hw={}",
487 st, frame.stream_id, frame.sequence_no, hw
488 ),
489 );
490 return Ok(());
491 }
492 self.in_hw.insert(key, frame.sequence_no);
493
494 let plain = match seal.open(st, frame.sequence_no, &frame.encrypted_payload) {
495 Ok(p) => p,
496 Err(e) => {
497 self.emit_err_named(
504 codes::DECRYPT_FAILED,
505 ErrorClass::Crypto,
506 true, false, format!("aead open: {e}"),
509 );
510 return Ok(());
511 }
512 };
513
514 match st {
515 StreamType::Control => self.handle_control(plain),
516 other => self.events.push(Event::PayloadReceived(DeliveredPayload {
517 stream_type: other,
518 stream_id: frame.stream_id,
519 sequence_no: frame.sequence_no,
520 flags: frame.flags,
521 plaintext: plain,
522 codec: frame.payload_codec(),
523 })),
524 }
525 Ok(())
526 }
527
528 fn handle_control(&mut self, plain: Vec<u8>) {
529 let c = match ControlMessage::from_cbor(&plain) {
530 Ok(c) => c,
531 Err(_) => {
532 self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "control decode");
533 return;
534 }
535 };
536 let opcode = match ControlOpcode::try_from(c.opcode) {
537 Ok(op) => op,
538 Err(_) => {
539 self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "unknown opcode");
540 return;
541 }
542 };
543 let tid_ok = match opcode {
545 ControlOpcode::PrepareTransition => {
549 c.transition_id > self.last_transition_id
550 && (self.pending_transition_id == 0
551 || self.pending_transition_id == c.transition_id)
552 }
553 ControlOpcode::ReadyForTransition
555 | ControlOpcode::ExecuteTransition
556 | ControlOpcode::AbortTransition => {
557 self.pending_transition_id != 0 && c.transition_id == self.pending_transition_id
558 }
559 _ => true,
562 };
563 if !tid_ok {
564 self.emit_err_spec(
565 codes::TRANSITION_MISMATCH,
566 format!(
567 "control tid={} not valid for {:?} (last={}, pending={})",
568 c.transition_id, opcode, self.last_transition_id, self.pending_transition_id
569 ),
570 );
571 return;
572 }
573 match opcode {
574 ControlOpcode::PrepareTransition => {
575 if self.pending_transition_id == c.transition_id {
579 let current_winner = self.pending_commit_sender.unwrap_or(MemberId::MAX);
580 if c.sender_id >= current_winner {
581 self.events.push(Event::Control {
585 from: c.sender_id,
586 opcode,
587 transition_id: c.transition_id,
588 request_id: c.request_id,
589 args: c.args.to_vec(),
590 });
591 return;
592 }
593 }
595 self.pending_transition_id = c.transition_id;
596 self.pending_commit_sender = Some(c.sender_id);
597 self.transition_state = TransitionState::TPrepared;
598 self.note_coordinator_activity();
600 self.execute_deadline =
602 Some(Instant::now() + Duration::from_millis(timeouts::T_EXECUTE_MAX_MS));
603 }
604 ControlOpcode::ReadyForTransition => {
605 self.transition_state = TransitionState::TReady;
606 self.prepare_deadline = None;
608 }
609 ControlOpcode::ExecuteTransition => {
610 self.execute_deadline = None;
611 self.pending_commit_sender = None;
612 self.apply_transition(c.transition_id);
613 self.note_coordinator_activity();
614 }
615 ControlOpcode::AbortTransition => {
616 self.prepare_deadline = None;
617 self.execute_deadline = None;
618 self.pending_commit_sender = None;
619 self.transition_state = TransitionState::TAborted;
620 self.pending_transition_id = 0;
621 }
622 ControlOpcode::GroupStateDigestResponse => {
623 if self.state == NodeState::Resyncing {
624 self.transition(NodeState::Active);
625 }
626 }
627 ControlOpcode::CapabilitiesAdvertise => {
628 if Self::is_coordinator_claim(&c.args) {
629 self.note_coordinator_activity();
631 if self.is_coordinator && c.sender_id < self.member_id {
635 self.is_coordinator = false;
636 }
637 self.events.push(Event::CoordinatorClaim {
638 claimant: c.sender_id,
639 });
640 }
641 }
642 _ => {}
643 }
644 self.events.push(Event::Control {
645 from: c.sender_id,
646 opcode,
647 transition_id: c.transition_id,
648 request_id: c.request_id,
649 args: c.args.to_vec(),
650 });
651 }
652
653 pub fn apply_transition(&mut self, tid: TransitionId) {
656 self.current_epoch += 1;
657 self.last_transition_id = tid;
658 self.pending_transition_id = 0;
659 self.pending_commit_sender = None;
660 self.transition_state = TransitionState::TExecuted;
661 self.out_seq.clear();
662 self.in_hw.clear();
663 self.events.push(Event::EpochAdvanced {
664 epoch: self.current_epoch,
665 transition_id: tid,
666 });
667 }
668
669 pub fn trigger_resync(&mut self) {
671 if self.state != NodeState::Resyncing {
672 self.transition(NodeState::Resyncing);
673 }
674 }
675
676 pub fn check_timeouts(&mut self) -> Vec<Event> {
683 let now = Instant::now();
684
685 if self.prepare_deadline.is_some_and(|d| now >= d) {
686 self.prepare_deadline = None;
687 self.execute_deadline = None;
688 self.pending_transition_id = 0;
689 self.transition_state = TransitionState::TAborted;
690 self.emit_err_spec(codes::PREPARE_TIMEOUT, "T_prepare_max exceeded");
691 }
692
693 if self.execute_deadline.is_some_and(|d| now >= d) {
694 self.execute_deadline = None;
695 self.emit_err_spec(codes::EXECUTE_TIMEOUT, "T_execute_max exceeded");
696 }
697
698 if self.coordinator_last_seen.is_some_and(|t| {
699 now.duration_since(t).as_millis() as u64 >= timeouts::T_COORDINATOR_GRACE_MS
700 }) {
701 self.coordinator_last_seen = None;
702 self.is_coordinator = false;
703 self.emit_err_spec(
704 codes::COORDINATOR_GONE,
705 "coordinator silence exceeded T_coordinator_grace",
706 );
707 self.events.push(Event::CoordinatorElectionNeeded);
708 }
709
710 self.drain_events()
711 }
712
713 pub fn note_coordinator_activity(&mut self) {
720 self.coordinator_last_seen = Some(Instant::now());
721 }
722
723 pub fn claim_coordinator<S: Sealer>(
734 &mut self,
735 seal: &mut S,
736 target: MemberId,
737 ) -> Result<OutboundFrame, NodeError> {
738 let args = vec![0xA1u8, 0x00, 0xF5];
740 self.is_coordinator = true;
741 self.coordinator_last_seen = Some(Instant::now());
742 self.events.push(Event::BecameCoordinator);
743 self.send_control(
744 seal,
745 target,
746 ControlOpcode::CapabilitiesAdvertise,
747 self.last_transition_id,
748 0,
749 args,
750 )
751 }
752
753 fn is_coordinator_claim(args: &[u8]) -> bool {
756 if args == [0xA1, 0x00, 0xF5] {
760 return true;
761 }
762 args.windows(2).any(|w| w == [0x00, 0xF5])
766 }
767
768 fn transition(&mut self, next: NodeState) {
769 if self.state == next {
770 return;
771 }
772 if !self.state.can_transition_to(next) {
773 let from = self.state;
774 self.state = NodeState::Failed;
775 self.events.push(Event::StateChanged {
776 from,
777 to: NodeState::Failed,
778 });
779 return;
780 }
781 let from = self.state;
782 self.state = next;
783 self.events.push(Event::StateChanged { from, to: next });
784 }
785
786 fn assert_can_send(&self) -> Result<(), NodeError> {
787 if matches!(
788 self.state,
789 NodeState::Active | NodeState::Resyncing | NodeState::EstablishingGroup
790 ) {
791 Ok(())
792 } else {
793 Err(NodeError::InvalidState(format!(
794 "cannot send in state {}",
795 self.state
796 )))
797 }
798 }
799
800 fn next_seq(&mut self, st: StreamType, sid: StreamId) -> SequenceNo {
801 let entry = self.out_seq.entry((st, sid)).or_insert(0);
802 *entry += 1;
803 *entry
804 }
805
806 fn emit_err_spec(&mut self, code: u16, reason: impl Into<String>) {
807 if let Some(spec) = ErrorSpec::lookup(code) {
808 self.emit_err_named(spec.code, spec.class, spec.retryable, spec.fatal, reason);
809 } else {
810 self.emit_err_named(code, ErrorClass::Policy, false, false, reason);
811 }
812 }
813
814 fn emit_err_named(
815 &mut self,
816 code: u16,
817 class: ErrorClass,
818 retryable: bool,
819 fatal: bool,
820 reason: impl Into<String>,
821 ) {
822 let reason = reason.into();
823 let (class, retryable, fatal) = if let Some(spec) = ErrorSpec::lookup(code) {
826 (spec.class, spec.retryable, spec.fatal)
827 } else {
828 (class, retryable, fatal)
829 };
830 let _ = ErrorObject::new(code, class, retryable, fatal, reason.clone()).to_cbor();
831 self.events.push(Event::Error {
832 code,
833 class,
834 retryable,
835 fatal,
836 reason,
837 });
838 if fatal {
839 let from = self.state;
840 self.state = NodeState::Failed;
841 self.events.push(Event::StateChanged {
842 from,
843 to: NodeState::Failed,
844 });
845 }
846 }
847}
848
849pub trait Sealer {
855 fn seal(&mut self, st: StreamType, seq: SequenceNo, pt: &[u8]) -> Result<Vec<u8>, MlsError>;
857 fn open(&mut self, st: StreamType, seq: SequenceNo, ct: &[u8]) -> Result<Vec<u8>, MlsError>;
859}
860
861impl Sealer for gbp_mls::MlsContext {
862 fn seal(&mut self, st: StreamType, seq: SequenceNo, pt: &[u8]) -> Result<Vec<u8>, MlsError> {
863 gbp_mls::MlsContext::seal(self, label_for(st), seq, pt)
864 }
865 fn open(&mut self, st: StreamType, seq: SequenceNo, ct: &[u8]) -> Result<Vec<u8>, MlsError> {
866 gbp_mls::MlsContext::open(self, label_for(st), seq, ct)
867 }
868}
869
870#[cfg(test)]
871mod tests {
872 use super::*;
873
874 struct PlainSealer;
875 impl Sealer for PlainSealer {
876 fn seal(
877 &mut self,
878 _st: StreamType,
879 _seq: SequenceNo,
880 pt: &[u8],
881 ) -> Result<Vec<u8>, MlsError> {
882 Ok(pt.to_vec())
883 }
884 fn open(
885 &mut self,
886 _st: StreamType,
887 _seq: SequenceNo,
888 ct: &[u8],
889 ) -> Result<Vec<u8>, MlsError> {
890 Ok(ct.to_vec())
891 }
892 }
893
894 fn group_id() -> GroupId {
895 let mut g = [0u8; 16];
896 g[..3].copy_from_slice(b"GBP");
897 g
898 }
899
900 #[test]
901 fn replay_window_rejects_repeat() {
902 let mut alice = GroupNode::new(1, group_id());
903 let mut bob = GroupNode::new(2, group_id());
904 alice.bootstrap_as_creator(1);
905 bob.bootstrap_as_joiner(1, 0);
906 let mut s = PlainSealer;
907 let sid = alice.member_stream_id(2);
908 let f = alice
909 .send_payload(
910 &mut s,
911 2,
912 StreamType::Text,
913 sid,
914 GbpFlags::ordered_reliable_ack(),
915 b"hi",
916 PayloadCodec::Cbor,
917 )
918 .unwrap();
919 let _ = bob.on_wire(&mut s, &f.wire).unwrap();
920 let evs = bob.on_wire(&mut s, &f.wire).unwrap();
921 assert!(evs.iter().any(|e| matches!(
922 e,
923 Event::Error {
924 code: codes::REPLAY_DETECTED,
925 ..
926 }
927 )));
928 }
929
930 #[test]
931 fn epoch_mismatch_triggers_resync() {
932 let mut alice = GroupNode::new(1, group_id());
933 let mut bob = GroupNode::new(2, group_id());
934 alice.bootstrap_as_creator(1);
935 bob.bootstrap_as_joiner(1, 0);
936 alice.current_epoch = 2;
937 let mut s = PlainSealer;
938 let sid = alice.member_stream_id(2);
939 let f = alice
940 .send_payload(
941 &mut s,
942 2,
943 StreamType::Text,
944 sid,
945 GbpFlags::ordered_reliable_ack(),
946 b"x",
947 PayloadCodec::Cbor,
948 )
949 .unwrap();
950 let _ = bob.on_wire(&mut s, &f.wire).unwrap();
951 assert_eq!(bob.state, NodeState::Resyncing);
952 }
953
954 #[test]
955 fn payload_emits_received_event() {
956 let mut alice = GroupNode::new(1, group_id());
957 let mut bob = GroupNode::new(2, group_id());
958 alice.bootstrap_as_creator(1);
959 bob.bootstrap_as_joiner(1, 0);
960 let mut s = PlainSealer;
961 let sid = alice.member_stream_id(2);
962 let f = alice
963 .send_payload(
964 &mut s,
965 2,
966 StreamType::Text,
967 sid,
968 GbpFlags::ordered_reliable_ack(),
969 b"payload",
970 PayloadCodec::Cbor,
971 )
972 .unwrap();
973 let evs = bob.on_wire(&mut s, &f.wire).unwrap();
974 let pr = evs
975 .into_iter()
976 .find_map(|e| match e {
977 Event::PayloadReceived(p) => Some(p),
978 _ => None,
979 })
980 .expect("payload");
981 assert_eq!(pr.stream_type, StreamType::Text);
982 assert_eq!(pr.plaintext, b"payload");
983 }
984
985 fn drain_errs(events: &[Event]) -> Vec<u16> {
988 events
989 .iter()
990 .filter_map(|e| match e {
991 Event::Error { code, .. } => Some(*code),
992 _ => None,
993 })
994 .collect()
995 }
996
997 fn drain_controls(events: &[Event]) -> Vec<(ControlOpcode, TransitionId)> {
998 events
999 .iter()
1000 .filter_map(|e| match e {
1001 Event::Control {
1002 opcode,
1003 transition_id,
1004 ..
1005 } => Some((*opcode, *transition_id)),
1006 _ => None,
1007 })
1008 .collect()
1009 }
1010
1011 #[test]
1012 fn prepare_transition_sets_pending_on_sender_and_receiver() {
1013 let mut coord = GroupNode::new(1, group_id());
1014 let mut peer = GroupNode::new(2, group_id());
1015 coord.bootstrap_as_creator(0);
1016 peer.bootstrap_as_joiner(0, 0);
1017 let mut s = PlainSealer;
1018 let f = coord
1020 .send_control(
1021 &mut s,
1022 0,
1023 ControlOpcode::PrepareTransition,
1024 1,
1025 100,
1026 b"commit-blob".to_vec(),
1027 )
1028 .unwrap();
1029 assert_eq!(coord.pending_transition_id, 1, "sender mirrors pending");
1030 assert_eq!(coord.transition_state, TransitionState::TPrepared);
1031 let evs = peer.on_wire(&mut s, &f.wire).unwrap();
1032 assert_eq!(peer.pending_transition_id, 1, "receiver records pending");
1033 assert!(
1034 drain_errs(&evs).is_empty(),
1035 "no error: {:?}",
1036 drain_errs(&evs)
1037 );
1038 let ctls = drain_controls(&evs);
1039 assert_eq!(ctls, vec![(ControlOpcode::PrepareTransition, 1)]);
1040 }
1041
1042 #[test]
1043 fn ready_with_wrong_tid_is_rejected() {
1044 let mut coord = GroupNode::new(1, group_id());
1045 let mut peer = GroupNode::new(2, group_id());
1046 coord.bootstrap_as_creator(0);
1047 peer.bootstrap_as_joiner(0, 0);
1048 let mut s = PlainSealer;
1049 let f = coord
1050 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1051 .unwrap();
1052 peer.on_wire(&mut s, &f.wire).unwrap();
1053 let bogus = peer
1055 .send_control(&mut s, 1, ControlOpcode::ReadyForTransition, 7, 1, vec![])
1056 .unwrap();
1057 let evs = coord.on_wire(&mut s, &bogus.wire).unwrap();
1058 let errs = drain_errs(&evs);
1059 assert!(errs.contains(&codes::TRANSITION_MISMATCH), "got {:?}", errs);
1060 }
1061
1062 #[test]
1063 fn execute_advances_epoch_and_clears_pending() {
1064 let mut coord = GroupNode::new(1, group_id());
1065 let mut peer = GroupNode::new(2, group_id());
1066 coord.bootstrap_as_creator(0);
1067 peer.bootstrap_as_joiner(0, 0);
1068 let mut s = PlainSealer;
1069 let prep = coord
1070 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1071 .unwrap();
1072 peer.on_wire(&mut s, &prep.wire).unwrap();
1073 let exec = coord
1075 .send_control(&mut s, 0, ControlOpcode::ExecuteTransition, 1, 2, vec![])
1076 .unwrap();
1077 coord.apply_transition(1);
1078 let evs = peer.on_wire(&mut s, &exec.wire).unwrap();
1079 assert_eq!(coord.last_transition_id, 1);
1080 assert_eq!(coord.current_epoch, 1);
1081 assert_eq!(peer.last_transition_id, 1);
1082 assert_eq!(peer.current_epoch, 1);
1083 assert_eq!(peer.pending_transition_id, 0);
1084 assert!(evs.iter().any(|e| matches!(
1085 e,
1086 Event::EpochAdvanced {
1087 transition_id: 1,
1088 ..
1089 }
1090 )));
1091 }
1092
1093 #[test]
1094 fn abort_clears_pending_no_advance() {
1095 let mut coord = GroupNode::new(1, group_id());
1096 let mut peer = GroupNode::new(2, group_id());
1097 coord.bootstrap_as_creator(0);
1098 peer.bootstrap_as_joiner(0, 0);
1099 let mut s = PlainSealer;
1100 let prep = coord
1101 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1102 .unwrap();
1103 peer.on_wire(&mut s, &prep.wire).unwrap();
1104 let abort = coord
1105 .send_control(&mut s, 0, ControlOpcode::AbortTransition, 1, 2, vec![])
1106 .unwrap();
1107 peer.on_wire(&mut s, &abort.wire).unwrap();
1108 assert_eq!(peer.pending_transition_id, 0);
1109 assert_eq!(peer.current_epoch, 0);
1110 assert_eq!(peer.transition_state, TransitionState::TAborted);
1111 assert_eq!(coord.transition_state, TransitionState::TAborted);
1112 }
1113
1114 #[test]
1115 fn bootstrap_as_joiner_with_expected_tid_accepts_first_execute() {
1116 let mut coord = GroupNode::new(1, group_id());
1117 let mut joiner = GroupNode::new(2, group_id());
1119 coord.bootstrap_as_creator(0);
1120 joiner.bootstrap_as_joiner(0, 1);
1121 assert_eq!(joiner.pending_transition_id, 1);
1122 let mut s = PlainSealer;
1123 let _ = coord
1125 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1126 .unwrap();
1127 let exec = coord
1129 .send_control(&mut s, 0, ControlOpcode::ExecuteTransition, 1, 2, vec![])
1130 .unwrap();
1131 let evs = joiner.on_wire(&mut s, &exec.wire).unwrap();
1132 let errs = drain_errs(&evs);
1133 assert!(
1134 errs.is_empty(),
1135 "expected clean apply, got errors {:?}",
1136 errs
1137 );
1138 assert_eq!(joiner.last_transition_id, 1);
1139 assert_eq!(joiner.current_epoch, 1);
1140 }
1141
1142 #[test]
1145 fn claim_coordinator_sets_flag_and_emits_event() {
1146 let mut node = GroupNode::new(1, group_id());
1147 node.bootstrap_as_creator(0);
1148 node.drain_events();
1149 let mut s = PlainSealer;
1150 let _ = node.claim_coordinator(&mut s, 0).unwrap();
1151 assert!(node.is_coordinator);
1152 let evs = node.drain_events();
1153 assert!(evs.iter().any(|e| matches!(e, Event::BecameCoordinator)));
1154 }
1155
1156 #[test]
1157 fn coordinator_gone_emits_election_needed() {
1158 let mut member = GroupNode::new(2, group_id());
1159 member.bootstrap_as_joiner(0, 0);
1160 member.coordinator_last_seen = Some(Instant::now() - Duration::from_millis(11_000));
1161 let evs = member.check_timeouts();
1162 assert!(
1163 evs.iter()
1164 .any(|e| matches!(e, Event::CoordinatorElectionNeeded))
1165 );
1166 assert!(!member.is_coordinator, "flag cleared on silence");
1167 }
1168
1169 #[test]
1170 fn capabilities_advertise_with_claim_resets_silence_timer() {
1171 let mut member = GroupNode::new(2, group_id());
1172 let mut coord = GroupNode::new(1, group_id());
1173 member.bootstrap_as_joiner(0, 0);
1174 coord.bootstrap_as_creator(0);
1175 let mut s = PlainSealer;
1176 let f = coord.claim_coordinator(&mut s, 2).unwrap();
1178 let evs = member.on_wire(&mut s, &f.wire).unwrap();
1180 assert!(
1181 member.coordinator_last_seen.is_some(),
1182 "silence timer reset"
1183 );
1184 assert!(
1185 evs.iter()
1186 .any(|e| matches!(e, Event::CoordinatorClaim { claimant: 1 }))
1187 );
1188 }
1189
1190 #[test]
1191 fn higher_id_yields_to_lower_claimant() {
1192 let mut node5 = GroupNode::new(5, group_id());
1194 let mut node2 = GroupNode::new(2, group_id());
1195 node5.bootstrap_as_joiner(0, 0);
1196 node2.bootstrap_as_creator(0);
1197 let mut s = PlainSealer;
1198 node5.is_coordinator = true;
1200 let f = node2.claim_coordinator(&mut s, 5).unwrap();
1202 node5.on_wire(&mut s, &f.wire).unwrap();
1203 assert!(!node5.is_coordinator, "node5 yielded to node2");
1204 }
1205
1206 #[test]
1207 fn lower_id_keeps_coordinator_against_higher_claimant() {
1208 let mut node1 = GroupNode::new(1, group_id());
1209 let mut node5 = GroupNode::new(5, group_id());
1210 node1.bootstrap_as_creator(0);
1211 node5.bootstrap_as_joiner(0, 0);
1212 let mut s = PlainSealer;
1213 node1.is_coordinator = true;
1214 let f = node5.claim_coordinator(&mut s, 1).unwrap();
1215 node1.on_wire(&mut s, &f.wire).unwrap();
1216 assert!(node1.is_coordinator, "node1 keeps role — it has lower id");
1217 }
1218
1219 #[test]
1222 fn competing_prepare_lower_member_id_wins() {
1223 let mut node = GroupNode::new(10, group_id());
1226 node.bootstrap_as_joiner(0, 0);
1227 let mut s = PlainSealer;
1228
1229 let mut sender1 = GroupNode::new(1, group_id());
1231 sender1.bootstrap_as_creator(0);
1232 let f1 = sender1
1233 .send_control(
1234 &mut s,
1235 10,
1236 ControlOpcode::PrepareTransition,
1237 1,
1238 1,
1239 b"commit-A".to_vec(),
1240 )
1241 .unwrap();
1242 node.on_wire(&mut s, &f1.wire).unwrap();
1243 assert_eq!(
1244 node.pending_commit_sender,
1245 Some(1),
1246 "member 1 is initial winner"
1247 );
1248
1249 let mut sender3 = GroupNode::new(3, group_id());
1251 sender3.bootstrap_as_creator(0);
1252 let f3 = sender3
1253 .send_control(
1254 &mut s,
1255 10,
1256 ControlOpcode::PrepareTransition,
1257 1,
1258 2,
1259 b"commit-B".to_vec(),
1260 )
1261 .unwrap();
1262 node.on_wire(&mut s, &f3.wire).unwrap();
1263 assert_eq!(node.pending_commit_sender, Some(1), "member 1 still wins");
1265 assert_eq!(node.pending_transition_id, 1);
1266 }
1267
1268 #[test]
1269 fn competing_prepare_later_lower_id_displaces_winner() {
1270 let mut node = GroupNode::new(10, group_id());
1272 node.bootstrap_as_joiner(0, 0);
1273 let mut s = PlainSealer;
1274
1275 let mut sender5 = GroupNode::new(5, group_id());
1276 sender5.bootstrap_as_creator(0);
1277 let f5 = sender5
1278 .send_control(
1279 &mut s,
1280 10,
1281 ControlOpcode::PrepareTransition,
1282 1,
1283 1,
1284 b"commit-X".to_vec(),
1285 )
1286 .unwrap();
1287 node.on_wire(&mut s, &f5.wire).unwrap();
1288 assert_eq!(node.pending_commit_sender, Some(5));
1289
1290 let mut sender2 = GroupNode::new(2, group_id());
1291 sender2.bootstrap_as_creator(0);
1292 let f2 = sender2
1293 .send_control(
1294 &mut s,
1295 10,
1296 ControlOpcode::PrepareTransition,
1297 1,
1298 2,
1299 b"commit-Y".to_vec(),
1300 )
1301 .unwrap();
1302 node.on_wire(&mut s, &f2.wire).unwrap();
1303 assert_eq!(
1304 node.pending_commit_sender,
1305 Some(2),
1306 "member 2 displaces member 5"
1307 );
1308 }
1309
1310 #[test]
1311 fn apply_transition_clears_commit_sender() {
1312 let mut coord = GroupNode::new(1, group_id());
1313 coord.bootstrap_as_creator(0);
1314 let mut s = PlainSealer;
1315 coord
1316 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1317 .unwrap();
1318 coord.apply_transition(1);
1319 assert_eq!(coord.pending_commit_sender, None);
1320 }
1321
1322 #[test]
1325 fn prepare_timeout_fires_when_deadline_exceeded() {
1326 let mut coord = GroupNode::new(1, group_id());
1327 coord.bootstrap_as_creator(0);
1328 let mut s = PlainSealer;
1329 coord
1330 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1331 .unwrap();
1332 coord.prepare_deadline = Some(Instant::now() - Duration::from_millis(1));
1334 let evs = coord.check_timeouts();
1335 assert!(
1336 evs.iter().any(|e| matches!(
1337 e,
1338 Event::Error {
1339 code: codes::PREPARE_TIMEOUT,
1340 ..
1341 }
1342 )),
1343 "expected PREPARE_TIMEOUT, got {:?}",
1344 evs
1345 );
1346 assert_eq!(
1347 coord.transition_state,
1348 TransitionState::TAborted,
1349 "transition aborted"
1350 );
1351 assert_eq!(coord.prepare_deadline, None, "deadline cleared");
1352 }
1353
1354 #[test]
1355 fn execute_timeout_fires_when_deadline_exceeded() {
1356 let mut member = GroupNode::new(2, group_id());
1357 member.bootstrap_as_joiner(0, 0);
1358 let mut s = PlainSealer;
1359 member.pending_transition_id = 1;
1361 member.transition_state = TransitionState::TPrepared;
1362 member
1363 .send_control(&mut s, 1, ControlOpcode::ReadyForTransition, 1, 1, vec![])
1364 .unwrap();
1365 member.execute_deadline = Some(Instant::now() - Duration::from_millis(1));
1367 let evs = member.check_timeouts();
1368 assert!(
1369 evs.iter().any(|e| matches!(
1370 e,
1371 Event::Error {
1372 code: codes::EXECUTE_TIMEOUT,
1373 ..
1374 }
1375 )),
1376 "expected EXECUTE_TIMEOUT, got {:?}",
1377 evs
1378 );
1379 assert_eq!(member.execute_deadline, None, "deadline cleared");
1380 }
1381
1382 #[test]
1383 fn coordinator_gone_fires_after_silence() {
1384 let mut member = GroupNode::new(2, group_id());
1385 member.bootstrap_as_joiner(0, 0);
1386 member.coordinator_last_seen = Some(Instant::now() - Duration::from_millis(11_000));
1388 let evs = member.check_timeouts();
1389 assert!(
1390 evs.iter().any(|e| matches!(
1391 e,
1392 Event::Error {
1393 code: codes::COORDINATOR_GONE,
1394 ..
1395 }
1396 )),
1397 "expected COORDINATOR_GONE, got {:?}",
1398 evs
1399 );
1400 assert_eq!(member.coordinator_last_seen, None, "timer cleared");
1401 }
1402
1403 #[test]
1404 fn note_coordinator_activity_resets_silence_timer() {
1405 let mut member = GroupNode::new(2, group_id());
1406 member.bootstrap_as_joiner(0, 0);
1407 member.coordinator_last_seen = Some(Instant::now() - Duration::from_millis(11_000));
1409 member.note_coordinator_activity();
1411 let evs = member.check_timeouts();
1412 assert!(
1413 !evs.iter().any(|e| matches!(
1414 e,
1415 Event::Error {
1416 code: codes::COORDINATOR_GONE,
1417 ..
1418 }
1419 )),
1420 "should NOT fire after reset"
1421 );
1422 }
1423
1424 #[test]
1425 fn execute_clears_prepare_deadline() {
1426 let mut coord = GroupNode::new(1, group_id());
1427 coord.bootstrap_as_creator(0);
1428 let mut s = PlainSealer;
1429 coord
1430 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1431 .unwrap();
1432 assert!(coord.prepare_deadline.is_some(), "deadline armed");
1433 coord
1434 .send_control(&mut s, 0, ControlOpcode::ExecuteTransition, 1, 2, vec![])
1435 .unwrap();
1436 assert_eq!(coord.prepare_deadline, None, "deadline cleared on EXECUTE");
1437 assert_eq!(
1438 coord.execute_deadline, None,
1439 "execute_deadline also cleared"
1440 );
1441 }
1442
1443 #[test]
1444 fn receive_prepare_arms_execute_deadline() {
1445 let mut coord = GroupNode::new(1, group_id());
1446 let mut member = GroupNode::new(2, group_id());
1447 coord.bootstrap_as_creator(0);
1448 member.bootstrap_as_joiner(0, 0);
1449 let mut s = PlainSealer;
1450 let f = coord
1451 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1452 .unwrap();
1453 member.on_wire(&mut s, &f.wire).unwrap();
1454 assert!(
1455 member.execute_deadline.is_some(),
1456 "execute_deadline armed on receiving PREPARE"
1457 );
1458 }
1459
1460 #[test]
1461 fn receive_execute_clears_execute_deadline() {
1462 let mut coord = GroupNode::new(1, group_id());
1463 let mut member = GroupNode::new(2, group_id());
1464 coord.bootstrap_as_creator(0);
1465 member.bootstrap_as_joiner(0, 0);
1466 let mut s = PlainSealer;
1467 let prep = coord
1468 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1469 .unwrap();
1470 member.on_wire(&mut s, &prep.wire).unwrap();
1471 let exec = coord
1472 .send_control(&mut s, 0, ControlOpcode::ExecuteTransition, 1, 2, vec![])
1473 .unwrap();
1474 member.on_wire(&mut s, &exec.wire).unwrap();
1475 assert_eq!(member.execute_deadline, None, "cleared on EXECUTE");
1476 }
1477
1478 #[test]
1479 fn no_timeout_when_deadlines_not_set() {
1480 let mut node = GroupNode::new(1, group_id());
1481 node.bootstrap_as_creator(0);
1482 node.drain_events(); let evs = node.check_timeouts();
1484 assert!(evs.is_empty(), "no events without armed deadlines");
1485 }
1486
1487 #[test]
1488 fn prepare_with_already_applied_tid_is_rejected() {
1489 let mut coord = GroupNode::new(1, group_id());
1492 coord.bootstrap_as_creator(0);
1493 let mut s = PlainSealer;
1494 let _ = coord
1495 .send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![])
1496 .unwrap();
1497 coord.apply_transition(1);
1498 assert_eq!(coord.last_transition_id, 1);
1499 assert_eq!(coord.pending_transition_id, 0);
1500 let mut peer = GroupNode::new(2, group_id());
1504 peer.bootstrap_as_joiner(coord.current_epoch, 0);
1505 let stale = peer
1506 .send_control(&mut s, 1, ControlOpcode::PrepareTransition, 1, 9, vec![])
1507 .unwrap();
1508 let evs = coord.on_wire(&mut s, &stale.wire).unwrap();
1509 let errs = drain_errs(&evs);
1510 assert!(
1511 errs.contains(&codes::TRANSITION_MISMATCH),
1512 "expected TRANSITION_MISMATCH, got {:?}",
1513 errs
1514 );
1515 }
1516
1517 #[test]
1518 fn decrypt_failed_is_non_fatal() {
1519 struct OpenFailSealer;
1521 impl Sealer for OpenFailSealer {
1522 fn seal(
1523 &mut self,
1524 _: StreamType,
1525 _: SequenceNo,
1526 p: &[u8],
1527 ) -> Result<Vec<u8>, MlsError> {
1528 Ok(p.to_vec())
1529 }
1530 fn open(
1531 &mut self,
1532 _: StreamType,
1533 _: SequenceNo,
1534 _: &[u8],
1535 ) -> Result<Vec<u8>, MlsError> {
1536 Err(MlsError::Aead("simulated".into()))
1537 }
1538 }
1539 let mut alice = GroupNode::new(1, group_id());
1540 let mut bob = GroupNode::new(2, group_id());
1541 alice.bootstrap_as_creator(1);
1542 bob.bootstrap_as_joiner(1, 0);
1543 let mut s = PlainSealer;
1544 let sid = alice.member_stream_id(2);
1545 let f = alice
1546 .send_payload(
1547 &mut s,
1548 2,
1549 StreamType::Text,
1550 sid,
1551 GbpFlags::ordered_reliable_ack(),
1552 b"x",
1553 PayloadCodec::Cbor,
1554 )
1555 .unwrap();
1556 let mut fail = OpenFailSealer;
1557 let evs = bob.on_wire(&mut fail, &f.wire).unwrap();
1558 let err = evs
1559 .iter()
1560 .find_map(|e| match e {
1561 Event::Error {
1562 code,
1563 fatal,
1564 retryable,
1565 ..
1566 } => Some((*code, *fatal, *retryable)),
1567 _ => None,
1568 })
1569 .expect("error event");
1570 assert_eq!(err.0, codes::DECRYPT_FAILED);
1571 assert!(!err.1, "must be non-fatal");
1572 assert!(err.2, "must be retryable");
1573 assert_eq!(bob.state, NodeState::Active, "bob stays Active");
1574 }
1575}