Skip to main content

hypercore_protocol/
message.rs

1use crate::schema::*;
2use compact_encoding::{
3    CompactEncoding, EncodingError, EncodingErrorKind, VecEncodable, decode_usize, take_array,
4    write_array,
5};
6use pretty_hash::fmt as pretty_fmt;
7use std::{fmt, io};
8use tracing::{debug, instrument, trace, warn};
9
10const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1];
11const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3];
12const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0];
13const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0];
14
15#[instrument(skip_all err)]
16pub(crate) fn decode_unframed_channel_messages(
17    buf: &[u8],
18) -> Result<(Vec<ChannelMessage>, usize), io::Error> {
19    let og_len = buf.len();
20    if og_len >= 3 && buf[0] == 0x00 {
21        // batch of NOT open/close messages
22        if buf[1] == 0x00 {
23            let (_, mut buf) = take_array::<2>(buf)?;
24            // Batch of messages
25            let mut messages: Vec<ChannelMessage> = vec![];
26
27            // First, there is the original channel
28            let mut current_channel;
29            (current_channel, buf) = u64::decode(buf)?;
30            while !buf.is_empty() {
31                // Length of the message is inbetween here
32                let channel_message_length;
33                (channel_message_length, buf) = decode_usize(buf)?;
34                if channel_message_length > buf.len() {
35                    return Err(io::Error::new(
36                        io::ErrorKind::InvalidData,
37                        format!(
38                            "received invalid message length: [{channel_message_length}]
39\tbut we have [{}] remaining bytes.
40\tInitial buffer size [{og_len}]",
41                            buf.len()
42                        ),
43                    ));
44                }
45                // Then the actual message
46                let channel_message;
47                let bl = buf.len();
48                (channel_message, buf) = ChannelMessage::decode_with_channel(buf, current_channel)?;
49                trace!(
50                    "Decoded ChannelMessage::{:?} using [{} bytes]",
51                    channel_message.message,
52                    bl - buf.len()
53                );
54                messages.push(channel_message);
55                // After that, if there is an extra 0x00, that means the channel
56                // changed. This works because of LE encoding, and channels starting
57                // from the index 1.
58                if !buf.is_empty() && buf[0] == 0x00 {
59                    (current_channel, buf) = u64::decode(buf)?;
60                }
61            }
62            Ok((messages, og_len - buf.len()))
63        } else if buf[1] == 0x01 {
64            // Open message
65            let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?;
66            Ok((vec![channel_message], length + 2))
67        } else if buf[1] == 0x03 {
68            // Close message
69            let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?;
70            Ok((vec![channel_message], length + 2))
71        } else {
72            Err(io::Error::new(
73                io::ErrorKind::InvalidData,
74                "received invalid special message",
75            ))
76        }
77    } else if buf.len() >= 2 {
78        trace!("Decoding single ChannelMessage");
79        // Single message
80        let og_len = buf.len();
81        let (channel_message, buf) = ChannelMessage::decode_from_channel_and_message(buf)?;
82        Ok((vec![channel_message], og_len - buf.len()))
83    } else {
84        Err(io::Error::new(
85            io::ErrorKind::InvalidData,
86            format!("received too short message, {buf:?}"),
87        ))
88    }
89}
90
91fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result<usize, EncodingError> {
92    Ok(match messages {
93        [] => 0,
94        [msg] => match msg.message {
95            Message::Open(_) | Message::Close(_) => 2 + msg.encoded_size()?,
96            _ => msg.encoded_size()?,
97        },
98        msgs => {
99            let mut out = MULTI_MESSAGE_PREFIX.len();
100            let mut current_channel: u64 = messages[0].channel;
101            out += current_channel.encoded_size()?;
102            for message in msgs.iter() {
103                if message.channel != current_channel {
104                    // Channel changed, need to add a 0x00 in between and then the new
105                    // channel
106                    out += CHANNEL_CHANGE_SEPERATOR.len() + message.channel.encoded_size()?;
107                    current_channel = message.channel;
108                }
109                let message_length = message.message.encoded_size()?;
110                out += message_length + (message_length as u64).encoded_size()?;
111            }
112            out
113        }
114    })
115}
116
117/// A protocol message.
118#[derive(Debug, Clone, PartialEq)]
119#[expect(missing_docs)]
120pub enum Message {
121    Open(Open),
122    Close(Close),
123    Synchronize(Synchronize),
124    Request(Request),
125    Cancel(Cancel),
126    Data(Data),
127    NoData(NoData),
128    Want(Want),
129    Unwant(Unwant),
130    Bitfield(Bitfield),
131    Range(Range),
132    Extension(Extension),
133    /// A local signalling message never sent over the wire
134    LocalSignal((String, Vec<u8>)),
135}
136
137macro_rules! message_from {
138    ($($val:ident),+) => {
139        $(
140            impl From<$val> for Message {
141                fn from(value: $val) -> Self {
142                    Message::$val(value)
143                }
144            }
145        )*
146    }
147}
148message_from!(
149    Open,
150    Close,
151    Synchronize,
152    Request,
153    Cancel,
154    Data,
155    NoData,
156    Want,
157    Unwant,
158    Bitfield,
159    Range,
160    Extension
161);
162
163macro_rules! decode_message {
164    ($type:ty, $buf:expr) => {{
165        let (x, rest) = <$type>::decode($buf)?;
166        (Message::from(x), rest)
167    }};
168}
169
170impl CompactEncoding for Message {
171    fn encoded_size(&self) -> Result<usize, EncodingError> {
172        let typ_size = if let Self::Open(_) | Self::Close(_) = &self {
173            0
174        } else {
175            self.typ().encoded_size()?
176        };
177        let msg_size = match self {
178            Self::LocalSignal(_) => Ok(0),
179            Self::Open(x) => x.encoded_size(),
180            Self::Close(x) => x.encoded_size(),
181            Self::Synchronize(x) => x.encoded_size(),
182            Self::Request(x) => x.encoded_size(),
183            Self::Cancel(x) => x.encoded_size(),
184            Self::Data(x) => x.encoded_size(),
185            Self::NoData(x) => x.encoded_size(),
186            Self::Want(x) => x.encoded_size(),
187            Self::Unwant(x) => x.encoded_size(),
188            Self::Bitfield(x) => x.encoded_size(),
189            Self::Range(x) => x.encoded_size(),
190            Self::Extension(x) => x.encoded_size(),
191        }?;
192        Ok(typ_size + msg_size)
193    }
194
195    #[instrument(skip_all, fields(name = self.name()))]
196    fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> {
197        debug!("Encoding {self:?}");
198        let rest = if let Self::Open(_) | Self::Close(_) = &self {
199            buffer
200        } else {
201            self.typ().encode(buffer)?
202        };
203        match self {
204            Self::Open(x) => x.encode(rest),
205            Self::Close(x) => x.encode(rest),
206            Self::Synchronize(x) => x.encode(rest),
207            Self::Request(x) => x.encode(rest),
208            Self::Cancel(x) => x.encode(rest),
209            Self::Data(x) => x.encode(rest),
210            Self::NoData(x) => x.encode(rest),
211            Self::Want(x) => x.encode(rest),
212            Self::Unwant(x) => x.encode(rest),
213            Self::Bitfield(x) => x.encode(rest),
214            Self::Range(x) => x.encode(rest),
215            Self::Extension(x) => x.encode(rest),
216            Self::LocalSignal(_) => unimplemented!("do not encode LocalSignal"),
217        }
218    }
219
220    fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError>
221    where
222        Self: Sized,
223    {
224        let (typ, rest) = u64::decode(buffer)?;
225        Ok(match typ {
226            0 => decode_message!(Synchronize, rest),
227            1 => decode_message!(Request, rest),
228            2 => decode_message!(Cancel, rest),
229            3 => decode_message!(Data, rest),
230            4 => decode_message!(NoData, rest),
231            5 => decode_message!(Want, rest),
232            6 => decode_message!(Unwant, rest),
233            7 => decode_message!(Bitfield, rest),
234            8 => decode_message!(Range, rest),
235            9 => decode_message!(Extension, rest),
236            _ => {
237                return Err(EncodingError::new(
238                    EncodingErrorKind::InvalidData,
239                    &format!("Invalid message type to decode: {typ}"),
240                ));
241            }
242        })
243    }
244}
245impl Message {
246    /// Wire type of this message.
247    pub(crate) fn typ(&self) -> u64 {
248        match self {
249            Self::Synchronize(_) => 0,
250            Self::Request(_) => 1,
251            Self::Cancel(_) => 2,
252            Self::Data(_) => 3,
253            Self::NoData(_) => 4,
254            Self::Want(_) => 5,
255            Self::Unwant(_) => 6,
256            Self::Bitfield(_) => 7,
257            Self::Range(_) => 8,
258            Self::Extension(_) => 9,
259            value => unimplemented!("{} does not have a type", value),
260        }
261    }
262    /// Get the name of the message
263    pub fn name(&self) -> &'static str {
264        match self {
265            Message::Open(_) => "Open",
266            Message::Close(_) => "Close",
267            Message::Synchronize(_) => "Synchronize",
268            Message::Request(_) => "Request",
269            Message::Cancel(_) => "Cancel",
270            Message::Data(_) => "Data",
271            Message::NoData(_) => "NoData",
272            Message::Want(_) => "Want",
273            Message::Unwant(_) => "Unwant",
274            Message::Bitfield(_) => "Bitfield",
275            Message::Range(_) => "Range",
276            Message::Extension(_) => "Extension",
277            Message::LocalSignal(_) => "LocalSignal",
278        }
279    }
280}
281
282impl fmt::Display for Message {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        match self {
285            Self::Open(msg) => write!(
286                f,
287                "Open(discovery_key: {}, capability <{}>)",
288                pretty_fmt(&msg.discovery_key).unwrap(),
289                msg.capability.as_ref().map_or(0, |c| c.len())
290            ),
291            Self::Data(msg) => write!(
292                f,
293                "Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})",
294                msg.request,
295                msg.fork,
296                msg.block.is_some(),
297                msg.hash.is_some(),
298                msg.seek.is_some(),
299                msg.upgrade.is_some(),
300            ),
301            _ => write!(f, "{:?}", &self),
302        }
303    }
304}
305
306/// A message on a channel.
307#[derive(Clone)]
308pub(crate) struct ChannelMessage {
309    pub(crate) channel: u64,
310    pub(crate) message: Message,
311}
312
313impl PartialEq for ChannelMessage {
314    fn eq(&self, other: &Self) -> bool {
315        self.channel == other.channel && self.message == other.message
316    }
317}
318
319impl fmt::Debug for ChannelMessage {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        write!(f, "ChannelMessage({}, {})", self.channel, self.message)
322    }
323}
324
325impl fmt::Display for ChannelMessage {
326    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327        write!(
328            f,
329            "ChannelMessage {{ channel {}, message {} }}",
330            self.channel,
331            self.message.name()
332        )
333    }
334}
335
336impl ChannelMessage {
337    /// Create a new message.
338    pub(crate) fn new(channel: u64, message: Message) -> Self {
339        Self { channel, message }
340    }
341
342    /// Consume self and return (channel, Message).
343    pub(crate) fn into_split(self) -> (u64, Message) {
344        (self.channel, self.message)
345    }
346
347    /// Decodes an open message for a channel message from a buffer.
348    ///
349    /// Note: `buf` has to have a valid length, and without the 3 LE
350    /// bytes in it
351    #[instrument(skip_all, err)]
352    pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> {
353        debug!("Decode ChannelMessage::Open");
354        let og_len = buf.len();
355        if og_len <= 5 {
356            return Err(io::Error::new(
357                io::ErrorKind::UnexpectedEof,
358                "received too short Open message",
359            ));
360        }
361
362        let (open_msg, buf) = Open::decode(buf)?;
363        Ok((
364            Self {
365                channel: open_msg.channel,
366                message: Message::Open(open_msg),
367            },
368            og_len - buf.len(),
369        ))
370    }
371
372    /// Decodes a close message for a channel message from a buffer.
373    ///
374    /// Note: `buf` has to have a valid length, and without the 3 LE
375    /// bytes in it
376    pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> {
377        debug!("Decode ChannelMessage::Close");
378        let og_len = buf.len();
379        if buf.is_empty() {
380            return Err(io::Error::new(
381                io::ErrorKind::UnexpectedEof,
382                "received too short Close message",
383            ));
384        }
385        let (close, buf) = Close::decode(buf)?;
386        Ok((
387            Self {
388                channel: close.channel,
389                message: Message::Close(close),
390            },
391            og_len - buf.len(),
392        ))
393    }
394
395    #[instrument(err, skip_all)]
396    pub(crate) fn decode_from_channel_and_message(
397        buf: &[u8],
398    ) -> Result<(Self, &[u8]), EncodingError> {
399        //<ChannelMessage as CompactEncoding>::decode(buf)
400        let (channel, buf) = u64::decode(buf)?;
401        let (message, buf) = <Message as CompactEncoding>::decode(buf)?;
402        debug!(
403            "Decode ChannelMessage{{ channel: {channel}, message: {} }}",
404            message.name()
405        );
406        Ok((Self { channel, message }, buf))
407    }
408    /// Decode a normal channel message from a buffer.
409    ///
410    /// Note: `buf` has to have a valid length, and without the 3 LE
411    /// bytes in it
412    #[instrument(err, skip(buf))]
413    pub(crate) fn decode_with_channel(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> {
414        if buf.len() <= 1 {
415            return Err(io::Error::new(
416                io::ErrorKind::UnexpectedEof,
417                format!("received empty message [{buf:?}]"),
418            ));
419        }
420        let (message, buf) = <Message as CompactEncoding>::decode(buf)?;
421        Ok((Self { channel, message }, buf))
422    }
423}
424
425/// NB: currently this is just for a standalone channel message. ChannelMessages in a vec decode &
426/// encode differently
427impl CompactEncoding for ChannelMessage {
428    fn encoded_size(&self) -> Result<usize, EncodingError> {
429        let channel_size = if let Message::Open(_) | Message::Close(_) = &self.message {
430            0
431        } else {
432            self.channel.encoded_size()?
433        };
434
435        Ok(channel_size + self.message.encoded_size()?)
436    }
437
438    fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> {
439        let rest = if let Message::Open(_) | Message::Close(_) = &self.message {
440            buffer
441        } else {
442            self.channel.encode(buffer)?
443        };
444        <Message as CompactEncoding>::encode(&self.message, rest)
445    }
446
447    fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError>
448    where
449        Self: Sized,
450    {
451        ChannelMessage::decode_from_channel_and_message(buffer)
452    }
453}
454
455impl VecEncodable for ChannelMessage {
456    #[instrument(skip_all, ret)]
457    fn vec_encoded_size(vec: &[Self]) -> Result<usize, EncodingError>
458    where
459        Self: Sized,
460    {
461        vec_channel_messages_encoded_size(vec)
462    }
463
464    #[instrument(skip_all, err)]
465    fn vec_encode<'a>(vec: &[Self], buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError>
466    where
467        Self: Sized,
468    {
469        let in_buf_len = buffer.len();
470        trace!(
471            "Vec<ChannelMessage>::encode to buf.len() = [{}]",
472            buffer.len()
473        );
474        let mut rest = buffer;
475        match vec {
476            [] => Ok(rest),
477            [msg] => {
478                rest = match msg.message {
479                    Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, rest)?,
480                    Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, rest)?,
481                    _ => msg.channel.encode(rest)?,
482                };
483                msg.message.encode(rest)
484            }
485            msgs => {
486                rest = write_array(&MULTI_MESSAGE_PREFIX, rest)?;
487                let mut current_channel: u64 = msgs[0].channel;
488                rest = current_channel.encode(rest)?;
489                for msg in msgs {
490                    if msg.channel != current_channel {
491                        rest = write_array(&CHANNEL_CHANGE_SEPERATOR, rest)?;
492                        rest = msg.channel.encode(rest)?;
493                        current_channel = msg.channel;
494                    }
495                    let msg_len = msg.message.encoded_size()?;
496                    rest = (msg_len as u64).encode(rest)?;
497                    rest = msg.message.encode(rest)?;
498                }
499                trace!("wrote [{}] bytes to buffer", in_buf_len - rest.len());
500                Ok(rest)
501            }
502        }
503    }
504
505    fn vec_decode(buffer: &[u8]) -> Result<(Vec<Self>, &[u8]), EncodingError>
506    where
507        Self: Sized,
508    {
509        let mut combined_messages: Vec<ChannelMessage> = vec![];
510        let mut rest = buffer;
511        while !rest.is_empty() {
512            let (msgs, length) = decode_unframed_channel_messages(rest)
513                .map_err(|e| EncodingError::external(&format!("{e}")))?;
514            rest = &rest[length..];
515            combined_messages.extend(msgs);
516        }
517        Ok((combined_messages, rest))
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use hypercore_schema::{
525        DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade,
526    };
527
528    macro_rules! message_enc_dec {
529        ($( $msg:expr ),*) => {
530            $(
531                let channel = rand::random::<u8>() as u64;
532                let channel_message = ChannelMessage::new(channel, $msg);
533                let encoded_size = channel_message.encoded_size()?;
534                let mut buf = vec![0u8; encoded_size];
535                let rest = <ChannelMessage as CompactEncoding>::encode(&channel_message, &mut buf)?;
536                assert!(rest.is_empty());
537                let (decoded, rest) = <ChannelMessage as CompactEncoding>::decode(&buf)?;
538                assert!(rest.is_empty());
539                assert_eq!(decoded, channel_message);
540            )*
541        }
542    }
543
544    #[test]
545    fn message_encode_decode() -> Result<(), EncodingError> {
546        message_enc_dec! {
547            Message::Synchronize(Synchronize{
548                fork: 0,
549                can_upgrade: true,
550                downloading: true,
551                uploading: true,
552                length: 5,
553                remote_length: 0,
554            }),
555            Message::Request(Request {
556                id: 1,
557                fork: 1,
558                block: Some(RequestBlock {
559                    index: 5,
560                    nodes: 10,
561                }),
562                hash: Some(RequestBlock {
563                    index: 20,
564                    nodes: 0
565                }),
566                seek: Some(RequestSeek {
567                    bytes: 10
568                }),
569                upgrade: Some(RequestUpgrade {
570                    start: 0,
571                    length: 10
572                }),
573                manifest: false,
574                priority: 0
575            }),
576            Message::Cancel(Cancel {
577                request: 1,
578            }),
579            Message::Data(Data{
580                request: 1,
581                fork: 5,
582                block: Some(DataBlock {
583                    index: 5,
584                    nodes: vec![Node::new(1, vec![0x01; 32], 100)],
585                    value: vec![0xFF; 10]
586                }),
587                hash: Some(DataHash {
588                    index: 20,
589                    nodes: vec![Node::new(2, vec![0x02; 32], 200)],
590                }),
591                seek: Some(DataSeek {
592                    bytes: 10,
593                    nodes: vec![Node::new(3, vec![0x03; 32], 300)],
594                }),
595                upgrade: Some(DataUpgrade {
596                    start: 0,
597                    length: 10,
598                    nodes: vec![Node::new(4, vec![0x04; 32], 400)],
599                    additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)],
600                    signature: vec![0xAB; 32]
601                })
602            }),
603            Message::NoData(NoData {
604                request: 2,
605            }),
606            Message::Want(Want {
607                start: 0,
608                length: 100,
609            }),
610            Message::Unwant(Unwant {
611                start: 10,
612                length: 2,
613            }),
614            Message::Bitfield(Bitfield {
615                start: 20,
616                bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF],
617            }),
618            Message::Range(Range {
619                drop: true,
620                start: 12345,
621                length: 100000
622            }),
623            Message::Extension(Extension {
624                name: "custom_extension/v1/open".to_string(),
625                message: vec![0x44, 20]
626            })
627        };
628        Ok(())
629    }
630
631    #[test]
632    fn enc_dec_vec_chan_message() -> Result<(), EncodingError> {
633        let one = Message::Synchronize(Synchronize {
634            fork: 0,
635            length: 4,
636            remote_length: 0,
637            downloading: true,
638            uploading: true,
639            can_upgrade: true,
640        });
641        let two = Message::Range(Range {
642            drop: false,
643            start: 0,
644            length: 4,
645        });
646        let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)];
647        let buff = msgs.to_encoded_bytes()?;
648        let (result, rest) = <Vec<ChannelMessage> as CompactEncoding>::decode(&buff)?;
649        assert!(rest.is_empty());
650        assert_eq!(result, msgs);
651        Ok(())
652    }
653}