1use crate::schema::*;
2use compact_encoding::{
3 CompactEncoding, EncodingError, EncodingErrorKind, VecEncodable, decode_usize, take_array,
4 write_array,
5};
6use pretty_hash::fmt as pretty_fmt;
7use std::{fmt, io};
8use tracing::{debug, instrument, trace, warn};
9
10const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1];
11const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3];
12const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0];
13const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0];
14
15#[instrument(skip_all err)]
16pub(crate) fn decode_unframed_channel_messages(
17 buf: &[u8],
18) -> Result<(Vec<ChannelMessage>, usize), io::Error> {
19 let og_len = buf.len();
20 if og_len >= 3 && buf[0] == 0x00 {
21 if buf[1] == 0x00 {
23 let (_, mut buf) = take_array::<2>(buf)?;
24 let mut messages: Vec<ChannelMessage> = vec![];
26
27 let mut current_channel;
29 (current_channel, buf) = u64::decode(buf)?;
30 while !buf.is_empty() {
31 let channel_message_length;
33 (channel_message_length, buf) = decode_usize(buf)?;
34 if channel_message_length > buf.len() {
35 return Err(io::Error::new(
36 io::ErrorKind::InvalidData,
37 format!(
38 "received invalid message length: [{channel_message_length}]
39\tbut we have [{}] remaining bytes.
40\tInitial buffer size [{og_len}]",
41 buf.len()
42 ),
43 ));
44 }
45 let channel_message;
47 let bl = buf.len();
48 (channel_message, buf) = ChannelMessage::decode_with_channel(buf, current_channel)?;
49 trace!(
50 "Decoded ChannelMessage::{:?} using [{} bytes]",
51 channel_message.message,
52 bl - buf.len()
53 );
54 messages.push(channel_message);
55 if !buf.is_empty() && buf[0] == 0x00 {
59 (current_channel, buf) = u64::decode(buf)?;
60 }
61 }
62 Ok((messages, og_len - buf.len()))
63 } else if buf[1] == 0x01 {
64 let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?;
66 Ok((vec![channel_message], length + 2))
67 } else if buf[1] == 0x03 {
68 let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?;
70 Ok((vec![channel_message], length + 2))
71 } else {
72 Err(io::Error::new(
73 io::ErrorKind::InvalidData,
74 "received invalid special message",
75 ))
76 }
77 } else if buf.len() >= 2 {
78 trace!("Decoding single ChannelMessage");
79 let og_len = buf.len();
81 let (channel_message, buf) = ChannelMessage::decode_from_channel_and_message(buf)?;
82 Ok((vec![channel_message], og_len - buf.len()))
83 } else {
84 Err(io::Error::new(
85 io::ErrorKind::InvalidData,
86 format!("received too short message, {buf:?}"),
87 ))
88 }
89}
90
91fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result<usize, EncodingError> {
92 Ok(match messages {
93 [] => 0,
94 [msg] => match msg.message {
95 Message::Open(_) | Message::Close(_) => 2 + msg.encoded_size()?,
96 _ => msg.encoded_size()?,
97 },
98 msgs => {
99 let mut out = MULTI_MESSAGE_PREFIX.len();
100 let mut current_channel: u64 = messages[0].channel;
101 out += current_channel.encoded_size()?;
102 for message in msgs.iter() {
103 if message.channel != current_channel {
104 out += CHANNEL_CHANGE_SEPERATOR.len() + message.channel.encoded_size()?;
107 current_channel = message.channel;
108 }
109 let message_length = message.message.encoded_size()?;
110 out += message_length + (message_length as u64).encoded_size()?;
111 }
112 out
113 }
114 })
115}
116
117#[derive(Debug, Clone, PartialEq)]
119#[expect(missing_docs)]
120pub enum Message {
121 Open(Open),
122 Close(Close),
123 Synchronize(Synchronize),
124 Request(Request),
125 Cancel(Cancel),
126 Data(Data),
127 NoData(NoData),
128 Want(Want),
129 Unwant(Unwant),
130 Bitfield(Bitfield),
131 Range(Range),
132 Extension(Extension),
133 LocalSignal((String, Vec<u8>)),
135}
136
137macro_rules! message_from {
138 ($($val:ident),+) => {
139 $(
140 impl From<$val> for Message {
141 fn from(value: $val) -> Self {
142 Message::$val(value)
143 }
144 }
145 )*
146 }
147}
148message_from!(
149 Open,
150 Close,
151 Synchronize,
152 Request,
153 Cancel,
154 Data,
155 NoData,
156 Want,
157 Unwant,
158 Bitfield,
159 Range,
160 Extension
161);
162
163macro_rules! decode_message {
164 ($type:ty, $buf:expr) => {{
165 let (x, rest) = <$type>::decode($buf)?;
166 (Message::from(x), rest)
167 }};
168}
169
170impl CompactEncoding for Message {
171 fn encoded_size(&self) -> Result<usize, EncodingError> {
172 let typ_size = if let Self::Open(_) | Self::Close(_) = &self {
173 0
174 } else {
175 self.typ().encoded_size()?
176 };
177 let msg_size = match self {
178 Self::LocalSignal(_) => Ok(0),
179 Self::Open(x) => x.encoded_size(),
180 Self::Close(x) => x.encoded_size(),
181 Self::Synchronize(x) => x.encoded_size(),
182 Self::Request(x) => x.encoded_size(),
183 Self::Cancel(x) => x.encoded_size(),
184 Self::Data(x) => x.encoded_size(),
185 Self::NoData(x) => x.encoded_size(),
186 Self::Want(x) => x.encoded_size(),
187 Self::Unwant(x) => x.encoded_size(),
188 Self::Bitfield(x) => x.encoded_size(),
189 Self::Range(x) => x.encoded_size(),
190 Self::Extension(x) => x.encoded_size(),
191 }?;
192 Ok(typ_size + msg_size)
193 }
194
195 #[instrument(skip_all, fields(name = self.name()))]
196 fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> {
197 debug!("Encoding {self:?}");
198 let rest = if let Self::Open(_) | Self::Close(_) = &self {
199 buffer
200 } else {
201 self.typ().encode(buffer)?
202 };
203 match self {
204 Self::Open(x) => x.encode(rest),
205 Self::Close(x) => x.encode(rest),
206 Self::Synchronize(x) => x.encode(rest),
207 Self::Request(x) => x.encode(rest),
208 Self::Cancel(x) => x.encode(rest),
209 Self::Data(x) => x.encode(rest),
210 Self::NoData(x) => x.encode(rest),
211 Self::Want(x) => x.encode(rest),
212 Self::Unwant(x) => x.encode(rest),
213 Self::Bitfield(x) => x.encode(rest),
214 Self::Range(x) => x.encode(rest),
215 Self::Extension(x) => x.encode(rest),
216 Self::LocalSignal(_) => unimplemented!("do not encode LocalSignal"),
217 }
218 }
219
220 fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError>
221 where
222 Self: Sized,
223 {
224 let (typ, rest) = u64::decode(buffer)?;
225 Ok(match typ {
226 0 => decode_message!(Synchronize, rest),
227 1 => decode_message!(Request, rest),
228 2 => decode_message!(Cancel, rest),
229 3 => decode_message!(Data, rest),
230 4 => decode_message!(NoData, rest),
231 5 => decode_message!(Want, rest),
232 6 => decode_message!(Unwant, rest),
233 7 => decode_message!(Bitfield, rest),
234 8 => decode_message!(Range, rest),
235 9 => decode_message!(Extension, rest),
236 _ => {
237 return Err(EncodingError::new(
238 EncodingErrorKind::InvalidData,
239 &format!("Invalid message type to decode: {typ}"),
240 ));
241 }
242 })
243 }
244}
245impl Message {
246 pub(crate) fn typ(&self) -> u64 {
248 match self {
249 Self::Synchronize(_) => 0,
250 Self::Request(_) => 1,
251 Self::Cancel(_) => 2,
252 Self::Data(_) => 3,
253 Self::NoData(_) => 4,
254 Self::Want(_) => 5,
255 Self::Unwant(_) => 6,
256 Self::Bitfield(_) => 7,
257 Self::Range(_) => 8,
258 Self::Extension(_) => 9,
259 value => unimplemented!("{} does not have a type", value),
260 }
261 }
262 pub fn name(&self) -> &'static str {
264 match self {
265 Message::Open(_) => "Open",
266 Message::Close(_) => "Close",
267 Message::Synchronize(_) => "Synchronize",
268 Message::Request(_) => "Request",
269 Message::Cancel(_) => "Cancel",
270 Message::Data(_) => "Data",
271 Message::NoData(_) => "NoData",
272 Message::Want(_) => "Want",
273 Message::Unwant(_) => "Unwant",
274 Message::Bitfield(_) => "Bitfield",
275 Message::Range(_) => "Range",
276 Message::Extension(_) => "Extension",
277 Message::LocalSignal(_) => "LocalSignal",
278 }
279 }
280}
281
282impl fmt::Display for Message {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 match self {
285 Self::Open(msg) => write!(
286 f,
287 "Open(discovery_key: {}, capability <{}>)",
288 pretty_fmt(&msg.discovery_key).unwrap(),
289 msg.capability.as_ref().map_or(0, |c| c.len())
290 ),
291 Self::Data(msg) => write!(
292 f,
293 "Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})",
294 msg.request,
295 msg.fork,
296 msg.block.is_some(),
297 msg.hash.is_some(),
298 msg.seek.is_some(),
299 msg.upgrade.is_some(),
300 ),
301 _ => write!(f, "{:?}", &self),
302 }
303 }
304}
305
306#[derive(Clone)]
308pub(crate) struct ChannelMessage {
309 pub(crate) channel: u64,
310 pub(crate) message: Message,
311}
312
313impl PartialEq for ChannelMessage {
314 fn eq(&self, other: &Self) -> bool {
315 self.channel == other.channel && self.message == other.message
316 }
317}
318
319impl fmt::Debug for ChannelMessage {
320 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321 write!(f, "ChannelMessage({}, {})", self.channel, self.message)
322 }
323}
324
325impl fmt::Display for ChannelMessage {
326 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327 write!(
328 f,
329 "ChannelMessage {{ channel {}, message {} }}",
330 self.channel,
331 self.message.name()
332 )
333 }
334}
335
336impl ChannelMessage {
337 pub(crate) fn new(channel: u64, message: Message) -> Self {
339 Self { channel, message }
340 }
341
342 pub(crate) fn into_split(self) -> (u64, Message) {
344 (self.channel, self.message)
345 }
346
347 #[instrument(skip_all, err)]
352 pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> {
353 debug!("Decode ChannelMessage::Open");
354 let og_len = buf.len();
355 if og_len <= 5 {
356 return Err(io::Error::new(
357 io::ErrorKind::UnexpectedEof,
358 "received too short Open message",
359 ));
360 }
361
362 let (open_msg, buf) = Open::decode(buf)?;
363 Ok((
364 Self {
365 channel: open_msg.channel,
366 message: Message::Open(open_msg),
367 },
368 og_len - buf.len(),
369 ))
370 }
371
372 pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> {
377 debug!("Decode ChannelMessage::Close");
378 let og_len = buf.len();
379 if buf.is_empty() {
380 return Err(io::Error::new(
381 io::ErrorKind::UnexpectedEof,
382 "received too short Close message",
383 ));
384 }
385 let (close, buf) = Close::decode(buf)?;
386 Ok((
387 Self {
388 channel: close.channel,
389 message: Message::Close(close),
390 },
391 og_len - buf.len(),
392 ))
393 }
394
395 #[instrument(err, skip_all)]
396 pub(crate) fn decode_from_channel_and_message(
397 buf: &[u8],
398 ) -> Result<(Self, &[u8]), EncodingError> {
399 let (channel, buf) = u64::decode(buf)?;
401 let (message, buf) = <Message as CompactEncoding>::decode(buf)?;
402 debug!(
403 "Decode ChannelMessage{{ channel: {channel}, message: {} }}",
404 message.name()
405 );
406 Ok((Self { channel, message }, buf))
407 }
408 #[instrument(err, skip(buf))]
413 pub(crate) fn decode_with_channel(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> {
414 if buf.len() <= 1 {
415 return Err(io::Error::new(
416 io::ErrorKind::UnexpectedEof,
417 format!("received empty message [{buf:?}]"),
418 ));
419 }
420 let (message, buf) = <Message as CompactEncoding>::decode(buf)?;
421 Ok((Self { channel, message }, buf))
422 }
423}
424
425impl CompactEncoding for ChannelMessage {
428 fn encoded_size(&self) -> Result<usize, EncodingError> {
429 let channel_size = if let Message::Open(_) | Message::Close(_) = &self.message {
430 0
431 } else {
432 self.channel.encoded_size()?
433 };
434
435 Ok(channel_size + self.message.encoded_size()?)
436 }
437
438 fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> {
439 let rest = if let Message::Open(_) | Message::Close(_) = &self.message {
440 buffer
441 } else {
442 self.channel.encode(buffer)?
443 };
444 <Message as CompactEncoding>::encode(&self.message, rest)
445 }
446
447 fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError>
448 where
449 Self: Sized,
450 {
451 ChannelMessage::decode_from_channel_and_message(buffer)
452 }
453}
454
455impl VecEncodable for ChannelMessage {
456 #[instrument(skip_all, ret)]
457 fn vec_encoded_size(vec: &[Self]) -> Result<usize, EncodingError>
458 where
459 Self: Sized,
460 {
461 vec_channel_messages_encoded_size(vec)
462 }
463
464 #[instrument(skip_all, err)]
465 fn vec_encode<'a>(vec: &[Self], buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError>
466 where
467 Self: Sized,
468 {
469 let in_buf_len = buffer.len();
470 trace!(
471 "Vec<ChannelMessage>::encode to buf.len() = [{}]",
472 buffer.len()
473 );
474 let mut rest = buffer;
475 match vec {
476 [] => Ok(rest),
477 [msg] => {
478 rest = match msg.message {
479 Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, rest)?,
480 Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, rest)?,
481 _ => msg.channel.encode(rest)?,
482 };
483 msg.message.encode(rest)
484 }
485 msgs => {
486 rest = write_array(&MULTI_MESSAGE_PREFIX, rest)?;
487 let mut current_channel: u64 = msgs[0].channel;
488 rest = current_channel.encode(rest)?;
489 for msg in msgs {
490 if msg.channel != current_channel {
491 rest = write_array(&CHANNEL_CHANGE_SEPERATOR, rest)?;
492 rest = msg.channel.encode(rest)?;
493 current_channel = msg.channel;
494 }
495 let msg_len = msg.message.encoded_size()?;
496 rest = (msg_len as u64).encode(rest)?;
497 rest = msg.message.encode(rest)?;
498 }
499 trace!("wrote [{}] bytes to buffer", in_buf_len - rest.len());
500 Ok(rest)
501 }
502 }
503 }
504
505 fn vec_decode(buffer: &[u8]) -> Result<(Vec<Self>, &[u8]), EncodingError>
506 where
507 Self: Sized,
508 {
509 let mut combined_messages: Vec<ChannelMessage> = vec![];
510 let mut rest = buffer;
511 while !rest.is_empty() {
512 let (msgs, length) = decode_unframed_channel_messages(rest)
513 .map_err(|e| EncodingError::external(&format!("{e}")))?;
514 rest = &rest[length..];
515 combined_messages.extend(msgs);
516 }
517 Ok((combined_messages, rest))
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use hypercore_schema::{
525 DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade,
526 };
527
528 macro_rules! message_enc_dec {
529 ($( $msg:expr ),*) => {
530 $(
531 let channel = rand::random::<u8>() as u64;
532 let channel_message = ChannelMessage::new(channel, $msg);
533 let encoded_size = channel_message.encoded_size()?;
534 let mut buf = vec![0u8; encoded_size];
535 let rest = <ChannelMessage as CompactEncoding>::encode(&channel_message, &mut buf)?;
536 assert!(rest.is_empty());
537 let (decoded, rest) = <ChannelMessage as CompactEncoding>::decode(&buf)?;
538 assert!(rest.is_empty());
539 assert_eq!(decoded, channel_message);
540 )*
541 }
542 }
543
544 #[test]
545 fn message_encode_decode() -> Result<(), EncodingError> {
546 message_enc_dec! {
547 Message::Synchronize(Synchronize{
548 fork: 0,
549 can_upgrade: true,
550 downloading: true,
551 uploading: true,
552 length: 5,
553 remote_length: 0,
554 }),
555 Message::Request(Request {
556 id: 1,
557 fork: 1,
558 block: Some(RequestBlock {
559 index: 5,
560 nodes: 10,
561 }),
562 hash: Some(RequestBlock {
563 index: 20,
564 nodes: 0
565 }),
566 seek: Some(RequestSeek {
567 bytes: 10
568 }),
569 upgrade: Some(RequestUpgrade {
570 start: 0,
571 length: 10
572 }),
573 manifest: false,
574 priority: 0
575 }),
576 Message::Cancel(Cancel {
577 request: 1,
578 }),
579 Message::Data(Data{
580 request: 1,
581 fork: 5,
582 block: Some(DataBlock {
583 index: 5,
584 nodes: vec![Node::new(1, vec![0x01; 32], 100)],
585 value: vec![0xFF; 10]
586 }),
587 hash: Some(DataHash {
588 index: 20,
589 nodes: vec![Node::new(2, vec![0x02; 32], 200)],
590 }),
591 seek: Some(DataSeek {
592 bytes: 10,
593 nodes: vec![Node::new(3, vec![0x03; 32], 300)],
594 }),
595 upgrade: Some(DataUpgrade {
596 start: 0,
597 length: 10,
598 nodes: vec![Node::new(4, vec![0x04; 32], 400)],
599 additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)],
600 signature: vec![0xAB; 32]
601 })
602 }),
603 Message::NoData(NoData {
604 request: 2,
605 }),
606 Message::Want(Want {
607 start: 0,
608 length: 100,
609 }),
610 Message::Unwant(Unwant {
611 start: 10,
612 length: 2,
613 }),
614 Message::Bitfield(Bitfield {
615 start: 20,
616 bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF],
617 }),
618 Message::Range(Range {
619 drop: true,
620 start: 12345,
621 length: 100000
622 }),
623 Message::Extension(Extension {
624 name: "custom_extension/v1/open".to_string(),
625 message: vec![0x44, 20]
626 })
627 };
628 Ok(())
629 }
630
631 #[test]
632 fn enc_dec_vec_chan_message() -> Result<(), EncodingError> {
633 let one = Message::Synchronize(Synchronize {
634 fork: 0,
635 length: 4,
636 remote_length: 0,
637 downloading: true,
638 uploading: true,
639 can_upgrade: true,
640 });
641 let two = Message::Range(Range {
642 drop: false,
643 start: 0,
644 length: 4,
645 });
646 let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)];
647 let buff = msgs.to_encoded_bytes()?;
648 let (result, rest) = <Vec<ChannelMessage> as CompactEncoding>::decode(&buff)?;
649 assert!(rest.is_empty());
650 assert_eq!(result, msgs);
651 Ok(())
652 }
653}