1use crate::schema::*;
2use crate::util::{stat_uint24_le, write_uint24_le};
3use hypercore::encoding::{
4 CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State,
5};
6use pretty_hash::fmt as pretty_fmt;
7use std::fmt;
8use std::io;
9
10#[derive(Debug, Clone, PartialEq)]
12pub(crate) enum FrameType {
13 Raw,
14 Message,
15}
16
17pub(crate) trait Encoder: Sized + fmt::Debug {
22 fn encoded_len(&mut self) -> Result<usize, EncodingError>;
24
25 fn encode(&mut self, buf: &mut [u8]) -> Result<usize, EncodingError>;
29}
30
31impl Encoder for &[u8] {
32 fn encoded_len(&mut self) -> Result<usize, EncodingError> {
33 Ok(self.len())
34 }
35
36 fn encode(&mut self, buf: &mut [u8]) -> Result<usize, EncodingError> {
37 let len = self.encoded_len()?;
38 if len > buf.len() {
39 return Err(EncodingError::new(
40 EncodingErrorKind::Overflow,
41 &format!("Length does not fit buffer, {} > {}", len, buf.len()),
42 ));
43 }
44 buf[..len].copy_from_slice(&self[..]);
45 Ok(len)
46 }
47}
48
49#[derive(Clone, PartialEq)]
51pub(crate) enum Frame {
52 RawBatch(Vec<Vec<u8>>),
54 MessageBatch(Vec<ChannelMessage>),
56}
57
58impl fmt::Debug for Frame {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 match self {
61 Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()),
62 Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"),
63 }
64 }
65}
66
67impl From<ChannelMessage> for Frame {
68 fn from(m: ChannelMessage) -> Self {
69 Self::MessageBatch(vec![m])
70 }
71}
72
73impl From<Vec<u8>> for Frame {
74 fn from(m: Vec<u8>) -> Self {
75 Self::RawBatch(vec![m])
76 }
77}
78
79impl Frame {
80 pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result<Self, io::Error> {
82 match frame_type {
83 FrameType::Raw => {
84 let mut index = 0;
85 let mut raw_batch: Vec<Vec<u8>> = vec![];
86 while index < buf.len() {
87 if buf[index] == 0 {
90 index += 1;
91 continue;
92 }
93 let stat = stat_uint24_le(&buf[index..]);
94 if let Some((header_len, body_len)) = stat {
95 raw_batch.push(
96 buf[index + header_len..index + header_len + body_len as usize]
97 .to_vec(),
98 );
99 index += header_len + body_len as usize;
100 } else {
101 return Err(io::Error::new(
102 io::ErrorKind::InvalidData,
103 "received invalid data in raw batch",
104 ));
105 }
106 }
107 Ok(Frame::RawBatch(raw_batch))
108 }
109 FrameType::Message => {
110 let mut index = 0;
111 let mut combined_messages: Vec<ChannelMessage> = vec![];
112 while index < buf.len() {
113 if buf[index] == 0 {
116 index += 1;
117 continue;
118 }
119
120 let stat = stat_uint24_le(&buf[index..]);
121 if let Some((header_len, body_len)) = stat {
122 let (frame, length) = Self::decode_message(
123 &buf[index + header_len..index + header_len + body_len as usize],
124 )?;
125 if length != body_len as usize {
126 tracing::warn!(
127 "Did not know what to do with all the bytes, got {} but decoded {}. \
128 This may be because the peer implements a newer protocol version \
129 that has extra fields.",
130 body_len,
131 length
132 );
133 }
134 if let Frame::MessageBatch(messages) = frame {
135 for message in messages {
136 combined_messages.push(message);
137 }
138 } else {
139 unreachable!("Can not get Raw messages");
140 }
141 index += header_len + body_len as usize;
142 } else {
143 return Err(io::Error::new(
144 io::ErrorKind::InvalidData,
145 "received invalid data in multi-message chunk",
146 ));
147 }
148 }
149 Ok(Frame::MessageBatch(combined_messages))
150 }
151 }
152 }
153
154 pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result<Self, io::Error> {
156 match frame_type {
157 FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])),
158 FrameType::Message => {
159 let (frame, _) = Self::decode_message(buf)?;
160 Ok(frame)
161 }
162 }
163 }
164
165 fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> {
166 if buf.len() >= 3 && buf[0] == 0x00 {
167 if buf[1] == 0x00 {
168 let mut messages: Vec<ChannelMessage> = vec![];
170 let mut state = State::new_with_start_and_end(2, buf.len());
171
172 let mut current_channel: u64 = state.decode(buf)?;
174 while state.start() < state.end() {
175 let channel_message_length: usize = state.decode(buf)?;
177 if state.start() + channel_message_length > state.end() {
178 return Err(io::Error::new(
179 io::ErrorKind::InvalidData,
180 format!(
181 "received invalid message length, {} + {} > {}",
182 state.start(),
183 channel_message_length,
184 state.end()
185 ),
186 ));
187 }
188 let (channel_message, _) = ChannelMessage::decode(
190 &buf[state.start()..state.start() + channel_message_length],
191 current_channel,
192 )?;
193 messages.push(channel_message);
194 state.add_start(channel_message_length)?;
195 if state.start() < state.end() && buf[state.start()] == 0x00 {
199 state.add_start(1)?;
200 current_channel = state.decode(buf)?;
201 }
202 }
203 Ok((Frame::MessageBatch(messages), state.start()))
204 } else if buf[1] == 0x01 {
205 let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?;
207 Ok((Frame::MessageBatch(vec![channel_message]), length + 2))
208 } else if buf[1] == 0x03 {
209 let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?;
211 Ok((Frame::MessageBatch(vec![channel_message]), length + 2))
212 } else {
213 Err(io::Error::new(
214 io::ErrorKind::InvalidData,
215 "received invalid special message",
216 ))
217 }
218 } else if buf.len() >= 2 {
219 let mut state = State::from_buffer(buf);
221 let channel: u64 = state.decode(buf)?;
222 let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?;
223 Ok((
224 Frame::MessageBatch(vec![channel_message]),
225 state.start() + length,
226 ))
227 } else {
228 Err(io::Error::new(
229 io::ErrorKind::InvalidData,
230 format!("received too short message, {buf:02X?}"),
231 ))
232 }
233 }
234
235 fn preencode(&mut self, state: &mut State) -> Result<usize, EncodingError> {
236 match self {
237 Self::RawBatch(raw_batch) => {
238 for raw in raw_batch {
239 state.add_end(raw.as_slice().encoded_len()?)?;
240 }
241 }
242 #[allow(clippy::comparison_chain)]
243 Self::MessageBatch(messages) => {
244 if messages.len() == 1 {
245 if let Message::Open(_) = &messages[0].message {
246 state.add_end(2 + &messages[0].encoded_len()?)?;
248 } else if let Message::Close(_) = &messages[0].message {
249 state.add_end(2 + &messages[0].encoded_len()?)?;
251 } else {
252 (*state).preencode(&messages[0].channel)?;
253 state.add_end(messages[0].encoded_len()?)?;
254 }
255 } else if messages.len() > 1 {
256 state.add_end(2)?;
258 let mut current_channel: u64 = messages[0].channel;
259 state.preencode(¤t_channel)?;
260 for message in messages.iter_mut() {
261 if message.channel != current_channel {
262 state.add_end(1)?;
265 state.preencode(&message.channel)?;
266 current_channel = message.channel;
267 }
268 let message_length = message.encoded_len()?;
269 state.preencode(&message_length)?;
270 state.add_end(message_length)?;
271 }
272 }
273 }
274 }
275 Ok(state.end())
276 }
277}
278
279impl Encoder for Frame {
280 fn encoded_len(&mut self) -> Result<usize, EncodingError> {
281 let body_len = self.preencode(&mut State::new())?;
282 match self {
283 Self::RawBatch(_) => Ok(body_len),
284 Self::MessageBatch(_) => Ok(3 + body_len),
285 }
286 }
287
288 fn encode(&mut self, buf: &mut [u8]) -> Result<usize, EncodingError> {
289 let mut state = State::new();
290 let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 };
291 let body_len = self.preencode(&mut state)?;
292 let len = body_len + header_len;
293 if buf.len() < len {
294 return Err(EncodingError::new(
295 EncodingErrorKind::Overflow,
296 &format!("Length does not fit buffer, {} > {}", len, buf.len()),
297 ));
298 }
299 match self {
300 Self::RawBatch(ref raw_batch) => {
301 for raw in raw_batch {
302 raw.as_slice().encode(buf)?;
303 }
304 }
305 #[allow(clippy::comparison_chain)]
306 Self::MessageBatch(ref mut messages) => {
307 write_uint24_le(body_len, buf);
308 let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes");
309 if messages.len() == 1 {
310 if let Message::Open(_) = &messages[0].message {
311 state.encode(&(0_u8), buf)?;
313 state.encode(&(1_u8), buf)?;
314 state.add_start(messages[0].encode(&mut buf[state.start()..])?)?;
315 } else if let Message::Close(_) = &messages[0].message {
316 state.encode(&(0_u8), buf)?;
318 state.encode(&(3_u8), buf)?;
319 state.add_start(messages[0].encode(&mut buf[state.start()..])?)?;
320 } else {
321 state.encode(&messages[0].channel, buf)?;
322 state.add_start(messages[0].encode(&mut buf[state.start()..])?)?;
323 }
324 } else if messages.len() > 1 {
325 state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?;
327 let mut current_channel: u64 = messages[0].channel;
328 state.encode(¤t_channel, buf)?;
329 for message in messages.iter_mut() {
330 if message.channel != current_channel {
331 state.encode(&(0_u8), buf)?;
334 state.encode(&message.channel, buf)?;
335 current_channel = message.channel;
336 }
337 let message_length = message.encoded_len()?;
338 state.encode(&message_length, buf)?;
339 state.add_start(message.encode(&mut buf[state.start()..])?)?;
340 }
341 }
342 }
343 };
344 Ok(len)
345 }
346}
347
348#[derive(Debug, Clone, PartialEq)]
350#[allow(missing_docs)]
351pub enum Message {
352 Open(Open),
353 Close(Close),
354 Synchronize(Synchronize),
355 Request(Request),
356 Cancel(Cancel),
357 Data(Data),
358 NoData(NoData),
359 Want(Want),
360 Unwant(Unwant),
361 Bitfield(Bitfield),
362 Range(Range),
363 Extension(Extension),
364 LocalSignal((String, Vec<u8>)),
366}
367
368impl Message {
369 pub(crate) fn typ(&self) -> u64 {
371 match self {
372 Self::Synchronize(_) => 0,
373 Self::Request(_) => 1,
374 Self::Cancel(_) => 2,
375 Self::Data(_) => 3,
376 Self::NoData(_) => 4,
377 Self::Want(_) => 5,
378 Self::Unwant(_) => 6,
379 Self::Bitfield(_) => 7,
380 Self::Range(_) => 8,
381 Self::Extension(_) => 9,
382 value => unimplemented!("{} does not have a type", value),
383 }
384 }
385
386 pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> {
388 let mut state = HypercoreState::from_buffer(buf);
389 let message = match typ {
390 0 => Ok(Self::Synchronize((*state).decode(buf)?)),
391 1 => Ok(Self::Request(state.decode(buf)?)),
392 2 => Ok(Self::Cancel((*state).decode(buf)?)),
393 3 => Ok(Self::Data(state.decode(buf)?)),
394 4 => Ok(Self::NoData((*state).decode(buf)?)),
395 5 => Ok(Self::Want((*state).decode(buf)?)),
396 6 => Ok(Self::Unwant((*state).decode(buf)?)),
397 7 => Ok(Self::Bitfield((*state).decode(buf)?)),
398 8 => Ok(Self::Range((*state).decode(buf)?)),
399 9 => Ok(Self::Extension((*state).decode(buf)?)),
400 _ => Err(EncodingError::new(
401 EncodingErrorKind::InvalidData,
402 &format!("Invalid message type to decode: {typ}"),
403 )),
404 }?;
405 Ok((message, state.start()))
406 }
407
408 pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result<usize, EncodingError> {
410 match self {
411 Self::Open(ref message) => state.0.preencode(message)?,
412 Self::Close(ref message) => state.0.preencode(message)?,
413 Self::Synchronize(ref message) => state.0.preencode(message)?,
414 Self::Request(ref message) => state.preencode(message)?,
415 Self::Cancel(ref message) => state.0.preencode(message)?,
416 Self::Data(ref message) => state.preencode(message)?,
417 Self::NoData(ref message) => state.0.preencode(message)?,
418 Self::Want(ref message) => state.0.preencode(message)?,
419 Self::Unwant(ref message) => state.0.preencode(message)?,
420 Self::Bitfield(ref message) => state.0.preencode(message)?,
421 Self::Range(ref message) => state.0.preencode(message)?,
422 Self::Extension(ref message) => state.0.preencode(message)?,
423 Self::LocalSignal(_) => 0,
424 };
425 Ok(state.end())
426 }
427
428 pub(crate) fn encode(
430 &self,
431 state: &mut HypercoreState,
432 buf: &mut [u8],
433 ) -> Result<usize, EncodingError> {
434 match self {
435 Self::Open(ref message) => state.0.encode(message, buf)?,
436 Self::Close(ref message) => state.0.encode(message, buf)?,
437 Self::Synchronize(ref message) => state.0.encode(message, buf)?,
438 Self::Request(ref message) => state.encode(message, buf)?,
439 Self::Cancel(ref message) => state.0.encode(message, buf)?,
440 Self::Data(ref message) => state.encode(message, buf)?,
441 Self::NoData(ref message) => state.0.encode(message, buf)?,
442 Self::Want(ref message) => state.0.encode(message, buf)?,
443 Self::Unwant(ref message) => state.0.encode(message, buf)?,
444 Self::Bitfield(ref message) => state.0.encode(message, buf)?,
445 Self::Range(ref message) => state.0.encode(message, buf)?,
446 Self::Extension(ref message) => state.0.encode(message, buf)?,
447 Self::LocalSignal(_) => 0,
448 };
449 Ok(state.start())
450 }
451}
452
453impl fmt::Display for Message {
454 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455 match self {
456 Self::Open(msg) => write!(
457 f,
458 "Open(discovery_key: {}, capability <{}>)",
459 pretty_fmt(&msg.discovery_key).unwrap(),
460 msg.capability.as_ref().map_or(0, |c| c.len())
461 ),
462 Self::Data(msg) => write!(
463 f,
464 "Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})",
465 msg.request,
466 msg.fork,
467 msg.block.is_some(),
468 msg.hash.is_some(),
469 msg.seek.is_some(),
470 msg.upgrade.is_some(),
471 ),
472 _ => write!(f, "{:?}", &self),
473 }
474 }
475}
476
477#[derive(Clone)]
479pub(crate) struct ChannelMessage {
480 pub(crate) channel: u64,
481 pub(crate) message: Message,
482 state: Option<HypercoreState>,
483}
484
485impl PartialEq for ChannelMessage {
486 fn eq(&self, other: &Self) -> bool {
487 self.channel == other.channel && self.message == other.message
488 }
489}
490
491impl fmt::Debug for ChannelMessage {
492 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
493 write!(f, "ChannelMessage({}, {})", self.channel, self.message)
494 }
495}
496
497impl ChannelMessage {
498 pub(crate) fn new(channel: u64, message: Message) -> Self {
500 Self {
501 channel,
502 message,
503 state: None,
504 }
505 }
506
507 pub(crate) fn into_split(self) -> (u64, Message) {
509 (self.channel, self.message)
510 }
511
512 pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> {
517 if buf.len() <= 5 {
518 return Err(io::Error::new(
519 io::ErrorKind::UnexpectedEof,
520 "received too short Open message",
521 ));
522 }
523
524 let mut state = State::new_with_start_and_end(0, buf.len());
525 let open_msg: Open = state.decode(buf)?;
526 Ok((
527 Self {
528 channel: open_msg.channel,
529 message: Message::Open(open_msg),
530 state: None,
531 },
532 state.start(),
533 ))
534 }
535
536 pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> {
541 if buf.is_empty() {
542 return Err(io::Error::new(
543 io::ErrorKind::UnexpectedEof,
544 "received too short Close message",
545 ));
546 }
547 let mut state = State::new_with_start_and_end(0, buf.len());
548 let close_msg: Close = state.decode(buf)?;
549 Ok((
550 Self {
551 channel: close_msg.channel,
552 message: Message::Close(close_msg),
553 state: None,
554 },
555 state.start(),
556 ))
557 }
558
559 pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> {
564 if buf.len() <= 1 {
565 return Err(io::Error::new(
566 io::ErrorKind::UnexpectedEof,
567 "received empty message",
568 ));
569 }
570 let mut state = State::from_buffer(buf);
571 let typ: u64 = state.decode(buf)?;
572 let (message, length) = Message::decode(&buf[state.start()..], typ)?;
573 Ok((
574 Self {
575 channel,
576 message,
577 state: None,
578 },
579 state.start() + length,
580 ))
581 }
582
583 fn prepare_state(&mut self) -> Result<(), EncodingError> {
586 if self.state.is_none() {
587 let state = if let Message::Open(_) = self.message {
588 let mut state = HypercoreState::new();
591 self.message.preencode(&mut state)?;
592 state
593 } else if let Message::Close(_) = self.message {
594 let mut state = HypercoreState::new();
597 self.message.preencode(&mut state)?;
598 state
599 } else {
600 let mut state = HypercoreState::new();
603 let typ = self.message.typ();
604 (*state).preencode(&typ)?;
605 self.message.preencode(&mut state)?;
606 state
607 };
608 self.state = Some(state);
609 }
610 Ok(())
611 }
612}
613
614impl Encoder for ChannelMessage {
615 fn encoded_len(&mut self) -> Result<usize, EncodingError> {
616 self.prepare_state()?;
617 Ok(self.state.as_ref().unwrap().end())
618 }
619
620 fn encode(&mut self, buf: &mut [u8]) -> Result<usize, EncodingError> {
621 self.prepare_state()?;
622 let state = self.state.as_mut().unwrap();
623 if let Message::Open(_) = self.message {
624 self.message.encode(state, buf)?;
626 } else if let Message::Close(_) = self.message {
627 self.message.encode(state, buf)?;
629 } else {
630 let typ = self.message.typ();
631 state.0.encode(&typ, buf)?;
632 self.message.encode(state, buf)?;
633 }
634 Ok(state.start())
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use hypercore::{
642 DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade,
643 };
644
645 macro_rules! message_enc_dec {
646 ($( $msg:expr ),*) => {
647 $(
648 let channel = rand::random::<u8>() as u64;
649 let mut channel_message = ChannelMessage::new(channel, $msg);
650 let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length");
651 let mut buf = vec![0u8; encoded_len];
652 let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message");
653 let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split();
654 assert_eq!(channel, decoded.0);
655 assert_eq!($msg, decoded.1);
656 )*
657 }
658 }
659
660 #[test]
661 fn message_encode_decode() {
662 message_enc_dec! {
663 Message::Synchronize(Synchronize{
664 fork: 0,
665 can_upgrade: true,
666 downloading: true,
667 uploading: true,
668 length: 5,
669 remote_length: 0,
670 }),
671 Message::Request(Request {
672 id: 1,
673 fork: 1,
674 block: Some(RequestBlock {
675 index: 5,
676 nodes: 10,
677 }),
678 hash: Some(RequestBlock {
679 index: 20,
680 nodes: 0
681 }),
682 seek: Some(RequestSeek {
683 bytes: 10
684 }),
685 upgrade: Some(RequestUpgrade {
686 start: 0,
687 length: 10
688 })
689 }),
690 Message::Cancel(Cancel {
691 request: 1,
692 }),
693 Message::Data(Data{
694 request: 1,
695 fork: 5,
696 block: Some(DataBlock {
697 index: 5,
698 nodes: vec![Node::new(1, vec![0x01; 32], 100)],
699 value: vec![0xFF; 10]
700 }),
701 hash: Some(DataHash {
702 index: 20,
703 nodes: vec![Node::new(2, vec![0x02; 32], 200)],
704 }),
705 seek: Some(DataSeek {
706 bytes: 10,
707 nodes: vec![Node::new(3, vec![0x03; 32], 300)],
708 }),
709 upgrade: Some(DataUpgrade {
710 start: 0,
711 length: 10,
712 nodes: vec![Node::new(4, vec![0x04; 32], 400)],
713 additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)],
714 signature: vec![0xAB; 32]
715 })
716 }),
717 Message::NoData(NoData {
718 request: 2,
719 }),
720 Message::Want(Want {
721 start: 0,
722 length: 100,
723 }),
724 Message::Unwant(Unwant {
725 start: 10,
726 length: 2,
727 }),
728 Message::Bitfield(Bitfield {
729 start: 20,
730 bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF],
731 }),
732 Message::Range(Range {
733 drop: true,
734 start: 12345,
735 length: 100000
736 }),
737 Message::Extension(Extension {
738 name: "custom_extension/v1/open".to_string(),
739 message: vec![0x44, 20]
740 })
741 };
742 }
743}