Skip to main content

gbp_node/
node.rs

1//! GBP-layer group node.
2//!
3//! Responsibilities of this layer (analogous to IP):
4//!
5//! * Decode incoming CBOR frames and validate `version`, `group_id`, `epoch`
6//!   and `transition_id` per the GBP spec.
7//! * Enforce a per-`(stream_type, stream_id)` replay window via
8//!   `sequence_no`.
9//! * Open the AEAD payload through the [`Sealer`] trait (typically backed by
10//!   `gbp-mls`).
11//! * Surface decoded payloads to sub-protocols as
12//!   [`Event::PayloadReceived`]; the sub-protocol layer is responsible for
13//!   message-level semantics.
14//! * Drive the control plane: handle `EXECUTE_TRANSITION`, request resync on
15//!   `EPOCH_MISMATCH`, etc.
16//!
17//! Out of scope: parsing GTP / GAP / GSP payloads, GTP idempotency, GAP
18//! `key_phase` validation and mute-list tracking. Those concerns belong to
19//! the per-sub-protocol clients in the `gtp-protocol`, `gap-protocol` and
20//! `gsp-protocol` crates.
21
22use gbp::{CodecError, ControlMessage, ErrorObject, GbpFrame};
23use gbp_core::{
24    ControlOpcode, ErrorClass, GbpFlags, GroupId, MemberId, NodeState, SequenceNo, StreamId,
25    StreamType, TransitionId, TransitionState, codes,
26    errors::ErrorSpec,
27};
28use gbp_mls::{MlsError, label_for};
29use std::collections::HashMap;
30
31/// Errors raised by [`GroupNode`].
32#[derive(Debug, thiserror::Error)]
33pub enum NodeError {
34    /// Codec error.
35    #[error("codec: {0}")]
36    Codec(#[from] CodecError),
37    /// MLS / AEAD error.
38    #[error("mls: {0}")]
39    Mls(#[from] MlsError),
40    /// The node is not in a state that allows the requested operation.
41    #[error("invalid state: {0}")]
42    InvalidState(String),
43}
44
45/// A wire-ready outbound frame: the recipient and its serialised CBOR bytes.
46pub struct OutboundFrame {
47    /// Target member id.
48    pub to: MemberId,
49    /// CBOR-encoded [`GbpFrame`] bytes.
50    pub wire: Vec<u8>,
51}
52
53/// Information about a payload delivered by GBP to a sub-protocol.
54#[derive(Debug, Clone)]
55pub struct DeliveredPayload {
56    /// Stream class on which the frame arrived.
57    pub stream_type: StreamType,
58    /// Stream id from the frame (preserved so receivers can demultiplex
59    /// multiple sub-streams).
60    pub stream_id: StreamId,
61    /// Sequence number after passing the replay window.
62    pub sequence_no: SequenceNo,
63    /// Frame flag bits, copied as-is.
64    pub flags: u16,
65    /// Decrypted plaintext bytes.
66    pub plaintext: Vec<u8>,
67}
68
69/// Events surfaced by the GBP layer.
70#[derive(Debug, Clone)]
71pub enum Event {
72    /// Node FSM changed state.
73    StateChanged {
74        /// Previous state.
75        from: NodeState,
76        /// New state.
77        to: NodeState,
78    },
79    /// Payload delivered to a sub-protocol (Text / Audio / Signal). Control
80    /// frames are handled internally and do not surface as
81    /// [`Event::PayloadReceived`].
82    PayloadReceived(DeliveredPayload),
83    /// A control plane message was received and parsed.
84    Control {
85        /// Sender member id.
86        from: MemberId,
87        /// Decoded opcode.
88        opcode: ControlOpcode,
89        /// `transition_id` carried by the message.
90        transition_id: TransitionId,
91        /// `request_id` echoed by ACK / NACK responders.
92        request_id: u32,
93        /// Opcode-specific args (CBOR or opaque bytes; e.g. the MLS Commit
94        /// embedded in `PREPARE_TRANSITION`).
95        args: Vec<u8>,
96    },
97    /// An error was raised.
98    Error {
99        /// Numeric error code.
100        code: u16,
101        /// Error class.
102        class: ErrorClass,
103        /// MAY be retried.
104        retryable: bool,
105        /// Fatal — the node is now in `FAILED`.
106        fatal: bool,
107        /// Human-readable reason.
108        reason: String,
109    },
110    /// Epoch transition has been applied locally.
111    EpochAdvanced {
112        /// New epoch.
113        epoch: u64,
114        /// `transition_id` that produced the new epoch.
115        transition_id: TransitionId,
116    },
117}
118
119/// GBP-layer node.
120///
121/// Owns the framing, AEAD, replay window, FSM and control plane.
122/// Sub-protocol semantics live in their own crates and use this type plus a
123/// [`Sealer`] for outbound traffic and `on_wire` + the resulting events for
124/// inbound traffic.
125pub struct GroupNode {
126    /// Application-level member id.
127    pub member_id: MemberId,
128    /// 16-byte group identifier.
129    pub group_id: GroupId,
130    /// Current epoch as observed by the GBP layer (the authoritative epoch
131    /// lives in the underlying MLS group).
132    pub current_epoch: u64,
133    /// Last applied `transition_id`.
134    pub last_transition_id: TransitionId,
135    /// Pending `transition_id` (set during PREPARE / READY).
136    pub pending_transition_id: TransitionId,
137    /// Node FSM.
138    pub state: NodeState,
139    /// Transition FSM.
140    pub transition_state: TransitionState,
141
142    out_seq: HashMap<(StreamType, StreamId), SequenceNo>,
143    in_hw: HashMap<(StreamType, StreamId), SequenceNo>,
144    events: Vec<Event>,
145}
146
147impl GroupNode {
148    /// Builds a fresh node in the `IDLE` state.
149    pub fn new(member_id: MemberId, group_id: GroupId) -> Self {
150        Self {
151            member_id,
152            group_id,
153            current_epoch: 0,
154            last_transition_id: 0,
155            pending_transition_id: 0,
156            state: NodeState::Idle,
157            transition_state: TransitionState::TIdle,
158            out_seq: HashMap::new(),
159            in_hw: HashMap::new(),
160            events: Vec::new(),
161        }
162    }
163
164    /// Drives the node from `IDLE` to `ACTIVE` as a creator.
165    pub fn bootstrap_as_creator(&mut self, epoch: u64) {
166        self.transition(NodeState::Connecting);
167        self.transition(NodeState::EstablishingGroup);
168        self.current_epoch = epoch;
169        self.transition(NodeState::Active);
170    }
171
172    /// Drives the node from `IDLE` to `ACTIVE` as a joiner.
173    ///
174    /// `expected_first_tid` lets the joiner pre-arm its pending transition
175    /// state so that the very next `EXECUTE_TRANSITION` (which will arrive
176    /// without a preceding PREPARE the joiner could decrypt — that PREPARE
177    /// was sealed under the pre-Welcome epoch) is accepted by
178    /// `handle_control`'s tid-validation matrix. Pass `0` if the joiner
179    /// recovered out-of-band and is already current.
180    pub fn bootstrap_as_joiner(&mut self, epoch: u64, expected_first_tid: u32) {
181        self.transition(NodeState::Connecting);
182        self.transition(NodeState::EstablishingGroup);
183        self.current_epoch = epoch;
184        if expected_first_tid > 0 {
185            self.pending_transition_id = expected_first_tid;
186            self.transition_state = TransitionState::TPrepared;
187        }
188        self.transition(NodeState::Active);
189    }
190
191    /// Drains and returns all queued events.
192    pub fn drain_events(&mut self) -> Vec<Event> {
193        std::mem::take(&mut self.events)
194    }
195
196    /// Returns a sender-unique `stream_id` within the given base class.
197    ///
198    /// This is used so that the receiver's replay window does not conflate
199    /// streams that originate from different members.
200    pub fn member_stream_id(&self, base: u32) -> StreamId {
201        debug_assert!(self.member_id < 1_000_000, "member_id overflow: {0}", self.member_id);
202        base + self.member_id * 100
203    }
204
205    /// Sends an opaque plaintext payload on the given stream.
206    ///
207    /// Used by the sub-protocol clients: each one CBOR-encodes its message
208    /// and forwards the resulting bytes here.
209    pub fn send_payload<S: Sealer>(
210        &mut self,
211        seal: &mut S,
212        target: MemberId,
213        stream_type: StreamType,
214        stream_id: StreamId,
215        flags: u16,
216        plaintext: &[u8],
217    ) -> Result<OutboundFrame, NodeError> {
218        self.assert_can_send()?;
219        let seq = self.next_seq(stream_type, stream_id);
220        let ciphertext = seal.seal(stream_type, seq, plaintext)?;
221        let frame = GbpFrame::new(
222            self.group_id,
223            self.current_epoch,
224            self.last_transition_id,
225            stream_type,
226            stream_id,
227            flags,
228            seq,
229            ciphertext,
230        );
231        Ok(OutboundFrame { to: target, wire: frame.to_cbor() })
232    }
233
234    /// Sends a control plane message on Stream 0. Wrapper around
235    /// [`GroupNode::send_payload`].
236    ///
237    /// Side effect: when the coordinator originates a `PREPARE_TRANSITION`,
238    /// it must locally adopt the same `pending_transition_id` so that the
239    /// inbound READY / EXECUTE validation matrix in `handle_control` lines
240    /// up. Without this, the coordinator never matches its own pending tid
241    /// against the remote READY frames it expects, and the handshake never
242    /// completes.
243    pub fn send_control<S: Sealer>(
244        &mut self,
245        seal: &mut S,
246        target: MemberId,
247        opcode: ControlOpcode,
248        transition_id: TransitionId,
249        request_id: u32,
250        args: Vec<u8>,
251    ) -> Result<OutboundFrame, NodeError> {
252        let ctl = ControlMessage::with_args(
253            opcode as u16,
254            request_id,
255            self.member_id,
256            transition_id,
257            args,
258        );
259        let mut flags = GbpFlags::ordered_reliable_system();
260        if matches!(
261            opcode,
262            ControlOpcode::PrepareTransition
263                | ControlOpcode::ReadyForTransition
264                | ControlOpcode::ExecuteTransition
265        ) {
266            flags |= GbpFlags::CRITICAL;
267        }
268        // Sender-side state mirroring (matches what `handle_control` does on
269        // the receiver side). We only update on PREPARE/EXECUTE/ABORT — READY
270        // is purely an ack carrying the existing pending tid.
271        match opcode {
272            ControlOpcode::PrepareTransition => {
273                self.pending_transition_id = transition_id;
274                self.transition_state = TransitionState::TPrepared;
275            }
276            ControlOpcode::AbortTransition => {
277                self.pending_transition_id = 0;
278                self.transition_state = TransitionState::TAborted;
279            }
280            _ => {}
281        }
282        let stream_id = self.member_stream_id(0);
283        self.send_payload(seal, target, StreamType::Control, stream_id, flags, &ctl.to_cbor())
284    }
285
286    /// Feeds wire bytes to the node.
287    ///
288    /// Performs the §6.2 validation pipeline (version → group_id → epoch →
289    /// payload_size → transition_id → replay), opens the AEAD payload and
290    /// either:
291    /// * dispatches the parsed control message internally (for
292    ///   `StreamType::Control`), or
293    /// * surfaces an [`Event::PayloadReceived`] (for application streams).
294    ///
295    /// Returns every event that was produced as a result.
296    pub fn on_wire<S: Sealer>(
297        &mut self,
298        seal: &mut S,
299        wire: &[u8],
300    ) -> Result<Vec<Event>, NodeError> {
301        // Decode without payload-size validation — we want a malformed v!=1
302        // frame to surface as `ERR_UNSUPPORTED_VERSION`, not as
303        // `ERR_PAYLOAD_SIZE_MISMATCH`. Validation runs in deliver_frame, in
304        // the order required by §6.2.
305        let frame = match GbpFrame::decode(wire) {
306            Ok(f) => f,
307            Err(e) => {
308                self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, format!("frame decode: {e}"));
309                return Ok(self.drain_events());
310            }
311        };
312        self.deliver_frame(seal, frame)?;
313        Ok(self.drain_events())
314    }
315
316    fn deliver_frame<S: Sealer>(&mut self, seal: &mut S, frame: GbpFrame) -> Result<(), NodeError> {
317        // §6.2 order: version → group_id → epoch → payload_size →
318        // transition_id (when CRITICAL) → replay.
319        if frame.version != 1 {
320            self.emit_err_spec(codes::UNSUPPORTED_VERSION, "version != 1");
321            return Ok(());
322        }
323        if frame.group_id_array() != self.group_id {
324            self.emit_err_spec(codes::UNKNOWN_GROUP, "group_id");
325            return Ok(());
326        }
327        if frame.epoch != self.current_epoch {
328            self.emit_err_spec(
329                codes::EPOCH_MISMATCH,
330                format!("got {}, expected {}", frame.epoch, self.current_epoch),
331            );
332            self.trigger_resync();
333            return Ok(());
334        }
335        if let Err(e) = frame.validate_payload_size() {
336            self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, format!("payload size: {e}"));
337            return Ok(());
338        }
339        let flags = GbpFlags::from_bits(frame.flags);
340        let st = match frame.stream_type_typed() {
341            Ok(st) => st,
342            Err(_) => {
343                self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "unknown stream_type");
344                return Ok(());
345            }
346        };
347
348        // §6.2 transition_id ordering: CRITICAL frames on application streams
349        // MUST equal `last_transition_id`. Control-stream frames are exempt
350        // from this check and validated per-opcode inside `handle_control`,
351        // because PREPARE_TRANSITION legitimately carries `last + 1` and
352        // EXECUTE / ACK carry `pending_transition_id`.
353        if st != StreamType::Control
354            && flags.has(GbpFlags::CRITICAL)
355            && frame.transition_id != self.last_transition_id
356        {
357            self.emit_err_spec(
358                codes::TRANSITION_MISMATCH,
359                format!("got tid={}, expected {}", frame.transition_id, self.last_transition_id),
360            );
361            return Ok(());
362        }
363
364        let key = (st, frame.stream_id);
365        let hw = self.in_hw.get(&key).copied().unwrap_or(0);
366        if frame.sequence_no <= hw {
367            self.emit_err_spec(
368                codes::REPLAY_DETECTED,
369                format!(
370                    "st={} sid={} seq={} hw={}",
371                    st, frame.stream_id, frame.sequence_no, hw
372                ),
373            );
374            return Ok(());
375        }
376        self.in_hw.insert(key, frame.sequence_no);
377
378        let plain = match seal.open(st, frame.sequence_no, &frame.encrypted_payload) {
379            Ok(p) => p,
380            Err(e) => {
381                // Non-fatal: a frame addressed under a different MLS epoch
382                // (e.g. PREPARE_TRANSITION reaching a fresh joiner that has
383                // already accepted the post-commit Welcome) cannot be
384                // decrypted, but that's expected and the node MUST keep
385                // running to receive the subsequent EXECUTE frame on the
386                // shared post-merge epoch.
387                self.emit_err_named(
388                    codes::DECRYPT_FAILED,
389                    ErrorClass::Crypto,
390                    true,   // retryable: caller may resync via digest
391                    false,  // non-fatal
392                    format!("aead open: {e}"),
393                );
394                return Ok(());
395            }
396        };
397
398        match st {
399            StreamType::Control => self.handle_control(plain),
400            other => self.events.push(Event::PayloadReceived(DeliveredPayload {
401                stream_type: other,
402                stream_id: frame.stream_id,
403                sequence_no: frame.sequence_no,
404                flags: frame.flags,
405                plaintext: plain,
406            })),
407        }
408        Ok(())
409    }
410
411    fn handle_control(&mut self, plain: Vec<u8>) {
412        let c = match ControlMessage::from_cbor(&plain) {
413            Ok(c) => c,
414            Err(_) => {
415                self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "control decode");
416                return;
417            }
418        };
419        let opcode = match ControlOpcode::try_from(c.opcode) {
420            Ok(op) => op,
421            Err(_) => {
422                self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "unknown opcode");
423                return;
424            }
425        };
426        // Per-opcode TransitionID validation (§5 of gbp-control-plane).
427        let tid_ok = match opcode {
428            // PREPARE introduces last+1; receiver simply records it as pending.
429            // Re-issuing a PREPARE for an already-pending tid is allowed; a
430            // smaller-or-equal tid that is not strictly newer is rejected.
431            ControlOpcode::PrepareTransition => {
432                c.transition_id > self.last_transition_id
433                    && (self.pending_transition_id == 0
434                        || self.pending_transition_id == c.transition_id)
435            }
436            // READY / EXECUTE / ABORT must reference the pending tid.
437            ControlOpcode::ReadyForTransition
438            | ControlOpcode::ExecuteTransition
439            | ControlOpcode::AbortTransition => {
440                self.pending_transition_id != 0
441                    && c.transition_id == self.pending_transition_id
442            }
443            // Digest / capability / ack / nack: tid is informational, no
444            // ordering constraint at the GBP layer.
445            _ => true,
446        };
447        if !tid_ok {
448            self.emit_err_spec(
449                codes::TRANSITION_MISMATCH,
450                format!(
451                    "control tid={} not valid for {:?} (last={}, pending={})",
452                    c.transition_id, opcode, self.last_transition_id, self.pending_transition_id
453                ),
454            );
455            return;
456        }
457        match opcode {
458            ControlOpcode::PrepareTransition => {
459                self.pending_transition_id = c.transition_id;
460                self.transition_state = TransitionState::TPrepared;
461            }
462            ControlOpcode::ReadyForTransition => {
463                self.transition_state = TransitionState::TReady;
464            }
465            ControlOpcode::ExecuteTransition => {
466                self.apply_transition(c.transition_id);
467            }
468            ControlOpcode::AbortTransition => {
469                self.transition_state = TransitionState::TAborted;
470                self.pending_transition_id = 0;
471            }
472            ControlOpcode::GroupStateDigestResponse => {
473                if self.state == NodeState::Resyncing {
474                    self.transition(NodeState::Active);
475                }
476            }
477            _ => {}
478        }
479        self.events.push(Event::Control {
480            from: c.sender_id,
481            opcode,
482            transition_id: c.transition_id,
483            request_id: c.request_id,
484            args: c.args.to_vec(),
485        });
486    }
487
488    /// Applies a new epoch (called by the coordinator after
489    /// `EXECUTE_TRANSITION`).
490    pub fn apply_transition(&mut self, tid: TransitionId) {
491        self.current_epoch += 1;
492        self.last_transition_id = tid;
493        self.pending_transition_id = 0;
494        self.transition_state = TransitionState::TExecuted;
495        self.out_seq.clear();
496        self.in_hw.clear();
497        self.events.push(Event::EpochAdvanced {
498            epoch: self.current_epoch,
499            transition_id: tid,
500        });
501    }
502
503    /// Forces the node into the `RESYNCING` state.
504    pub fn trigger_resync(&mut self) {
505        if self.state != NodeState::Resyncing {
506            self.transition(NodeState::Resyncing);
507        }
508    }
509
510    fn transition(&mut self, next: NodeState) {
511        if self.state == next {
512            return;
513        }
514        if !self.state.can_transition_to(next) {
515            let from = self.state;
516            self.state = NodeState::Failed;
517            self.events.push(Event::StateChanged { from, to: NodeState::Failed });
518            return;
519        }
520        let from = self.state;
521        self.state = next;
522        self.events.push(Event::StateChanged { from, to: next });
523    }
524
525    fn assert_can_send(&self) -> Result<(), NodeError> {
526        if matches!(
527            self.state,
528            NodeState::Active | NodeState::Resyncing | NodeState::EstablishingGroup
529        ) {
530            Ok(())
531        } else {
532            Err(NodeError::InvalidState(format!("cannot send in state {}", self.state)))
533        }
534    }
535
536    fn next_seq(&mut self, st: StreamType, sid: StreamId) -> SequenceNo {
537        let entry = self.out_seq.entry((st, sid)).or_insert(0);
538        *entry += 1;
539        *entry
540    }
541
542    fn emit_err_spec(&mut self, code: u16, reason: impl Into<String>) {
543        if let Some(spec) = ErrorSpec::lookup(code) {
544            self.emit_err_named(spec.code, spec.class, spec.retryable, spec.fatal, reason);
545        } else {
546            self.emit_err_named(code, ErrorClass::Policy, false, false, reason);
547        }
548    }
549
550    fn emit_err_named(
551        &mut self,
552        code: u16,
553        class: ErrorClass,
554        retryable: bool,
555        fatal: bool,
556        reason: impl Into<String>,
557    ) {
558        let reason = reason.into();
559        // ErrorSpec is authoritative for known codes — use its class/retryable/fatal
560        // so that the wire error always matches the registry.
561        let (class, retryable, fatal) = if let Some(spec) = ErrorSpec::lookup(code) {
562            (spec.class, spec.retryable, spec.fatal)
563        } else {
564            (class, retryable, fatal)
565        };
566        let _ = ErrorObject::new(code, class, retryable, fatal, reason.clone()).to_cbor();
567        self.events.push(Event::Error { code, class, retryable, fatal, reason });
568        if fatal {
569            let from = self.state;
570            self.state = NodeState::Failed;
571            self.events.push(Event::StateChanged { from, to: NodeState::Failed });
572        }
573    }
574}
575
576/// Trait abstracting the AEAD seal / open operations.
577///
578/// Implemented for [`gbp_mls::MlsContext`] in this crate; tests typically
579/// implement a no-op identity sealer to exercise the FSM without standing
580/// up an MLS group.
581pub trait Sealer {
582    /// Encrypts `pt` for the given stream and per-stream sequence number.
583    fn seal(&mut self, st: StreamType, seq: SequenceNo, pt: &[u8]) -> Result<Vec<u8>, MlsError>;
584    /// Decrypts `ct` for the given stream and per-stream sequence number.
585    fn open(&mut self, st: StreamType, seq: SequenceNo, ct: &[u8]) -> Result<Vec<u8>, MlsError>;
586}
587
588impl Sealer for gbp_mls::MlsContext {
589    fn seal(&mut self, st: StreamType, seq: SequenceNo, pt: &[u8]) -> Result<Vec<u8>, MlsError> {
590        gbp_mls::MlsContext::seal(self, label_for(st), seq, pt)
591    }
592    fn open(&mut self, st: StreamType, seq: SequenceNo, ct: &[u8]) -> Result<Vec<u8>, MlsError> {
593        gbp_mls::MlsContext::open(self, label_for(st), seq, ct)
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    struct PlainSealer;
602    impl Sealer for PlainSealer {
603        fn seal(&mut self, _st: StreamType, _seq: SequenceNo, pt: &[u8]) -> Result<Vec<u8>, MlsError> {
604            Ok(pt.to_vec())
605        }
606        fn open(&mut self, _st: StreamType, _seq: SequenceNo, ct: &[u8]) -> Result<Vec<u8>, MlsError> {
607            Ok(ct.to_vec())
608        }
609    }
610
611    fn group_id() -> GroupId {
612        let mut g = [0u8; 16];
613        g[..3].copy_from_slice(b"GBP");
614        g
615    }
616
617    #[test]
618    fn replay_window_rejects_repeat() {
619        let mut alice = GroupNode::new(1, group_id());
620        let mut bob = GroupNode::new(2, group_id());
621        alice.bootstrap_as_creator(1);
622        bob.bootstrap_as_joiner(1, 0);
623        let mut s = PlainSealer;
624        let sid = alice.member_stream_id(2);
625        let f = alice
626            .send_payload(&mut s, 2, StreamType::Text, sid, GbpFlags::ordered_reliable_ack(), b"hi")
627            .unwrap();
628        let _ = bob.on_wire(&mut s, &f.wire).unwrap();
629        let evs = bob.on_wire(&mut s, &f.wire).unwrap();
630        assert!(evs.iter().any(|e| matches!(
631            e, Event::Error { code: codes::REPLAY_DETECTED, .. }
632        )));
633    }
634
635    #[test]
636    fn epoch_mismatch_triggers_resync() {
637        let mut alice = GroupNode::new(1, group_id());
638        let mut bob = GroupNode::new(2, group_id());
639        alice.bootstrap_as_creator(1);
640        bob.bootstrap_as_joiner(1, 0);
641        alice.current_epoch = 2;
642        let mut s = PlainSealer;
643        let sid = alice.member_stream_id(2);
644        let f = alice
645            .send_payload(&mut s, 2, StreamType::Text, sid, GbpFlags::ordered_reliable_ack(), b"x")
646            .unwrap();
647        let _ = bob.on_wire(&mut s, &f.wire).unwrap();
648        assert_eq!(bob.state, NodeState::Resyncing);
649    }
650
651    #[test]
652    fn payload_emits_received_event() {
653        let mut alice = GroupNode::new(1, group_id());
654        let mut bob = GroupNode::new(2, group_id());
655        alice.bootstrap_as_creator(1);
656        bob.bootstrap_as_joiner(1, 0);
657        let mut s = PlainSealer;
658        let sid = alice.member_stream_id(2);
659        let f = alice
660            .send_payload(&mut s, 2, StreamType::Text, sid, GbpFlags::ordered_reliable_ack(), b"payload")
661            .unwrap();
662        let evs = bob.on_wire(&mut s, &f.wire).unwrap();
663        let pr = evs.into_iter().find_map(|e| match e {
664            Event::PayloadReceived(p) => Some(p),
665            _ => None,
666        }).expect("payload");
667        assert_eq!(pr.stream_type, StreamType::Text);
668        assert_eq!(pr.plaintext, b"payload");
669    }
670
671    // ---- Control-plane handshake -----------------------------------------
672
673    fn drain_errs(events: &[Event]) -> Vec<u16> {
674        events.iter().filter_map(|e| match e {
675            Event::Error { code, .. } => Some(*code),
676            _ => None,
677        }).collect()
678    }
679
680    fn drain_controls(events: &[Event]) -> Vec<(ControlOpcode, TransitionId)> {
681        events.iter().filter_map(|e| match e {
682            Event::Control { opcode, transition_id, .. } => Some((*opcode, *transition_id)),
683            _ => None,
684        }).collect()
685    }
686
687    #[test]
688    fn prepare_transition_sets_pending_on_sender_and_receiver() {
689        let mut coord = GroupNode::new(1, group_id());
690        let mut peer = GroupNode::new(2, group_id());
691        coord.bootstrap_as_creator(0);
692        peer.bootstrap_as_joiner(0, 0);
693        let mut s = PlainSealer;
694        // Coordinator sends PREPARE for tid=1
695        let f = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 100, b"commit-blob".to_vec()).unwrap();
696        assert_eq!(coord.pending_transition_id, 1, "sender mirrors pending");
697        assert_eq!(coord.transition_state, TransitionState::TPrepared);
698        let evs = peer.on_wire(&mut s, &f.wire).unwrap();
699        assert_eq!(peer.pending_transition_id, 1, "receiver records pending");
700        assert!(drain_errs(&evs).is_empty(), "no error: {:?}", drain_errs(&evs));
701        let ctls = drain_controls(&evs);
702        assert_eq!(ctls, vec![(ControlOpcode::PrepareTransition, 1)]);
703    }
704
705    #[test]
706    fn ready_with_wrong_tid_is_rejected() {
707        let mut coord = GroupNode::new(1, group_id());
708        let mut peer = GroupNode::new(2, group_id());
709        coord.bootstrap_as_creator(0);
710        peer.bootstrap_as_joiner(0, 0);
711        let mut s = PlainSealer;
712        let f = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
713        peer.on_wire(&mut s, &f.wire).unwrap();
714        // Peer fakes a READY for the wrong tid
715        let bogus = peer.send_control(&mut s, 1, ControlOpcode::ReadyForTransition, 7, 1, vec![]).unwrap();
716        let evs = coord.on_wire(&mut s, &bogus.wire).unwrap();
717        let errs = drain_errs(&evs);
718        assert!(errs.contains(&codes::TRANSITION_MISMATCH), "got {:?}", errs);
719    }
720
721    #[test]
722    fn execute_advances_epoch_and_clears_pending() {
723        let mut coord = GroupNode::new(1, group_id());
724        let mut peer = GroupNode::new(2, group_id());
725        coord.bootstrap_as_creator(0);
726        peer.bootstrap_as_joiner(0, 0);
727        let mut s = PlainSealer;
728        let prep = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
729        peer.on_wire(&mut s, &prep.wire).unwrap();
730        // Coordinator broadcasts EXECUTE; both sides apply (coord locally, peer via on_wire)
731        let exec = coord.send_control(&mut s, 0, ControlOpcode::ExecuteTransition, 1, 2, vec![]).unwrap();
732        coord.apply_transition(1);
733        let evs = peer.on_wire(&mut s, &exec.wire).unwrap();
734        assert_eq!(coord.last_transition_id, 1);
735        assert_eq!(coord.current_epoch, 1);
736        assert_eq!(peer.last_transition_id, 1);
737        assert_eq!(peer.current_epoch, 1);
738        assert_eq!(peer.pending_transition_id, 0);
739        assert!(evs.iter().any(|e| matches!(e, Event::EpochAdvanced { transition_id: 1, .. })));
740    }
741
742    #[test]
743    fn abort_clears_pending_no_advance() {
744        let mut coord = GroupNode::new(1, group_id());
745        let mut peer = GroupNode::new(2, group_id());
746        coord.bootstrap_as_creator(0);
747        peer.bootstrap_as_joiner(0, 0);
748        let mut s = PlainSealer;
749        let prep = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
750        peer.on_wire(&mut s, &prep.wire).unwrap();
751        let abort = coord.send_control(&mut s, 0, ControlOpcode::AbortTransition, 1, 2, vec![]).unwrap();
752        peer.on_wire(&mut s, &abort.wire).unwrap();
753        assert_eq!(peer.pending_transition_id, 0);
754        assert_eq!(peer.current_epoch, 0);
755        assert_eq!(peer.transition_state, TransitionState::TAborted);
756        assert_eq!(coord.transition_state, TransitionState::TAborted);
757    }
758
759    #[test]
760    fn bootstrap_as_joiner_with_expected_tid_accepts_first_execute() {
761        let mut coord = GroupNode::new(1, group_id());
762        // Joiner pre-arms expected_first_tid=1 — typical post-Welcome state.
763        let mut joiner = GroupNode::new(2, group_id());
764        coord.bootstrap_as_creator(0);
765        joiner.bootstrap_as_joiner(0, 1);
766        assert_eq!(joiner.pending_transition_id, 1);
767        let mut s = PlainSealer;
768        // Coordinator must mirror its pending too — simulate by sending PREPARE
769        let _ = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
770        // EXECUTE should be accepted by the joiner without ever seeing PREPARE
771        let exec = coord.send_control(&mut s, 0, ControlOpcode::ExecuteTransition, 1, 2, vec![]).unwrap();
772        let evs = joiner.on_wire(&mut s, &exec.wire).unwrap();
773        let errs = drain_errs(&evs);
774        assert!(errs.is_empty(), "expected clean apply, got errors {:?}", errs);
775        assert_eq!(joiner.last_transition_id, 1);
776        assert_eq!(joiner.current_epoch, 1);
777    }
778
779    #[test]
780    fn prepare_with_already_applied_tid_is_rejected() {
781        // After the coordinator has fully applied tid=1, a replay or
782        // late-coordinator PREPARE with the same tid must fail validation.
783        let mut coord = GroupNode::new(1, group_id());
784        coord.bootstrap_as_creator(0);
785        let mut s = PlainSealer;
786        let _ = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
787        coord.apply_transition(1);
788        assert_eq!(coord.last_transition_id, 1);
789        assert_eq!(coord.pending_transition_id, 0);
790        // Forge a PREPARE with the same already-applied tid (epoch matches
791        // because we synthesise it locally with a peer node on the same
792        // post-apply epoch).
793        let mut peer = GroupNode::new(2, group_id());
794        peer.bootstrap_as_joiner(coord.current_epoch, 0);
795        let stale = peer.send_control(&mut s, 1, ControlOpcode::PrepareTransition, 1, 9, vec![]).unwrap();
796        let evs = coord.on_wire(&mut s, &stale.wire).unwrap();
797        let errs = drain_errs(&evs);
798        assert!(errs.contains(&codes::TRANSITION_MISMATCH), "expected TRANSITION_MISMATCH, got {:?}", errs);
799    }
800
801    #[test]
802    fn decrypt_failed_is_non_fatal() {
803        // Simulate a frame our open() can't unlock: a sealer that fails on `open`.
804        struct OpenFailSealer;
805        impl Sealer for OpenFailSealer {
806            fn seal(&mut self, _: StreamType, _: SequenceNo, p: &[u8]) -> Result<Vec<u8>, MlsError> { Ok(p.to_vec()) }
807            fn open(&mut self, _: StreamType, _: SequenceNo, _: &[u8]) -> Result<Vec<u8>, MlsError> { Err(MlsError::Aead("simulated".into())) }
808        }
809        let mut alice = GroupNode::new(1, group_id());
810        let mut bob = GroupNode::new(2, group_id());
811        alice.bootstrap_as_creator(1);
812        bob.bootstrap_as_joiner(1, 0);
813        let mut s = PlainSealer;
814        let sid = alice.member_stream_id(2);
815        let f = alice.send_payload(&mut s, 2, StreamType::Text, sid, GbpFlags::ordered_reliable_ack(), b"x").unwrap();
816        let mut fail = OpenFailSealer;
817        let evs = bob.on_wire(&mut fail, &f.wire).unwrap();
818        let err = evs.iter().find_map(|e| match e {
819            Event::Error { code, fatal, retryable, .. } => Some((*code, *fatal, *retryable)),
820            _ => None,
821        }).expect("error event");
822        assert_eq!(err.0, codes::DECRYPT_FAILED);
823        assert!(!err.1, "must be non-fatal");
824        assert!(err.2, "must be retryable");
825        assert_eq!(bob.state, NodeState::Active, "bob stays Active");
826    }
827}