use gbp::{CodecError, ControlMessage, ErrorObject, GbpFrame};
use gbp_core::{
ControlOpcode, ErrorClass, GbpFlags, GroupId, MemberId, NodeState, SequenceNo, StreamId,
StreamType, TransitionId, TransitionState, codes,
errors::ErrorSpec,
};
use gbp_mls::{MlsError, label_for};
use std::collections::HashMap;
#[derive(Debug, thiserror::Error)]
pub enum NodeError {
#[error("codec: {0}")]
Codec(#[from] CodecError),
#[error("mls: {0}")]
Mls(#[from] MlsError),
#[error("invalid state: {0}")]
InvalidState(String),
}
pub struct OutboundFrame {
pub to: MemberId,
pub wire: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct DeliveredPayload {
pub stream_type: StreamType,
pub stream_id: StreamId,
pub sequence_no: SequenceNo,
pub flags: u16,
pub plaintext: Vec<u8>,
}
#[derive(Debug, Clone)]
pub enum Event {
StateChanged {
from: NodeState,
to: NodeState,
},
PayloadReceived(DeliveredPayload),
Control {
from: MemberId,
opcode: ControlOpcode,
transition_id: TransitionId,
request_id: u32,
args: Vec<u8>,
},
Error {
code: u16,
class: ErrorClass,
retryable: bool,
fatal: bool,
reason: String,
},
EpochAdvanced {
epoch: u64,
transition_id: TransitionId,
},
}
pub struct GroupNode {
pub member_id: MemberId,
pub group_id: GroupId,
pub current_epoch: u64,
pub last_transition_id: TransitionId,
pub pending_transition_id: TransitionId,
pub state: NodeState,
pub transition_state: TransitionState,
out_seq: HashMap<(StreamType, StreamId), SequenceNo>,
in_hw: HashMap<(StreamType, StreamId), SequenceNo>,
events: Vec<Event>,
}
impl GroupNode {
pub fn new(member_id: MemberId, group_id: GroupId) -> Self {
Self {
member_id,
group_id,
current_epoch: 0,
last_transition_id: 0,
pending_transition_id: 0,
state: NodeState::Idle,
transition_state: TransitionState::TIdle,
out_seq: HashMap::new(),
in_hw: HashMap::new(),
events: Vec::new(),
}
}
pub fn bootstrap_as_creator(&mut self, epoch: u64) {
self.transition(NodeState::Connecting);
self.transition(NodeState::EstablishingGroup);
self.current_epoch = epoch;
self.transition(NodeState::Active);
}
pub fn bootstrap_as_joiner(&mut self, epoch: u64, expected_first_tid: u32) {
self.transition(NodeState::Connecting);
self.transition(NodeState::EstablishingGroup);
self.current_epoch = epoch;
if expected_first_tid > 0 {
self.pending_transition_id = expected_first_tid;
self.transition_state = TransitionState::TPrepared;
}
self.transition(NodeState::Active);
}
pub fn drain_events(&mut self) -> Vec<Event> {
std::mem::take(&mut self.events)
}
pub fn member_stream_id(&self, base: u32) -> StreamId {
debug_assert!(self.member_id < 1_000_000, "member_id overflow: {0}", self.member_id);
base + self.member_id * 100
}
pub fn send_payload<S: Sealer>(
&mut self,
seal: &mut S,
target: MemberId,
stream_type: StreamType,
stream_id: StreamId,
flags: u16,
plaintext: &[u8],
) -> Result<OutboundFrame, NodeError> {
self.assert_can_send()?;
let seq = self.next_seq(stream_type, stream_id);
let ciphertext = seal.seal(stream_type, seq, plaintext)?;
let frame = GbpFrame::new(
self.group_id,
self.current_epoch,
self.last_transition_id,
stream_type,
stream_id,
flags,
seq,
ciphertext,
);
Ok(OutboundFrame { to: target, wire: frame.to_cbor() })
}
pub fn send_control<S: Sealer>(
&mut self,
seal: &mut S,
target: MemberId,
opcode: ControlOpcode,
transition_id: TransitionId,
request_id: u32,
args: Vec<u8>,
) -> Result<OutboundFrame, NodeError> {
let ctl = ControlMessage::with_args(
opcode as u16,
request_id,
self.member_id,
transition_id,
args,
);
let mut flags = GbpFlags::ordered_reliable_system();
if matches!(
opcode,
ControlOpcode::PrepareTransition
| ControlOpcode::ReadyForTransition
| ControlOpcode::ExecuteTransition
) {
flags |= GbpFlags::CRITICAL;
}
match opcode {
ControlOpcode::PrepareTransition => {
self.pending_transition_id = transition_id;
self.transition_state = TransitionState::TPrepared;
}
ControlOpcode::AbortTransition => {
self.pending_transition_id = 0;
self.transition_state = TransitionState::TAborted;
}
_ => {}
}
let stream_id = self.member_stream_id(0);
self.send_payload(seal, target, StreamType::Control, stream_id, flags, &ctl.to_cbor())
}
pub fn on_wire<S: Sealer>(
&mut self,
seal: &mut S,
wire: &[u8],
) -> Result<Vec<Event>, NodeError> {
let frame = match GbpFrame::decode(wire) {
Ok(f) => f,
Err(e) => {
self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, format!("frame decode: {e}"));
return Ok(self.drain_events());
}
};
self.deliver_frame(seal, frame)?;
Ok(self.drain_events())
}
fn deliver_frame<S: Sealer>(&mut self, seal: &mut S, frame: GbpFrame) -> Result<(), NodeError> {
if frame.version != 1 {
self.emit_err_spec(codes::UNSUPPORTED_VERSION, "version != 1");
return Ok(());
}
if frame.group_id_array() != self.group_id {
self.emit_err_spec(codes::UNKNOWN_GROUP, "group_id");
return Ok(());
}
if frame.epoch != self.current_epoch {
self.emit_err_spec(
codes::EPOCH_MISMATCH,
format!("got {}, expected {}", frame.epoch, self.current_epoch),
);
self.trigger_resync();
return Ok(());
}
if let Err(e) = frame.validate_payload_size() {
self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, format!("payload size: {e}"));
return Ok(());
}
let flags = GbpFlags::from_bits(frame.flags);
let st = match frame.stream_type_typed() {
Ok(st) => st,
Err(_) => {
self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "unknown stream_type");
return Ok(());
}
};
if st != StreamType::Control
&& flags.has(GbpFlags::CRITICAL)
&& frame.transition_id != self.last_transition_id
{
self.emit_err_spec(
codes::TRANSITION_MISMATCH,
format!("got tid={}, expected {}", frame.transition_id, self.last_transition_id),
);
return Ok(());
}
let key = (st, frame.stream_id);
let hw = self.in_hw.get(&key).copied().unwrap_or(0);
if frame.sequence_no <= hw {
self.emit_err_spec(
codes::REPLAY_DETECTED,
format!(
"st={} sid={} seq={} hw={}",
st, frame.stream_id, frame.sequence_no, hw
),
);
return Ok(());
}
self.in_hw.insert(key, frame.sequence_no);
let plain = match seal.open(st, frame.sequence_no, &frame.encrypted_payload) {
Ok(p) => p,
Err(e) => {
self.emit_err_named(
codes::DECRYPT_FAILED,
ErrorClass::Crypto,
true, false, format!("aead open: {e}"),
);
return Ok(());
}
};
match st {
StreamType::Control => self.handle_control(plain),
other => self.events.push(Event::PayloadReceived(DeliveredPayload {
stream_type: other,
stream_id: frame.stream_id,
sequence_no: frame.sequence_no,
flags: frame.flags,
plaintext: plain,
})),
}
Ok(())
}
fn handle_control(&mut self, plain: Vec<u8>) {
let c = match ControlMessage::from_cbor(&plain) {
Ok(c) => c,
Err(_) => {
self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "control decode");
return;
}
};
let opcode = match ControlOpcode::try_from(c.opcode) {
Ok(op) => op,
Err(_) => {
self.emit_err_spec(codes::STREAM_POLICY_VIOLATION, "unknown opcode");
return;
}
};
let tid_ok = match opcode {
ControlOpcode::PrepareTransition => {
c.transition_id > self.last_transition_id
&& (self.pending_transition_id == 0
|| self.pending_transition_id == c.transition_id)
}
ControlOpcode::ReadyForTransition
| ControlOpcode::ExecuteTransition
| ControlOpcode::AbortTransition => {
self.pending_transition_id != 0
&& c.transition_id == self.pending_transition_id
}
_ => true,
};
if !tid_ok {
self.emit_err_spec(
codes::TRANSITION_MISMATCH,
format!(
"control tid={} not valid for {:?} (last={}, pending={})",
c.transition_id, opcode, self.last_transition_id, self.pending_transition_id
),
);
return;
}
match opcode {
ControlOpcode::PrepareTransition => {
self.pending_transition_id = c.transition_id;
self.transition_state = TransitionState::TPrepared;
}
ControlOpcode::ReadyForTransition => {
self.transition_state = TransitionState::TReady;
}
ControlOpcode::ExecuteTransition => {
self.apply_transition(c.transition_id);
}
ControlOpcode::AbortTransition => {
self.transition_state = TransitionState::TAborted;
self.pending_transition_id = 0;
}
ControlOpcode::GroupStateDigestResponse => {
if self.state == NodeState::Resyncing {
self.transition(NodeState::Active);
}
}
_ => {}
}
self.events.push(Event::Control {
from: c.sender_id,
opcode,
transition_id: c.transition_id,
request_id: c.request_id,
args: c.args.to_vec(),
});
}
pub fn apply_transition(&mut self, tid: TransitionId) {
self.current_epoch += 1;
self.last_transition_id = tid;
self.pending_transition_id = 0;
self.transition_state = TransitionState::TExecuted;
self.out_seq.clear();
self.in_hw.clear();
self.events.push(Event::EpochAdvanced {
epoch: self.current_epoch,
transition_id: tid,
});
}
pub fn trigger_resync(&mut self) {
if self.state != NodeState::Resyncing {
self.transition(NodeState::Resyncing);
}
}
fn transition(&mut self, next: NodeState) {
if self.state == next {
return;
}
if !self.state.can_transition_to(next) {
let from = self.state;
self.state = NodeState::Failed;
self.events.push(Event::StateChanged { from, to: NodeState::Failed });
return;
}
let from = self.state;
self.state = next;
self.events.push(Event::StateChanged { from, to: next });
}
fn assert_can_send(&self) -> Result<(), NodeError> {
if matches!(
self.state,
NodeState::Active | NodeState::Resyncing | NodeState::EstablishingGroup
) {
Ok(())
} else {
Err(NodeError::InvalidState(format!("cannot send in state {}", self.state)))
}
}
fn next_seq(&mut self, st: StreamType, sid: StreamId) -> SequenceNo {
let entry = self.out_seq.entry((st, sid)).or_insert(0);
*entry += 1;
*entry
}
fn emit_err_spec(&mut self, code: u16, reason: impl Into<String>) {
if let Some(spec) = ErrorSpec::lookup(code) {
self.emit_err_named(spec.code, spec.class, spec.retryable, spec.fatal, reason);
} else {
self.emit_err_named(code, ErrorClass::Policy, false, false, reason);
}
}
fn emit_err_named(
&mut self,
code: u16,
class: ErrorClass,
retryable: bool,
fatal: bool,
reason: impl Into<String>,
) {
let reason = reason.into();
let (class, retryable, fatal) = if let Some(spec) = ErrorSpec::lookup(code) {
(spec.class, spec.retryable, spec.fatal)
} else {
(class, retryable, fatal)
};
let _ = ErrorObject::new(code, class, retryable, fatal, reason.clone()).to_cbor();
self.events.push(Event::Error { code, class, retryable, fatal, reason });
if fatal {
let from = self.state;
self.state = NodeState::Failed;
self.events.push(Event::StateChanged { from, to: NodeState::Failed });
}
}
}
pub trait Sealer {
fn seal(&mut self, st: StreamType, seq: SequenceNo, pt: &[u8]) -> Result<Vec<u8>, MlsError>;
fn open(&mut self, st: StreamType, seq: SequenceNo, ct: &[u8]) -> Result<Vec<u8>, MlsError>;
}
impl Sealer for gbp_mls::MlsContext {
fn seal(&mut self, st: StreamType, seq: SequenceNo, pt: &[u8]) -> Result<Vec<u8>, MlsError> {
gbp_mls::MlsContext::seal(self, label_for(st), seq, pt)
}
fn open(&mut self, st: StreamType, seq: SequenceNo, ct: &[u8]) -> Result<Vec<u8>, MlsError> {
gbp_mls::MlsContext::open(self, label_for(st), seq, ct)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct PlainSealer;
impl Sealer for PlainSealer {
fn seal(&mut self, _st: StreamType, _seq: SequenceNo, pt: &[u8]) -> Result<Vec<u8>, MlsError> {
Ok(pt.to_vec())
}
fn open(&mut self, _st: StreamType, _seq: SequenceNo, ct: &[u8]) -> Result<Vec<u8>, MlsError> {
Ok(ct.to_vec())
}
}
fn group_id() -> GroupId {
let mut g = [0u8; 16];
g[..3].copy_from_slice(b"GBP");
g
}
#[test]
fn replay_window_rejects_repeat() {
let mut alice = GroupNode::new(1, group_id());
let mut bob = GroupNode::new(2, group_id());
alice.bootstrap_as_creator(1);
bob.bootstrap_as_joiner(1, 0);
let mut s = PlainSealer;
let sid = alice.member_stream_id(2);
let f = alice
.send_payload(&mut s, 2, StreamType::Text, sid, GbpFlags::ordered_reliable_ack(), b"hi")
.unwrap();
let _ = bob.on_wire(&mut s, &f.wire).unwrap();
let evs = bob.on_wire(&mut s, &f.wire).unwrap();
assert!(evs.iter().any(|e| matches!(
e, Event::Error { code: codes::REPLAY_DETECTED, .. }
)));
}
#[test]
fn epoch_mismatch_triggers_resync() {
let mut alice = GroupNode::new(1, group_id());
let mut bob = GroupNode::new(2, group_id());
alice.bootstrap_as_creator(1);
bob.bootstrap_as_joiner(1, 0);
alice.current_epoch = 2;
let mut s = PlainSealer;
let sid = alice.member_stream_id(2);
let f = alice
.send_payload(&mut s, 2, StreamType::Text, sid, GbpFlags::ordered_reliable_ack(), b"x")
.unwrap();
let _ = bob.on_wire(&mut s, &f.wire).unwrap();
assert_eq!(bob.state, NodeState::Resyncing);
}
#[test]
fn payload_emits_received_event() {
let mut alice = GroupNode::new(1, group_id());
let mut bob = GroupNode::new(2, group_id());
alice.bootstrap_as_creator(1);
bob.bootstrap_as_joiner(1, 0);
let mut s = PlainSealer;
let sid = alice.member_stream_id(2);
let f = alice
.send_payload(&mut s, 2, StreamType::Text, sid, GbpFlags::ordered_reliable_ack(), b"payload")
.unwrap();
let evs = bob.on_wire(&mut s, &f.wire).unwrap();
let pr = evs.into_iter().find_map(|e| match e {
Event::PayloadReceived(p) => Some(p),
_ => None,
}).expect("payload");
assert_eq!(pr.stream_type, StreamType::Text);
assert_eq!(pr.plaintext, b"payload");
}
fn drain_errs(events: &[Event]) -> Vec<u16> {
events.iter().filter_map(|e| match e {
Event::Error { code, .. } => Some(*code),
_ => None,
}).collect()
}
fn drain_controls(events: &[Event]) -> Vec<(ControlOpcode, TransitionId)> {
events.iter().filter_map(|e| match e {
Event::Control { opcode, transition_id, .. } => Some((*opcode, *transition_id)),
_ => None,
}).collect()
}
#[test]
fn prepare_transition_sets_pending_on_sender_and_receiver() {
let mut coord = GroupNode::new(1, group_id());
let mut peer = GroupNode::new(2, group_id());
coord.bootstrap_as_creator(0);
peer.bootstrap_as_joiner(0, 0);
let mut s = PlainSealer;
let f = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 100, b"commit-blob".to_vec()).unwrap();
assert_eq!(coord.pending_transition_id, 1, "sender mirrors pending");
assert_eq!(coord.transition_state, TransitionState::TPrepared);
let evs = peer.on_wire(&mut s, &f.wire).unwrap();
assert_eq!(peer.pending_transition_id, 1, "receiver records pending");
assert!(drain_errs(&evs).is_empty(), "no error: {:?}", drain_errs(&evs));
let ctls = drain_controls(&evs);
assert_eq!(ctls, vec![(ControlOpcode::PrepareTransition, 1)]);
}
#[test]
fn ready_with_wrong_tid_is_rejected() {
let mut coord = GroupNode::new(1, group_id());
let mut peer = GroupNode::new(2, group_id());
coord.bootstrap_as_creator(0);
peer.bootstrap_as_joiner(0, 0);
let mut s = PlainSealer;
let f = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
peer.on_wire(&mut s, &f.wire).unwrap();
let bogus = peer.send_control(&mut s, 1, ControlOpcode::ReadyForTransition, 7, 1, vec![]).unwrap();
let evs = coord.on_wire(&mut s, &bogus.wire).unwrap();
let errs = drain_errs(&evs);
assert!(errs.contains(&codes::TRANSITION_MISMATCH), "got {:?}", errs);
}
#[test]
fn execute_advances_epoch_and_clears_pending() {
let mut coord = GroupNode::new(1, group_id());
let mut peer = GroupNode::new(2, group_id());
coord.bootstrap_as_creator(0);
peer.bootstrap_as_joiner(0, 0);
let mut s = PlainSealer;
let prep = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
peer.on_wire(&mut s, &prep.wire).unwrap();
let exec = coord.send_control(&mut s, 0, ControlOpcode::ExecuteTransition, 1, 2, vec![]).unwrap();
coord.apply_transition(1);
let evs = peer.on_wire(&mut s, &exec.wire).unwrap();
assert_eq!(coord.last_transition_id, 1);
assert_eq!(coord.current_epoch, 1);
assert_eq!(peer.last_transition_id, 1);
assert_eq!(peer.current_epoch, 1);
assert_eq!(peer.pending_transition_id, 0);
assert!(evs.iter().any(|e| matches!(e, Event::EpochAdvanced { transition_id: 1, .. })));
}
#[test]
fn abort_clears_pending_no_advance() {
let mut coord = GroupNode::new(1, group_id());
let mut peer = GroupNode::new(2, group_id());
coord.bootstrap_as_creator(0);
peer.bootstrap_as_joiner(0, 0);
let mut s = PlainSealer;
let prep = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
peer.on_wire(&mut s, &prep.wire).unwrap();
let abort = coord.send_control(&mut s, 0, ControlOpcode::AbortTransition, 1, 2, vec![]).unwrap();
peer.on_wire(&mut s, &abort.wire).unwrap();
assert_eq!(peer.pending_transition_id, 0);
assert_eq!(peer.current_epoch, 0);
assert_eq!(peer.transition_state, TransitionState::TAborted);
assert_eq!(coord.transition_state, TransitionState::TAborted);
}
#[test]
fn bootstrap_as_joiner_with_expected_tid_accepts_first_execute() {
let mut coord = GroupNode::new(1, group_id());
let mut joiner = GroupNode::new(2, group_id());
coord.bootstrap_as_creator(0);
joiner.bootstrap_as_joiner(0, 1);
assert_eq!(joiner.pending_transition_id, 1);
let mut s = PlainSealer;
let _ = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
let exec = coord.send_control(&mut s, 0, ControlOpcode::ExecuteTransition, 1, 2, vec![]).unwrap();
let evs = joiner.on_wire(&mut s, &exec.wire).unwrap();
let errs = drain_errs(&evs);
assert!(errs.is_empty(), "expected clean apply, got errors {:?}", errs);
assert_eq!(joiner.last_transition_id, 1);
assert_eq!(joiner.current_epoch, 1);
}
#[test]
fn prepare_with_already_applied_tid_is_rejected() {
let mut coord = GroupNode::new(1, group_id());
coord.bootstrap_as_creator(0);
let mut s = PlainSealer;
let _ = coord.send_control(&mut s, 0, ControlOpcode::PrepareTransition, 1, 1, vec![]).unwrap();
coord.apply_transition(1);
assert_eq!(coord.last_transition_id, 1);
assert_eq!(coord.pending_transition_id, 0);
let mut peer = GroupNode::new(2, group_id());
peer.bootstrap_as_joiner(coord.current_epoch, 0);
let stale = peer.send_control(&mut s, 1, ControlOpcode::PrepareTransition, 1, 9, vec![]).unwrap();
let evs = coord.on_wire(&mut s, &stale.wire).unwrap();
let errs = drain_errs(&evs);
assert!(errs.contains(&codes::TRANSITION_MISMATCH), "expected TRANSITION_MISMATCH, got {:?}", errs);
}
#[test]
fn decrypt_failed_is_non_fatal() {
struct OpenFailSealer;
impl Sealer for OpenFailSealer {
fn seal(&mut self, _: StreamType, _: SequenceNo, p: &[u8]) -> Result<Vec<u8>, MlsError> { Ok(p.to_vec()) }
fn open(&mut self, _: StreamType, _: SequenceNo, _: &[u8]) -> Result<Vec<u8>, MlsError> { Err(MlsError::Aead("simulated".into())) }
}
let mut alice = GroupNode::new(1, group_id());
let mut bob = GroupNode::new(2, group_id());
alice.bootstrap_as_creator(1);
bob.bootstrap_as_joiner(1, 0);
let mut s = PlainSealer;
let sid = alice.member_stream_id(2);
let f = alice.send_payload(&mut s, 2, StreamType::Text, sid, GbpFlags::ordered_reliable_ack(), b"x").unwrap();
let mut fail = OpenFailSealer;
let evs = bob.on_wire(&mut fail, &f.wire).unwrap();
let err = evs.iter().find_map(|e| match e {
Event::Error { code, fatal, retryable, .. } => Some((*code, *fatal, *retryable)),
_ => None,
}).expect("error event");
assert_eq!(err.0, codes::DECRYPT_FAILED);
assert!(!err.1, "must be non-fatal");
assert!(err.2, "must be retryable");
assert_eq!(bob.state, NodeState::Active, "bob stays Active");
}
}