1use crate::{GspSignal, args::validate_args};
4use gbp::CodecError;
5use gbp_core::{BoundedSeen, GbpFlags, MemberId, PayloadCodec, SignalType, StreamType};
6use gbp_node::{GroupNode, NodeError, OutboundFrame, Sealer};
7use std::collections::HashSet;
8
9#[derive(Debug, thiserror::Error)]
11pub enum GspError {
12 #[error("decode: {0}")]
14 Decode(#[from] CodecError),
15 #[error("unknown signal_type: {0}")]
17 UnknownSignal(u32),
18 #[error("duplicate request_id: {0}")]
20 DuplicateRequest(u32),
21 #[error("bad args schema: {0}")]
23 BadSchema(&'static str),
24 #[error("node: {0}")]
26 Node(#[from] NodeError),
27}
28
29#[derive(Debug, Clone)]
32pub struct GspAccept {
33 pub signal: SignalType,
35 pub sender_id: MemberId,
37 pub role_claim: u32,
39 pub request_id: u32,
41}
42
43const GSP_SEEN_CAP: usize = 10_000;
45
46pub struct GspClient {
58 seen_requests: BoundedSeen<u32>,
59 pub muted: HashSet<MemberId>,
61 pub members: HashSet<MemberId>,
63 current_epoch: Option<u64>,
64}
65
66impl GspClient {
67 pub fn new() -> Self {
69 Self {
70 seen_requests: BoundedSeen::new(GSP_SEEN_CAP),
71 muted: HashSet::new(),
72 members: HashSet::new(),
73 current_epoch: None,
74 }
75 }
76
77 pub fn send<S: Sealer>(
81 &mut self,
82 node: &mut GroupNode,
83 seal: &mut S,
84 target: MemberId,
85 signal: SignalType,
86 role_claim: u32,
87 request_id: u32,
88 codec: PayloadCodec,
89 ) -> Result<OutboundFrame, GspError> {
90 self.send_with_args(node, seal, target, signal, role_claim, request_id, &[], codec)
91 }
92
93 pub fn send_with_args<S: Sealer>(
99 &mut self,
100 node: &mut GroupNode,
101 seal: &mut S,
102 target: MemberId,
103 signal: SignalType,
104 role_claim: u32,
105 request_id: u32,
106 args: &[u8],
107 codec: PayloadCodec,
108 ) -> Result<OutboundFrame, GspError> {
109 self.sync_epoch(node.current_epoch);
110 let mut sig = GspSignal::bare(signal as u32, request_id, node.member_id);
111 sig.role_claim = role_claim;
112 sig.args = serde_bytes::ByteBuf::from(args.to_vec());
113 sig.args_length = args.len() as u32;
114 let stream_id = node.member_stream_id(3);
115 Ok(node.send_payload(
116 seal,
117 target,
118 StreamType::Signal,
119 stream_id,
120 GbpFlags::ordered_reliable_ack(),
121 &sig.to_bytes(codec),
122 codec,
123 )?)
124 }
125
126 pub fn accept(
133 &mut self,
134 plaintext: &[u8],
135 current_epoch: u64,
136 codec: PayloadCodec,
137 ) -> Result<GspAccept, GspError> {
138 self.sync_epoch(current_epoch);
139 let s = GspSignal::from_bytes(plaintext, codec)?;
140 let signal = SignalType::try_from(s.signal_type).map_err(GspError::UnknownSignal)?;
141 validate_args(signal, &s.args).map_err(GspError::BadSchema)?;
143 if !self.seen_requests.insert(s.request_id) {
144 return Err(GspError::DuplicateRequest(s.request_id));
145 }
146 match signal {
147 SignalType::Join => {
148 self.members.insert(s.sender_id);
149 }
150 SignalType::Leave => {
151 self.members.remove(&s.sender_id);
152 self.muted.remove(&s.sender_id);
153 }
154 SignalType::Mute => {
155 self.muted.insert(s.sender_id);
156 }
157 SignalType::Unmute => {
158 self.muted.remove(&s.sender_id);
159 }
160 _ => {}
161 }
162 Ok(GspAccept {
163 signal,
164 sender_id: s.sender_id,
165 role_claim: s.role_claim,
166 request_id: s.request_id,
167 })
168 }
169
170 pub fn sync_epoch(&mut self, epoch: u64) {
174 if Some(epoch) != self.current_epoch {
175 self.seen_requests.clear();
176 self.current_epoch = Some(epoch);
177 }
178 }
179
180 pub fn reset(&mut self) {
182 self.seen_requests.clear();
183 self.current_epoch = None;
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::GspSignal;
191
192 fn encode_bare(signal: SignalType, request_id: u32, sender_id: u32) -> Vec<u8> {
193 GspSignal::bare(signal as u32, request_id, sender_id).to_cbor()
194 }
195
196 #[test]
197 fn join_adds_sender_to_members() {
198 let mut c = GspClient::new();
199 let payload = encode_bare(SignalType::Join, 1, 42);
200 let accept = c.accept(&payload, 0, PayloadCodec::Cbor).unwrap();
201 assert_eq!(accept.signal, SignalType::Join);
202 assert!(c.members.contains(&42));
203 }
204
205 #[test]
206 fn leave_removes_sender_from_members() {
207 let mut c = GspClient::new();
208 c.accept(&encode_bare(SignalType::Join, 1, 7), 0, PayloadCodec::Cbor).unwrap();
209 c.accept(&encode_bare(SignalType::Leave, 2, 7), 0, PayloadCodec::Cbor).unwrap();
210 assert!(!c.members.contains(&7));
211 }
212
213 #[test]
214 fn leave_also_removes_from_muted() {
215 let mut c = GspClient::new();
216 c.accept(&encode_bare(SignalType::Join, 1, 5), 0, PayloadCodec::Cbor).unwrap();
217 c.muted.insert(5); c.accept(&encode_bare(SignalType::Leave, 2, 5), 0, PayloadCodec::Cbor).unwrap();
219 assert!(!c.muted.contains(&5));
220 }
221
222 #[test]
223 fn duplicate_request_id_is_rejected() {
224 let mut c = GspClient::new();
225 c.accept(&encode_bare(SignalType::Join, 99, 1), 0, PayloadCodec::Cbor).unwrap();
226 let result = c.accept(&encode_bare(SignalType::Leave, 99, 1), 0, PayloadCodec::Cbor);
227 assert!(matches!(result, Err(GspError::DuplicateRequest(99))));
228 }
229
230 #[test]
231 fn epoch_advance_clears_request_seen_set() {
232 let mut c = GspClient::new();
233 let payload = encode_bare(SignalType::Join, 1, 10);
234 c.accept(&payload, 0, PayloadCodec::Cbor).unwrap();
235 let result = c.accept(&encode_bare(SignalType::Leave, 1, 10), 1, PayloadCodec::Cbor);
237 assert!(result.is_ok());
238 }
239
240 #[test]
241 fn reset_clears_state() {
242 let mut c = GspClient::new();
243 c.accept(&encode_bare(SignalType::Join, 1, 3), 0, PayloadCodec::Cbor).unwrap();
244 c.reset();
245 c.accept(&encode_bare(SignalType::Join, 1, 4), 0, PayloadCodec::Cbor).unwrap();
247 }
250
251 #[test]
252 fn unknown_signal_type_rejected() {
253 let mut c = GspClient::new();
254 let bad = GspSignal::bare(999, 1, 1).to_cbor();
255 assert!(matches!(
256 c.accept(&bad, 0, PayloadCodec::Cbor),
257 Err(GspError::UnknownSignal(999))
258 ));
259 }
260
261 #[test]
262 fn invalid_cbor_returns_decode_error() {
263 let mut c = GspClient::new();
264 assert!(matches!(
265 c.accept(b"\xFF\xFF", 0, PayloadCodec::Cbor),
266 Err(GspError::Decode(_))
267 ));
268 }
269
270 #[test]
271 fn multiple_members_join_independently() {
272 let mut c = GspClient::new();
273 c.accept(&encode_bare(SignalType::Join, 1, 10), 0, PayloadCodec::Cbor).unwrap();
274 c.accept(&encode_bare(SignalType::Join, 2, 20), 0, PayloadCodec::Cbor).unwrap();
275 c.accept(&encode_bare(SignalType::Join, 3, 30), 0, PayloadCodec::Cbor).unwrap();
276 assert_eq!(c.members.len(), 3);
277 assert!(c.members.contains(&10));
278 assert!(c.members.contains(&20));
279 assert!(c.members.contains(&30));
280 }
281}