1#![warn(clippy::missing_docs_in_private_items)]
3use std::{fmt, io};
4
5use crate::{wire, wire::varint, wire::varint::VarInt, wire::Message, Link, PROTOCOL_VERSION};
6
7pub const PROTOCOL_VERSION_STRING: Version = Version([b'r', b'a', b'd', PROTOCOL_VERSION]);
10
11const CONTROL_OPEN: u8 = 0;
13const CONTROL_CLOSE: u8 = 1;
15const CONTROL_EOF: u8 = 2;
17
18#[derive(Debug, PartialEq, Eq)]
20pub struct Version([u8; 4]);
21
22impl Version {
23 pub fn number(&self) -> u8 {
25 self.0[3]
26 }
27}
28
29impl wire::Encode for Version {
30 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
31 writer.write_all(&PROTOCOL_VERSION_STRING.0)?;
32
33 Ok(PROTOCOL_VERSION_STRING.0.len())
34 }
35}
36
37impl wire::Decode for Version {
38 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, wire::Error> {
39 let mut version = [0u8; 4];
40 reader.read_exact(&mut version[..])?;
41
42 if version != PROTOCOL_VERSION_STRING.0 {
43 return Err(wire::Error::InvalidProtocolVersion(version));
44 }
45 Ok(Self(version))
46 }
47}
48
49#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
80pub struct StreamId(VarInt);
81
82impl StreamId {
83 pub fn link(&self) -> Link {
85 let n = *self.0;
86 if 0b1 & n == 0 {
87 Link::Outbound
88 } else {
89 Link::Inbound
90 }
91 }
92
93 pub fn kind(&self) -> Result<StreamKind, u8> {
95 let id = *self.0;
96 let kind = ((id >> 1) & 0b11) as u8;
97
98 StreamKind::try_from(kind)
99 }
100
101 pub fn control(link: Link) -> Self {
103 let link = if link.is_outbound() { 0 } else { 1 };
104 Self(VarInt::from((StreamKind::Control as u8) << 1 | link))
105 }
106
107 pub fn gossip(link: Link) -> Self {
109 let link = if link.is_outbound() { 0 } else { 1 };
110 Self(VarInt::from((StreamKind::Gossip as u8) << 1 | link))
111 }
112
113 pub fn git(link: Link) -> Self {
115 let link = if link.is_outbound() { 0 } else { 1 };
116 Self(VarInt::from((StreamKind::Git as u8) << 1 | link))
117 }
118
119 pub fn nth(self, n: u64) -> Result<Self, varint::BoundsExceeded> {
121 let id = *self.0 + (n << 3);
122 VarInt::new(id).map(Self)
123 }
124}
125
126impl From<StreamId> for u64 {
127 fn from(value: StreamId) -> Self {
128 *value.0
129 }
130}
131
132impl From<StreamId> for VarInt {
133 fn from(value: StreamId) -> Self {
134 value.0
135 }
136}
137
138impl fmt::Display for StreamId {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 write!(f, "{}", *self.0)
141 }
142}
143
144impl wire::Decode for StreamId {
145 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, wire::Error> {
146 let id = VarInt::decode(reader)?;
147 Ok(Self(id))
148 }
149}
150
151impl wire::Encode for StreamId {
152 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
153 self.0.encode(writer)
154 }
155}
156
157#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
159#[repr(u8)]
160pub enum StreamKind {
161 Control = 0b00,
163 Gossip = 0b01,
165 Git = 0b10,
167}
168
169impl TryFrom<u8> for StreamKind {
170 type Error = u8;
171
172 fn try_from(value: u8) -> Result<Self, Self::Error> {
173 match value {
174 0b00 => Ok(StreamKind::Control),
175 0b01 => Ok(StreamKind::Gossip),
176 0b10 => Ok(StreamKind::Git),
177 n => Err(n),
178 }
179 }
180}
181
182#[derive(Debug, PartialEq, Eq)]
194pub struct Frame<M = Message> {
195 pub version: Version,
197 pub stream: StreamId,
199 pub data: FrameData<M>,
201}
202
203impl<M> Frame<M> {
204 pub fn git(stream: StreamId, data: Vec<u8>) -> Self {
206 Self {
207 version: PROTOCOL_VERSION_STRING,
208 stream,
209 data: FrameData::Git(data),
210 }
211 }
212
213 pub fn control(link: Link, ctrl: Control) -> Self {
215 Self {
216 version: PROTOCOL_VERSION_STRING,
217 stream: StreamId::control(link),
218 data: FrameData::Control(ctrl),
219 }
220 }
221
222 pub fn gossip(link: Link, msg: M) -> Self {
224 Self {
225 version: PROTOCOL_VERSION_STRING,
226 stream: StreamId::gossip(link),
227 data: FrameData::Gossip(msg),
228 }
229 }
230}
231
232impl<M: wire::Encode> Frame<M> {
233 pub fn to_bytes(&self) -> Vec<u8> {
235 wire::serialize(self)
236 }
237}
238
239#[derive(Debug, PartialEq, Eq)]
241pub enum FrameData<M> {
242 Control(Control),
244 Gossip(M),
246 Git(Vec<u8>),
248}
249
250#[derive(Debug, PartialEq, Eq)]
252pub enum Control {
253 Open {
255 stream: StreamId,
257 },
258 Close {
260 stream: StreamId,
262 },
263 Eof {
267 stream: StreamId,
269 },
270}
271
272impl wire::Decode for Control {
273 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, wire::Error> {
274 let command = u8::decode(reader)?;
275 match command {
276 CONTROL_OPEN => {
277 let stream = StreamId::decode(reader)?;
278 Ok(Control::Open { stream })
279 }
280 CONTROL_CLOSE => {
281 let stream = StreamId::decode(reader)?;
282 Ok(Control::Close { stream })
283 }
284 CONTROL_EOF => {
285 let stream = StreamId::decode(reader)?;
286 Ok(Control::Eof { stream })
287 }
288 other => Err(wire::Error::InvalidControlMessage(other)),
289 }
290 }
291}
292
293impl wire::Encode for Control {
294 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
295 let mut n = 0;
296
297 match self {
298 Self::Open { stream: id } => {
299 n += CONTROL_OPEN.encode(writer)?;
300 n += id.encode(writer)?;
301 }
302 Self::Eof { stream: id } => {
303 n += CONTROL_EOF.encode(writer)?;
304 n += id.encode(writer)?;
305 }
306 Self::Close { stream: id } => {
307 n += CONTROL_CLOSE.encode(writer)?;
308 n += id.encode(writer)?;
309 }
310 }
311 Ok(n)
312 }
313}
314
315impl<M: wire::Decode> wire::Decode for Frame<M> {
316 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, wire::Error> {
317 let version = Version::decode(reader)?;
318 if version.number() != PROTOCOL_VERSION {
319 return Err(wire::Error::WrongProtocolVersion(version.number()));
320 }
321 let stream = StreamId::decode(reader)?;
322
323 match stream.kind() {
324 Ok(StreamKind::Control) => {
325 let ctrl = Control::decode(reader)?;
326 let frame = Frame {
327 version,
328 stream,
329 data: FrameData::Control(ctrl),
330 };
331 Ok(frame)
332 }
333 Ok(StreamKind::Gossip) => {
334 let data = varint::payload::decode(reader)?;
335 let mut cursor = io::Cursor::new(data);
336 let msg = M::decode(&mut cursor)?;
337 let frame = Frame {
338 version,
339 stream,
340 data: FrameData::Gossip(msg),
341 };
342
343 Ok(frame)
347 }
348 Ok(StreamKind::Git { .. }) => {
349 let data = varint::payload::decode(reader)?;
350 Ok(Frame::git(stream, data))
351 }
352 Err(n) => Err(wire::Error::InvalidStreamKind(n)),
353 }
354 }
355}
356
357impl<M: wire::Encode> wire::Encode for Frame<M> {
358 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
359 let mut n = 0;
360
361 n += self.version.encode(writer)?;
362 n += self.stream.encode(writer)?;
363 n += match &self.data {
364 FrameData::Control(ctrl) => ctrl.encode(writer)?,
365 FrameData::Git(data) => varint::payload::encode(data, writer)?,
366 FrameData::Gossip(msg) => varint::payload::encode(&wire::serialize(msg), writer)?,
367 };
368
369 Ok(n)
370 }
371}
372
373#[cfg(test)]
374mod test {
375 use super::*;
376
377 #[test]
378 fn test_stream_id() {
379 assert_eq!(StreamId(VarInt(0b000)).kind().unwrap(), StreamKind::Control);
380 assert_eq!(StreamId(VarInt(0b010)).kind().unwrap(), StreamKind::Gossip);
381 assert_eq!(StreamId(VarInt(0b100)).kind().unwrap(), StreamKind::Git);
382 assert_eq!(StreamId(VarInt(0b001)).link(), Link::Inbound);
383 assert_eq!(StreamId(VarInt(0b000)).link(), Link::Outbound);
384 assert_eq!(StreamId(VarInt(0b101)).link(), Link::Inbound);
385 assert_eq!(StreamId(VarInt(0b100)).link(), Link::Outbound);
386
387 assert_eq!(StreamId::git(Link::Outbound), StreamId(VarInt(0b100)));
388 assert_eq!(StreamId::control(Link::Outbound), StreamId(VarInt(0b000)));
389 assert_eq!(StreamId::gossip(Link::Outbound), StreamId(VarInt(0b010)));
390
391 assert_eq!(StreamId::git(Link::Inbound), StreamId(VarInt(0b101)));
392 assert_eq!(StreamId::control(Link::Inbound), StreamId(VarInt(0b001)));
393 assert_eq!(StreamId::gossip(Link::Inbound), StreamId(VarInt(0b011)));
394 }
395}