use crate::{GspSignal, args::validate_args};
use gbp::CodecError;
use gbp_core::{BoundedSeen, GbpFlags, MemberId, PayloadCodec, SignalType, StreamType};
use gbp_node::{GroupNode, NodeError, OutboundFrame, Sealer};
use std::collections::HashSet;
#[derive(Debug, thiserror::Error)]
pub enum GspError {
#[error("decode: {0}")]
Decode(#[from] CodecError),
#[error("unknown signal_type: {0}")]
UnknownSignal(u32),
#[error("duplicate request_id: {0}")]
DuplicateRequest(u32),
#[error("bad args schema: {0}")]
BadSchema(&'static str),
#[error("node: {0}")]
Node(#[from] NodeError),
}
#[derive(Debug, Clone)]
pub struct GspAccept {
pub signal: SignalType,
pub sender_id: MemberId,
pub role_claim: u32,
pub request_id: u32,
}
const GSP_SEEN_CAP: usize = 10_000;
pub struct GspClient {
seen_requests: BoundedSeen<u32>,
pub muted: HashSet<MemberId>,
pub members: HashSet<MemberId>,
current_epoch: Option<u64>,
}
impl GspClient {
pub fn new() -> Self {
Self {
seen_requests: BoundedSeen::new(GSP_SEEN_CAP),
muted: HashSet::new(),
members: HashSet::new(),
current_epoch: None,
}
}
pub fn send<S: Sealer>(
&mut self,
node: &mut GroupNode,
seal: &mut S,
target: MemberId,
signal: SignalType,
role_claim: u32,
request_id: u32,
codec: PayloadCodec,
) -> Result<OutboundFrame, GspError> {
self.send_with_args(node, seal, target, signal, role_claim, request_id, &[], codec)
}
pub fn send_with_args<S: Sealer>(
&mut self,
node: &mut GroupNode,
seal: &mut S,
target: MemberId,
signal: SignalType,
role_claim: u32,
request_id: u32,
args: &[u8],
codec: PayloadCodec,
) -> Result<OutboundFrame, GspError> {
self.sync_epoch(node.current_epoch);
let mut sig = GspSignal::bare(signal as u32, request_id, node.member_id);
sig.role_claim = role_claim;
sig.args = serde_bytes::ByteBuf::from(args.to_vec());
sig.args_length = args.len() as u32;
let stream_id = node.member_stream_id(3);
Ok(node.send_payload(
seal,
target,
StreamType::Signal,
stream_id,
GbpFlags::ordered_reliable_ack(),
&sig.to_bytes(codec),
codec,
)?)
}
pub fn accept(
&mut self,
plaintext: &[u8],
current_epoch: u64,
codec: PayloadCodec,
) -> Result<GspAccept, GspError> {
self.sync_epoch(current_epoch);
let s = GspSignal::from_bytes(plaintext, codec)?;
let signal = SignalType::try_from(s.signal_type).map_err(GspError::UnknownSignal)?;
validate_args(signal, &s.args).map_err(GspError::BadSchema)?;
if !self.seen_requests.insert(s.request_id) {
return Err(GspError::DuplicateRequest(s.request_id));
}
match signal {
SignalType::Join => {
self.members.insert(s.sender_id);
}
SignalType::Leave => {
self.members.remove(&s.sender_id);
self.muted.remove(&s.sender_id);
}
SignalType::Mute => {
self.muted.insert(s.sender_id);
}
SignalType::Unmute => {
self.muted.remove(&s.sender_id);
}
_ => {}
}
Ok(GspAccept {
signal,
sender_id: s.sender_id,
role_claim: s.role_claim,
request_id: s.request_id,
})
}
pub fn sync_epoch(&mut self, epoch: u64) {
if Some(epoch) != self.current_epoch {
self.seen_requests.clear();
self.current_epoch = Some(epoch);
}
}
pub fn reset(&mut self) {
self.seen_requests.clear();
self.current_epoch = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::GspSignal;
fn encode_bare(signal: SignalType, request_id: u32, sender_id: u32) -> Vec<u8> {
GspSignal::bare(signal as u32, request_id, sender_id).to_cbor()
}
#[test]
fn join_adds_sender_to_members() {
let mut c = GspClient::new();
let payload = encode_bare(SignalType::Join, 1, 42);
let accept = c.accept(&payload, 0, PayloadCodec::Cbor).unwrap();
assert_eq!(accept.signal, SignalType::Join);
assert!(c.members.contains(&42));
}
#[test]
fn leave_removes_sender_from_members() {
let mut c = GspClient::new();
c.accept(&encode_bare(SignalType::Join, 1, 7), 0, PayloadCodec::Cbor).unwrap();
c.accept(&encode_bare(SignalType::Leave, 2, 7), 0, PayloadCodec::Cbor).unwrap();
assert!(!c.members.contains(&7));
}
#[test]
fn leave_also_removes_from_muted() {
let mut c = GspClient::new();
c.accept(&encode_bare(SignalType::Join, 1, 5), 0, PayloadCodec::Cbor).unwrap();
c.muted.insert(5); c.accept(&encode_bare(SignalType::Leave, 2, 5), 0, PayloadCodec::Cbor).unwrap();
assert!(!c.muted.contains(&5));
}
#[test]
fn duplicate_request_id_is_rejected() {
let mut c = GspClient::new();
c.accept(&encode_bare(SignalType::Join, 99, 1), 0, PayloadCodec::Cbor).unwrap();
let result = c.accept(&encode_bare(SignalType::Leave, 99, 1), 0, PayloadCodec::Cbor);
assert!(matches!(result, Err(GspError::DuplicateRequest(99))));
}
#[test]
fn epoch_advance_clears_request_seen_set() {
let mut c = GspClient::new();
let payload = encode_bare(SignalType::Join, 1, 10);
c.accept(&payload, 0, PayloadCodec::Cbor).unwrap();
let result = c.accept(&encode_bare(SignalType::Leave, 1, 10), 1, PayloadCodec::Cbor);
assert!(result.is_ok());
}
#[test]
fn reset_clears_state() {
let mut c = GspClient::new();
c.accept(&encode_bare(SignalType::Join, 1, 3), 0, PayloadCodec::Cbor).unwrap();
c.reset();
c.accept(&encode_bare(SignalType::Join, 1, 4), 0, PayloadCodec::Cbor).unwrap();
}
#[test]
fn unknown_signal_type_rejected() {
let mut c = GspClient::new();
let bad = GspSignal::bare(999, 1, 1).to_cbor();
assert!(matches!(
c.accept(&bad, 0, PayloadCodec::Cbor),
Err(GspError::UnknownSignal(999))
));
}
#[test]
fn invalid_cbor_returns_decode_error() {
let mut c = GspClient::new();
assert!(matches!(
c.accept(b"\xFF\xFF", 0, PayloadCodec::Cbor),
Err(GspError::Decode(_))
));
}
#[test]
fn multiple_members_join_independently() {
let mut c = GspClient::new();
c.accept(&encode_bare(SignalType::Join, 1, 10), 0, PayloadCodec::Cbor).unwrap();
c.accept(&encode_bare(SignalType::Join, 2, 20), 0, PayloadCodec::Cbor).unwrap();
c.accept(&encode_bare(SignalType::Join, 3, 30), 0, PayloadCodec::Cbor).unwrap();
assert_eq!(c.members.len(), 3);
assert!(c.members.contains(&10));
assert!(c.members.contains(&20));
assert!(c.members.contains(&30));
}
}