1#![warn(clippy::missing_docs_in_private_items)]
3use std::{fmt, io};
4
5use bytes::{Buf, BufMut};
6use radicle::node::Link;
7
8use crate::service::Message;
9use crate::{wire, wire::varint, wire::varint::VarInt, PROTOCOL_VERSION};
10
11pub const PROTOCOL_VERSION_STRING: Version = Version([b'r', b'a', b'd', PROTOCOL_VERSION]);
14
15#[derive(Debug, PartialEq, Eq)]
17pub struct Version([u8; 4]);
18
19impl Version {
20 pub fn number(&self) -> u8 {
22 self.0[3]
23 }
24}
25
26impl wire::Encode for Version {
27 fn encode(&self, buf: &mut impl BufMut) {
28 buf.put_slice(&PROTOCOL_VERSION_STRING.0);
29 }
30}
31
32impl wire::Decode for Version {
33 fn decode(buf: &mut impl Buf) -> Result<Self, wire::Error> {
34 let mut version = [0u8; 4];
35
36 buf.try_copy_to_slice(&mut version[..])?;
37
38 if version != PROTOCOL_VERSION_STRING.0 {
39 return Err(wire::Invalid::ProtocolVersion { actual: version }.into());
40 }
41 Ok(Self(version))
42 }
43}
44
45#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
76pub struct StreamId(VarInt);
77
78impl StreamId {
79 pub fn link(&self) -> Link {
81 let n = *self.0;
82 if 0b1 & n == 0 {
83 Link::Outbound
84 } else {
85 Link::Inbound
86 }
87 }
88
89 pub fn kind(&self) -> Result<StreamType, u8> {
91 let id = *self.0;
92 let kind = ((id >> 1) & 0b11) as u8;
93
94 StreamType::try_from(kind)
95 }
96
97 pub fn control(link: Link) -> Self {
99 let link = if link.is_outbound() { 0 } else { 1 };
100 Self(VarInt::from(((u8::from(StreamType::Control)) << 1) | link))
101 }
102
103 pub fn gossip(link: Link) -> Self {
105 let link = if link.is_outbound() { 0 } else { 1 };
106 Self(VarInt::from((u8::from(StreamType::Gossip) << 1) | link))
107 }
108
109 pub fn git(link: Link) -> Self {
111 let link = if link.is_outbound() { 0 } else { 1 };
112 Self(VarInt::from((u8::from(StreamType::Git) << 1) | link))
113 }
114
115 pub fn nth(self, n: u64) -> Result<Self, varint::BoundsExceeded> {
117 let id = *self.0 + (n << 3);
118 VarInt::new(id).map(Self)
119 }
120}
121
122impl From<StreamId> for u64 {
123 fn from(value: StreamId) -> Self {
124 *value.0
125 }
126}
127
128impl From<StreamId> for VarInt {
129 fn from(value: StreamId) -> Self {
130 value.0
131 }
132}
133
134impl fmt::Display for StreamId {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 write!(f, "{}", *self.0)
137 }
138}
139
140impl wire::Decode for StreamId {
141 fn decode(buf: &mut impl Buf) -> Result<Self, wire::Error> {
142 let id = VarInt::decode(buf)?;
143 Ok(Self(id))
144 }
145}
146
147impl wire::Encode for StreamId {
148 fn encode(&self, buf: &mut impl BufMut) {
149 self.0.encode(buf)
150 }
151}
152
153#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
155#[repr(u8)]
156pub enum StreamType {
157 Control = 0b00,
159 Gossip = 0b01,
161 Git = 0b10,
163}
164
165impl TryFrom<u8> for StreamType {
166 type Error = u8;
167
168 fn try_from(value: u8) -> Result<Self, Self::Error> {
169 match value {
170 0b00 => Ok(StreamType::Control),
171 0b01 => Ok(StreamType::Gossip),
172 0b10 => Ok(StreamType::Git),
173 n => Err(n),
174 }
175 }
176}
177
178impl From<StreamType> for u8 {
179 fn from(value: StreamType) -> Self {
180 value as u8
181 }
182}
183
184#[derive(Debug, PartialEq, Eq)]
198pub struct Frame<M = Message> {
199 pub version: Version,
201 pub stream: StreamId,
203 pub data: FrameData<M>,
205}
206
207impl<M> Frame<M> {
208 pub fn git(stream: StreamId, data: Vec<u8>) -> Self {
210 Self {
211 version: PROTOCOL_VERSION_STRING,
212 stream,
213 data: FrameData::Git(data),
214 }
215 }
216
217 pub fn control(link: Link, ctrl: Control) -> Self {
219 Self {
220 version: PROTOCOL_VERSION_STRING,
221 stream: StreamId::control(link),
222 data: FrameData::Control(ctrl),
223 }
224 }
225
226 pub fn gossip(link: Link, msg: M) -> Self {
228 Self {
229 version: PROTOCOL_VERSION_STRING,
230 stream: StreamId::gossip(link),
231 data: FrameData::Gossip(msg),
232 }
233 }
234}
235
236#[derive(Debug, PartialEq, Eq)]
238pub enum FrameData<M> {
239 Control(Control),
241 Gossip(M),
243 Git(Vec<u8>),
245}
246
247#[derive(Debug, PartialEq, Eq)]
249pub enum Control {
250 Open {
252 stream: StreamId,
254 },
255 Close {
257 stream: StreamId,
259 },
260 Eof {
264 stream: StreamId,
266 },
267}
268
269#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
271#[repr(u8)]
272pub enum ControlType {
273 Open = 0,
275 Close = 1,
277 Eof = 2,
279}
280
281impl TryFrom<u8> for ControlType {
282 type Error = u8;
283
284 fn try_from(value: u8) -> Result<Self, Self::Error> {
285 match value {
286 0b00 => Ok(ControlType::Open),
287 0b01 => Ok(ControlType::Close),
288 0b10 => Ok(ControlType::Eof),
289 n => Err(n),
290 }
291 }
292}
293
294impl From<ControlType> for u8 {
295 fn from(value: ControlType) -> Self {
296 value as u8
297 }
298}
299
300impl wire::Decode for Control {
301 fn decode(buf: &mut impl Buf) -> Result<Self, wire::Error> {
302 match ControlType::try_from(u8::decode(buf)?) {
303 Ok(ControlType::Open) => Ok(Control::Open {
304 stream: StreamId::decode(buf)?,
305 }),
306 Ok(ControlType::Close) => Ok(Control::Close {
307 stream: StreamId::decode(buf)?,
308 }),
309 Ok(ControlType::Eof) => Ok(Control::Eof {
310 stream: StreamId::decode(buf)?,
311 }),
312 Err(other) => Err(wire::Invalid::ControlType { actual: other }.into()),
313 }
314 }
315}
316
317impl wire::Encode for Control {
318 fn encode(&self, buf: &mut impl BufMut) {
319 match self {
320 Self::Open { stream: id } => {
321 u8::from(ControlType::Open).encode(buf);
322 id.encode(buf);
323 }
324 Self::Eof { stream: id } => {
325 u8::from(ControlType::Eof).encode(buf);
326 id.encode(buf);
327 }
328 Self::Close { stream: id } => {
329 u8::from(ControlType::Close).encode(buf);
330 id.encode(buf);
331 }
332 }
333 }
334}
335
336impl<M: wire::Decode> wire::Decode for Frame<M> {
337 fn decode(buf: &mut impl Buf) -> Result<Self, wire::Error> {
338 let version = Version::decode(buf)?;
339 if version.number() != PROTOCOL_VERSION {
340 return Err(wire::Invalid::ProtocolVersionUnsupported {
341 actual: version.number(),
342 }
343 .into());
344 }
345 let stream = StreamId::decode(buf)?;
346
347 match stream.kind() {
348 Ok(StreamType::Control) => {
349 let ctrl = Control::decode(buf)?;
350 let frame = Frame {
351 version,
352 stream,
353 data: FrameData::Control(ctrl),
354 };
355 Ok(frame)
356 }
357 Ok(StreamType::Gossip) => {
358 let data = varint::payload::decode(buf)?;
359 let mut cursor = io::Cursor::new(data);
360 let msg = M::decode(&mut cursor)?;
361 let frame = Frame {
362 version,
363 stream,
364 data: FrameData::Gossip(msg),
365 };
366
367 Ok(frame)
371 }
372 Ok(StreamType::Git) => {
373 let data = varint::payload::decode(buf)?;
374 Ok(Frame::git(stream, data))
375 }
376 Err(n) => Err(wire::Invalid::StreamType { actual: n }.into()),
377 }
378 }
379}
380
381impl<M: wire::Encode> wire::Encode for Frame<M> {
382 fn encode(&self, buf: &mut impl BufMut) {
383 self.version.encode(buf);
384 self.stream.encode(buf);
385 match &self.data {
386 FrameData::Control(ctrl) => ctrl.encode(buf),
387 FrameData::Git(data) => varint::payload::encode(data, buf),
388 FrameData::Gossip(msg) => varint::payload::encode(&msg.encode_to_vec(), buf),
389 }
390 }
391}
392
393#[cfg(test)]
394mod test {
395 use super::*;
396
397 #[test]
398 fn test_stream_id() {
399 assert_eq!(StreamId(VarInt(0b000)).kind().unwrap(), StreamType::Control);
400 assert_eq!(StreamId(VarInt(0b010)).kind().unwrap(), StreamType::Gossip);
401 assert_eq!(StreamId(VarInt(0b100)).kind().unwrap(), StreamType::Git);
402 assert_eq!(StreamId(VarInt(0b001)).link(), Link::Inbound);
403 assert_eq!(StreamId(VarInt(0b000)).link(), Link::Outbound);
404 assert_eq!(StreamId(VarInt(0b101)).link(), Link::Inbound);
405 assert_eq!(StreamId(VarInt(0b100)).link(), Link::Outbound);
406
407 assert_eq!(StreamId::git(Link::Outbound), StreamId(VarInt(0b100)));
408 assert_eq!(StreamId::control(Link::Outbound), StreamId(VarInt(0b000)));
409 assert_eq!(StreamId::gossip(Link::Outbound), StreamId(VarInt(0b010)));
410
411 assert_eq!(StreamId::git(Link::Inbound), StreamId(VarInt(0b101)));
412 assert_eq!(StreamId::control(Link::Inbound), StreamId(VarInt(0b001)));
413 assert_eq!(StreamId::gossip(Link::Inbound), StreamId(VarInt(0b011)));
414 }
415
416 #[test]
417 fn test_encode_git_large() {
418 use wire::Encode as _;
419
420 let size = u16::MAX as usize * 3;
421 assert!(
422 size > (wire::Size::MAX as usize * 2),
423 "we want to test sizes that are way larger than any gossip message"
424 );
425
426 let a_lot_of_data = vec![0u8; size];
427
428 let frame: Frame<Message> = Frame::git(StreamId(0u8.into()), a_lot_of_data);
429
430 let bytes = frame.encode_to_vec();
432
433 assert!(
434 bytes.len() > wire::Size::MAX as usize * 2,
435 "just making sure that whatever was encoded is still quite large"
436 );
437 }
438}