1#![warn(clippy::missing_docs_in_private_items)]
3use std::{fmt, io};
4
5use bytes::{Buf, BufMut};
6use radicle::node::Link;
7
8use crate::service::Message;
9use crate::{wire, wire::varint, wire::varint::VarInt, PROTOCOL_VERSION};
10
11pub const PROTOCOL_VERSION_STRING: Version = Version([b'r', b'a', b'd', PROTOCOL_VERSION]);
14
15const CONTROL_OPEN: u8 = 0;
17const CONTROL_CLOSE: u8 = 1;
19const CONTROL_EOF: u8 = 2;
21
22#[derive(Debug, PartialEq, Eq)]
24pub struct Version([u8; 4]);
25
26impl Version {
27 pub fn number(&self) -> u8 {
29 self.0[3]
30 }
31}
32
33impl wire::Encode for Version {
34 fn encode(&self, buf: &mut impl BufMut) {
35 buf.put_slice(&PROTOCOL_VERSION_STRING.0);
36 }
37}
38
39impl wire::Decode for Version {
40 fn decode(buf: &mut impl Buf) -> Result<Self, wire::Error> {
41 let mut version = [0u8; 4];
42
43 buf.try_copy_to_slice(&mut version[..])?;
44
45 if version != PROTOCOL_VERSION_STRING.0 {
46 return Err(wire::Error::InvalidProtocolVersion(version));
47 }
48 Ok(Self(version))
49 }
50}
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
83pub struct StreamId(VarInt);
84
85impl StreamId {
86 pub fn link(&self) -> Link {
88 let n = *self.0;
89 if 0b1 & n == 0 {
90 Link::Outbound
91 } else {
92 Link::Inbound
93 }
94 }
95
96 pub fn kind(&self) -> Result<StreamKind, u8> {
98 let id = *self.0;
99 let kind = ((id >> 1) & 0b11) as u8;
100
101 StreamKind::try_from(kind)
102 }
103
104 pub fn control(link: Link) -> Self {
106 let link = if link.is_outbound() { 0 } else { 1 };
107 Self(VarInt::from(((StreamKind::Control as u8) << 1) | link))
108 }
109
110 pub fn gossip(link: Link) -> Self {
112 let link = if link.is_outbound() { 0 } else { 1 };
113 Self(VarInt::from(((StreamKind::Gossip as u8) << 1) | link))
114 }
115
116 pub fn git(link: Link) -> Self {
118 let link = if link.is_outbound() { 0 } else { 1 };
119 Self(VarInt::from(((StreamKind::Git as u8) << 1) | link))
120 }
121
122 pub fn nth(self, n: u64) -> Result<Self, varint::BoundsExceeded> {
124 let id = *self.0 + (n << 3);
125 VarInt::new(id).map(Self)
126 }
127}
128
129impl From<StreamId> for u64 {
130 fn from(value: StreamId) -> Self {
131 *value.0
132 }
133}
134
135impl From<StreamId> for VarInt {
136 fn from(value: StreamId) -> Self {
137 value.0
138 }
139}
140
141impl fmt::Display for StreamId {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 write!(f, "{}", *self.0)
144 }
145}
146
147impl wire::Decode for StreamId {
148 fn decode(buf: &mut impl Buf) -> Result<Self, wire::Error> {
149 let id = VarInt::decode(buf)?;
150 Ok(Self(id))
151 }
152}
153
154impl wire::Encode for StreamId {
155 fn encode(&self, buf: &mut impl BufMut) {
156 self.0.encode(buf)
157 }
158}
159
160#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
162#[repr(u8)]
163pub enum StreamKind {
164 Control = 0b00,
166 Gossip = 0b01,
168 Git = 0b10,
170}
171
172impl TryFrom<u8> for StreamKind {
173 type Error = u8;
174
175 fn try_from(value: u8) -> Result<Self, Self::Error> {
176 match value {
177 0b00 => Ok(StreamKind::Control),
178 0b01 => Ok(StreamKind::Gossip),
179 0b10 => Ok(StreamKind::Git),
180 n => Err(n),
181 }
182 }
183}
184
185#[derive(Debug, PartialEq, Eq)]
199pub struct Frame<M = Message> {
200 pub version: Version,
202 pub stream: StreamId,
204 pub data: FrameData<M>,
206}
207
208impl<M> Frame<M> {
209 pub fn git(stream: StreamId, data: Vec<u8>) -> Self {
211 Self {
212 version: PROTOCOL_VERSION_STRING,
213 stream,
214 data: FrameData::Git(data),
215 }
216 }
217
218 pub fn control(link: Link, ctrl: Control) -> Self {
220 Self {
221 version: PROTOCOL_VERSION_STRING,
222 stream: StreamId::control(link),
223 data: FrameData::Control(ctrl),
224 }
225 }
226
227 pub fn gossip(link: Link, msg: M) -> Self {
229 Self {
230 version: PROTOCOL_VERSION_STRING,
231 stream: StreamId::gossip(link),
232 data: FrameData::Gossip(msg),
233 }
234 }
235}
236
237impl<M: wire::Encode> Frame<M> {
238 pub fn to_bytes(&self) -> Vec<u8> {
240 wire::serialize(self)
241 }
242}
243
244#[derive(Debug, PartialEq, Eq)]
246pub enum FrameData<M> {
247 Control(Control),
249 Gossip(M),
251 Git(Vec<u8>),
253}
254
255#[derive(Debug, PartialEq, Eq)]
257pub enum Control {
258 Open {
260 stream: StreamId,
262 },
263 Close {
265 stream: StreamId,
267 },
268 Eof {
272 stream: StreamId,
274 },
275}
276
277impl wire::Decode for Control {
278 fn decode(buf: &mut impl Buf) -> Result<Self, wire::Error> {
279 let command = u8::decode(buf)?;
280 match command {
281 CONTROL_OPEN => {
282 let stream = StreamId::decode(buf)?;
283 Ok(Control::Open { stream })
284 }
285 CONTROL_CLOSE => {
286 let stream = StreamId::decode(buf)?;
287 Ok(Control::Close { stream })
288 }
289 CONTROL_EOF => {
290 let stream = StreamId::decode(buf)?;
291 Ok(Control::Eof { stream })
292 }
293 other => Err(wire::Error::InvalidControlMessage(other)),
294 }
295 }
296}
297
298impl wire::Encode for Control {
299 fn encode(&self, buf: &mut impl BufMut) {
300 match self {
301 Self::Open { stream: id } => {
302 CONTROL_OPEN.encode(buf);
303 id.encode(buf);
304 }
305 Self::Eof { stream: id } => {
306 CONTROL_EOF.encode(buf);
307 id.encode(buf);
308 }
309 Self::Close { stream: id } => {
310 CONTROL_CLOSE.encode(buf);
311 id.encode(buf);
312 }
313 }
314 }
315}
316
317impl<M: wire::Decode> wire::Decode for Frame<M> {
318 fn decode(buf: &mut impl Buf) -> Result<Self, wire::Error> {
319 let version = Version::decode(buf)?;
320 if version.number() != PROTOCOL_VERSION {
321 return Err(wire::Error::WrongProtocolVersion(version.number()));
322 }
323 let stream = StreamId::decode(buf)?;
324
325 match stream.kind() {
326 Ok(StreamKind::Control) => {
327 let ctrl = Control::decode(buf)?;
328 let frame = Frame {
329 version,
330 stream,
331 data: FrameData::Control(ctrl),
332 };
333 Ok(frame)
334 }
335 Ok(StreamKind::Gossip) => {
336 let data = varint::payload::decode(buf)?;
337 let mut cursor = io::Cursor::new(data);
338 let msg = M::decode(&mut cursor)?;
339 let frame = Frame {
340 version,
341 stream,
342 data: FrameData::Gossip(msg),
343 };
344
345 Ok(frame)
349 }
350 Ok(StreamKind::Git) => {
351 let data = varint::payload::decode(buf)?;
352 Ok(Frame::git(stream, data))
353 }
354 Err(n) => Err(wire::Error::InvalidStreamKind(n)),
355 }
356 }
357}
358
359impl<M: wire::Encode> wire::Encode for Frame<M> {
360 fn encode(&self, buf: &mut impl BufMut) {
361 self.version.encode(buf);
362 self.stream.encode(buf);
363 match &self.data {
364 FrameData::Control(ctrl) => ctrl.encode(buf),
365 FrameData::Git(data) => varint::payload::encode(data, buf),
366 FrameData::Gossip(msg) => varint::payload::encode(&wire::serialize(msg), buf),
367 }
368 }
369}
370
371#[cfg(test)]
372mod test {
373 use super::*;
374
375 #[test]
376 fn test_stream_id() {
377 assert_eq!(StreamId(VarInt(0b000)).kind().unwrap(), StreamKind::Control);
378 assert_eq!(StreamId(VarInt(0b010)).kind().unwrap(), StreamKind::Gossip);
379 assert_eq!(StreamId(VarInt(0b100)).kind().unwrap(), StreamKind::Git);
380 assert_eq!(StreamId(VarInt(0b001)).link(), Link::Inbound);
381 assert_eq!(StreamId(VarInt(0b000)).link(), Link::Outbound);
382 assert_eq!(StreamId(VarInt(0b101)).link(), Link::Inbound);
383 assert_eq!(StreamId(VarInt(0b100)).link(), Link::Outbound);
384
385 assert_eq!(StreamId::git(Link::Outbound), StreamId(VarInt(0b100)));
386 assert_eq!(StreamId::control(Link::Outbound), StreamId(VarInt(0b000)));
387 assert_eq!(StreamId::gossip(Link::Outbound), StreamId(VarInt(0b010)));
388
389 assert_eq!(StreamId::git(Link::Inbound), StreamId(VarInt(0b101)));
390 assert_eq!(StreamId::control(Link::Inbound), StreamId(VarInt(0b001)));
391 assert_eq!(StreamId::gossip(Link::Inbound), StreamId(VarInt(0b011)));
392 }
393}