1use crate::{GspSignal, args::validate_args};
4use gbp::CodecError;
5use gbp_core::{BoundedSeen, GbpFlags, MemberId, SignalType, StreamType};
6use gbp_node::{GroupNode, NodeError, OutboundFrame, Sealer};
7use std::collections::HashSet;
8
9#[derive(Debug, thiserror::Error)]
11pub enum GspError {
12 #[error("decode: {0}")]
14 Decode(#[from] CodecError),
15 #[error("unknown signal_type: {0}")]
17 UnknownSignal(u32),
18 #[error("duplicate request_id: {0}")]
20 DuplicateRequest(u32),
21 #[error("bad args schema: {0}")]
23 BadSchema(&'static str),
24 #[error("node: {0}")]
26 Node(#[from] NodeError),
27}
28
29#[derive(Debug, Clone)]
32pub struct GspAccept {
33 pub signal: SignalType,
35 pub sender_id: MemberId,
37 pub role_claim: u32,
39 pub request_id: u32,
41}
42
43const GSP_SEEN_CAP: usize = 10_000;
45
46pub struct GspClient {
58 seen_requests: BoundedSeen<u32>,
59 pub muted: HashSet<MemberId>,
61 pub members: HashSet<MemberId>,
63 current_epoch: Option<u64>,
64}
65
66impl GspClient {
67 pub fn new() -> Self {
69 Self {
70 seen_requests: BoundedSeen::new(GSP_SEEN_CAP),
71 muted: HashSet::new(),
72 members: HashSet::new(),
73 current_epoch: None,
74 }
75 }
76
77 pub fn send<S: Sealer>(
79 &mut self,
80 node: &mut GroupNode,
81 seal: &mut S,
82 target: MemberId,
83 signal: SignalType,
84 role_claim: u32,
85 request_id: u32,
86 ) -> Result<OutboundFrame, GspError> {
87 self.send_with_args(node, seal, target, signal, role_claim, request_id, &[])
88 }
89
90 pub fn send_with_args<S: Sealer>(
94 &mut self,
95 node: &mut GroupNode,
96 seal: &mut S,
97 target: MemberId,
98 signal: SignalType,
99 role_claim: u32,
100 request_id: u32,
101 args: &[u8],
102 ) -> Result<OutboundFrame, GspError> {
103 self.sync_epoch(node.current_epoch);
104 let mut sig = GspSignal::bare(signal as u32, request_id, node.member_id);
105 sig.role_claim = role_claim;
106 sig.args = serde_bytes::ByteBuf::from(args.to_vec());
107 sig.args_length = args.len() as u32;
108 let stream_id = node.member_stream_id(3);
109 Ok(node.send_payload(
110 seal,
111 target,
112 StreamType::Signal,
113 stream_id,
114 GbpFlags::ordered_reliable_ack(),
115 &sig.to_cbor(),
116 )?)
117 }
118
119 pub fn accept(&mut self, plaintext: &[u8], current_epoch: u64) -> Result<GspAccept, GspError> {
126 self.sync_epoch(current_epoch);
127 let s = GspSignal::from_cbor(plaintext)?;
128 let signal = SignalType::try_from(s.signal_type).map_err(GspError::UnknownSignal)?;
129 validate_args(signal, &s.args).map_err(GspError::BadSchema)?;
131 if !self.seen_requests.insert(s.request_id) {
132 return Err(GspError::DuplicateRequest(s.request_id));
133 }
134 match signal {
135 SignalType::Join => {
136 self.members.insert(s.sender_id);
137 }
138 SignalType::Leave => {
139 self.members.remove(&s.sender_id);
140 self.muted.remove(&s.sender_id);
141 }
142 SignalType::Mute => {
143 self.muted.insert(s.sender_id);
144 }
145 SignalType::Unmute => {
146 self.muted.remove(&s.sender_id);
147 }
148 _ => {}
149 }
150 Ok(GspAccept {
151 signal,
152 sender_id: s.sender_id,
153 role_claim: s.role_claim,
154 request_id: s.request_id,
155 })
156 }
157
158 pub fn sync_epoch(&mut self, epoch: u64) {
162 if Some(epoch) != self.current_epoch {
163 self.seen_requests.clear();
164 self.current_epoch = Some(epoch);
165 }
166 }
167
168 pub fn reset(&mut self) {
170 self.seen_requests.clear();
171 self.current_epoch = None;
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::GspSignal;
179
180 fn encode_bare(signal: SignalType, request_id: u32, sender_id: u32) -> Vec<u8> {
181 GspSignal::bare(signal as u32, request_id, sender_id).to_cbor()
182 }
183
184 #[test]
185 fn join_adds_sender_to_members() {
186 let mut c = GspClient::new();
187 let payload = encode_bare(SignalType::Join, 1, 42);
188 let accept = c.accept(&payload, 0).unwrap();
189 assert_eq!(accept.signal, SignalType::Join);
190 assert!(c.members.contains(&42));
191 }
192
193 #[test]
194 fn leave_removes_sender_from_members() {
195 let mut c = GspClient::new();
196 c.accept(&encode_bare(SignalType::Join, 1, 7), 0).unwrap();
197 c.accept(&encode_bare(SignalType::Leave, 2, 7), 0).unwrap();
198 assert!(!c.members.contains(&7));
199 }
200
201 #[test]
202 fn leave_also_removes_from_muted() {
203 let mut c = GspClient::new();
204 c.accept(&encode_bare(SignalType::Join, 1, 5), 0).unwrap();
205 c.muted.insert(5); c.accept(&encode_bare(SignalType::Leave, 2, 5), 0).unwrap();
207 assert!(!c.muted.contains(&5));
208 }
209
210 #[test]
211 fn duplicate_request_id_is_rejected() {
212 let mut c = GspClient::new();
213 c.accept(&encode_bare(SignalType::Join, 99, 1), 0).unwrap();
214 let result = c.accept(&encode_bare(SignalType::Leave, 99, 1), 0);
215 assert!(matches!(result, Err(GspError::DuplicateRequest(99))));
216 }
217
218 #[test]
219 fn epoch_advance_clears_request_seen_set() {
220 let mut c = GspClient::new();
221 let payload = encode_bare(SignalType::Join, 1, 10);
222 c.accept(&payload, 0).unwrap();
223 let result = c.accept(&encode_bare(SignalType::Leave, 1, 10), 1);
225 assert!(result.is_ok());
226 }
227
228 #[test]
229 fn reset_clears_state() {
230 let mut c = GspClient::new();
231 c.accept(&encode_bare(SignalType::Join, 1, 3), 0).unwrap();
232 c.reset();
233 c.accept(&encode_bare(SignalType::Join, 1, 4), 0).unwrap();
235 }
238
239 #[test]
240 fn unknown_signal_type_rejected() {
241 let mut c = GspClient::new();
242 let bad = GspSignal::bare(999, 1, 1).to_cbor();
243 assert!(matches!(
244 c.accept(&bad, 0),
245 Err(GspError::UnknownSignal(999))
246 ));
247 }
248
249 #[test]
250 fn invalid_cbor_returns_decode_error() {
251 let mut c = GspClient::new();
252 assert!(matches!(c.accept(b"\xFF\xFF", 0), Err(GspError::Decode(_))));
253 }
254
255 #[test]
256 fn multiple_members_join_independently() {
257 let mut c = GspClient::new();
258 c.accept(&encode_bare(SignalType::Join, 1, 10), 0).unwrap();
259 c.accept(&encode_bare(SignalType::Join, 2, 20), 0).unwrap();
260 c.accept(&encode_bare(SignalType::Join, 3, 30), 0).unwrap();
261 assert_eq!(c.members.len(), 3);
262 assert!(c.members.contains(&10));
263 assert!(c.members.contains(&20));
264 assert!(c.members.contains(&30));
265 }
266}