hypercore_protocol/
protocol.rs

1use async_channel::{Receiver, Sender};
2use futures_lite::io::{AsyncRead, AsyncWrite};
3use futures_lite::stream::Stream;
4use futures_timer::Delay;
5use std::collections::VecDeque;
6use std::convert::TryInto;
7use std::fmt;
8use std::future::Future;
9use std::io::{self, Error, ErrorKind, Result};
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use crate::channels::{Channel, ChannelMap};
15use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME};
16use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult};
17use crate::message::{ChannelMessage, Frame, FrameType, Message};
18use crate::reader::ReadState;
19use crate::schema::*;
20use crate::util::{map_channel_err, pretty_hash};
21use crate::writer::WriteState;
22
23macro_rules! return_error {
24    ($msg:expr) => {
25        if let Err(e) = $msg {
26            return Poll::Ready(Err(e));
27        }
28    };
29}
30
31const CHANNEL_CAP: usize = 1000;
32const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64);
33
34/// Options for a Protocol instance.
35#[derive(Debug)]
36pub(crate) struct Options {
37    /// Whether this peer initiated the IO connection for this protoccol
38    pub(crate) is_initiator: bool,
39    /// Enable or disable the handshake.
40    /// Disabling the handshake will also disable capabilitity verification.
41    /// Don't disable this if you're not 100% sure you want this.
42    pub(crate) noise: bool,
43    /// Enable or disable transport encryption.
44    pub(crate) encrypted: bool,
45}
46
47impl Options {
48    /// Create with default options.
49    pub(crate) fn new(is_initiator: bool) -> Self {
50        Self {
51            is_initiator,
52            noise: true,
53            encrypted: true,
54        }
55    }
56}
57
58/// Remote public key (32 bytes).
59pub(crate) type RemotePublicKey = [u8; 32];
60/// Discovery key (32 bytes).
61pub type DiscoveryKey = [u8; 32];
62/// Key (32 bytes).
63pub type Key = [u8; 32];
64
65/// A protocol event.
66#[non_exhaustive]
67#[derive(PartialEq)]
68pub enum Event {
69    /// Emitted after the handshake with the remote peer is complete.
70    /// This is the first event (if the handshake is not disabled).
71    Handshake(RemotePublicKey),
72    /// Emitted when the remote peer opens a channel that we did not yet open.
73    DiscoveryKey(DiscoveryKey),
74    /// Emitted when a channel is established.
75    Channel(Channel),
76    /// Emitted when a channel is closed.
77    Close(DiscoveryKey),
78    /// Convenience event to make it possible to signal the protocol from a channel.
79    /// See channel.signal_local() and protocol.commands().signal_local().
80    LocalSignal((String, Vec<u8>)),
81}
82
83/// A protocol command.
84#[derive(Debug)]
85pub enum Command {
86    /// Open a channel
87    Open(Key),
88    /// Close a channel by discovery key
89    Close(DiscoveryKey),
90    /// Signal locally to protocol
91    SignalLocal((String, Vec<u8>)),
92}
93
94impl fmt::Debug for Event {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        match self {
97            Event::Handshake(remote_key) => {
98                write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key))
99            }
100            Event::DiscoveryKey(discovery_key) => {
101                write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key))
102            }
103            Event::Channel(channel) => {
104                write!(f, "Channel({})", &pretty_hash(channel.discovery_key()))
105            }
106            Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)),
107            Event::LocalSignal((name, data)) => {
108                write!(f, "LocalSignal(name={},len={})", name, data.len())
109            }
110        }
111    }
112}
113
114/// Protocol state
115#[allow(clippy::large_enum_variant)]
116pub(crate) enum State {
117    NotInitialized,
118    // The Handshake struct sits behind an option only so that we can .take()
119    // it out, it's never actually empty when in State::Handshake.
120    Handshake(Option<Handshake>),
121    SecretStream(Option<EncryptCipher>),
122    Established,
123}
124
125impl fmt::Debug for State {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        match self {
128            State::NotInitialized => write!(f, "NotInitialized"),
129            State::Handshake(_) => write!(f, "Handshaking"),
130            State::SecretStream(_) => write!(f, "SecretStream"),
131            State::Established => write!(f, "Established"),
132        }
133    }
134}
135
136/// A Protocol stream.
137pub struct Protocol<IO> {
138    write_state: WriteState,
139    read_state: ReadState,
140    io: IO,
141    state: State,
142    options: Options,
143    handshake: Option<HandshakeResult>,
144    channels: ChannelMap,
145    command_rx: Receiver<Command>,
146    command_tx: CommandTx,
147    outbound_rx: Receiver<Vec<ChannelMessage>>,
148    outbound_tx: Sender<Vec<ChannelMessage>>,
149    keepalive: Delay,
150    queued_events: VecDeque<Event>,
151}
152
153impl<IO> std::fmt::Debug for Protocol<IO> {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        f.debug_struct("Protocol")
156            .field("write_state", &self.write_state)
157            .field("read_state", &self.read_state)
158            //.field("io", &self.io)
159            .field("state", &self.state)
160            .field("options", &self.options)
161            .field("handshake", &self.handshake)
162            .field("channels", &self.channels)
163            .field("command_rx", &self.command_rx)
164            .field("command_tx", &self.command_tx)
165            .field("outbound_rx", &self.outbound_rx)
166            .field("outbound_tx", &self.outbound_tx)
167            .field("keepalive", &self.keepalive)
168            .field("queued_events", &self.queued_events)
169            .finish()
170    }
171}
172
173impl<IO> Protocol<IO>
174where
175    IO: AsyncWrite + AsyncRead + Send + Unpin + 'static,
176{
177    /// Create a new protocol instance.
178    pub(crate) fn new(io: IO, options: Options) -> Self {
179        let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP);
180        let (outbound_tx, outbound_rx): (
181            Sender<Vec<ChannelMessage>>,
182            Receiver<Vec<ChannelMessage>>,
183        ) = async_channel::bounded(1);
184        Protocol {
185            io,
186            read_state: ReadState::new(),
187            write_state: WriteState::new(),
188            options,
189            state: State::NotInitialized,
190            channels: ChannelMap::new(),
191            handshake: None,
192            command_rx,
193            command_tx: CommandTx(command_tx),
194            outbound_tx,
195            outbound_rx,
196            keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)),
197            queued_events: VecDeque::new(),
198        }
199    }
200
201    /// Whether this protocol stream initiated the underlying IO connection.
202    pub fn is_initiator(&self) -> bool {
203        self.options.is_initiator
204    }
205
206    /// Get your own Noise public key.
207    ///
208    /// Empty before the handshake completed.
209    pub fn public_key(&self) -> Option<&[u8]> {
210        match &self.handshake {
211            None => None,
212            Some(handshake) => Some(handshake.local_pubkey.as_slice()),
213        }
214    }
215
216    /// Get the remote's Noise public key.
217    ///
218    /// Empty before the handshake completed.
219    pub fn remote_public_key(&self) -> Option<&[u8]> {
220        match &self.handshake {
221            None => None,
222            Some(handshake) => Some(handshake.remote_pubkey.as_slice()),
223        }
224    }
225
226    /// Get a sender to send commands.
227    pub fn commands(&self) -> CommandTx {
228        self.command_tx.clone()
229    }
230
231    /// Give a command to the protocol.
232    pub async fn command(&mut self, command: Command) -> Result<()> {
233        self.command_tx.send(command).await
234    }
235
236    /// Open a new protocol channel.
237    ///
238    /// Once the other side proofed that it also knows the `key`, the channel is emitted as
239    /// `Event::Channel` on the protocol event stream.
240    pub async fn open(&mut self, key: Key) -> Result<()> {
241        self.command_tx.open(key).await
242    }
243
244    /// Iterator of all currently opened channels.
245    pub fn channels(&self) -> impl Iterator<Item = &DiscoveryKey> {
246        self.channels.iter().map(|c| c.discovery_key())
247    }
248
249    /// Stop the protocol and return the inner reader and writer.
250    pub fn release(self) -> IO {
251        self.io
252    }
253
254    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Event>> {
255        let this = self.get_mut();
256
257        if let State::NotInitialized = this.state {
258            return_error!(this.init());
259        }
260
261        // Drain queued events first.
262        if let Some(event) = this.queued_events.pop_front() {
263            return Poll::Ready(Ok(event));
264        }
265
266        // Read and process incoming messages.
267        return_error!(this.poll_inbound_read(cx));
268
269        if let State::Established = this.state {
270            // Check for commands, but only once the connection is established.
271            return_error!(this.poll_commands(cx));
272        }
273
274        // Poll the keepalive timer.
275        this.poll_keepalive(cx);
276
277        // Write everything we can write.
278        return_error!(this.poll_outbound_write(cx));
279
280        // Check if any events are enqueued.
281        if let Some(event) = this.queued_events.pop_front() {
282            Poll::Ready(Ok(event))
283        } else {
284            Poll::Pending
285        }
286    }
287
288    fn init(&mut self) -> Result<()> {
289        tracing::debug!(
290            "protocol init, state {:?}, options {:?}",
291            self.state,
292            self.options
293        );
294        match self.state {
295            State::NotInitialized => {}
296            _ => return Ok(()),
297        };
298
299        self.state = if self.options.noise {
300            let mut handshake = Handshake::new(self.options.is_initiator)?;
301            // If the handshake start returns a buffer, send it now.
302            if let Some(buf) = handshake.start()? {
303                self.queue_frame_direct(buf.to_vec()).unwrap();
304            }
305            self.read_state.set_frame_type(FrameType::Raw);
306            State::Handshake(Some(handshake))
307        } else {
308            self.read_state.set_frame_type(FrameType::Message);
309            State::Established
310        };
311
312        Ok(())
313    }
314
315    /// Poll commands.
316    fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> {
317        while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) {
318            self.on_command(command)?;
319        }
320        Ok(())
321    }
322
323    /// Poll the keepalive timer and queue a ping message if needed.
324    fn poll_keepalive(&mut self, cx: &mut Context<'_>) {
325        if Pin::new(&mut self.keepalive).poll(cx).is_ready() {
326            if let State::Established = self.state {
327                // 24 bit header for the empty message, hence the 3
328                self.write_state
329                    .queue_frame(Frame::RawBatch(vec![vec![0u8; 3]]));
330            }
331            self.keepalive.reset(KEEPALIVE_DURATION);
332        }
333    }
334
335    fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool {
336        // If message is close, close the local channel.
337        if let ChannelMessage {
338            channel,
339            message: Message::Close(_),
340            ..
341        } = message
342        {
343            self.close_local(*channel);
344        // If message is a LocalSignal, emit an event and return false to indicate
345        // this message should be filtered out.
346        } else if let ChannelMessage {
347            message: Message::LocalSignal((name, data)),
348            ..
349        } = message
350        {
351            self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec())));
352            return false;
353        }
354        true
355    }
356
357    /// Poll for inbound messages and processs them.
358    fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> {
359        loop {
360            let msg = self.read_state.poll_reader(cx, &mut self.io);
361            match msg {
362                Poll::Ready(Ok(message)) => {
363                    self.on_inbound_frame(message)?;
364                }
365                Poll::Ready(Err(e)) => return Err(e),
366                Poll::Pending => return Ok(()),
367            }
368        }
369    }
370
371    /// Poll for outbound messages and write them.
372    fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> {
373        loop {
374            if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) {
375                return Err(e);
376            }
377            if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) {
378                return Ok(());
379            }
380
381            match Pin::new(&mut self.outbound_rx).poll_next(cx) {
382                Poll::Ready(Some(mut messages)) => {
383                    if !messages.is_empty() {
384                        messages.retain(|message| self.on_outbound_message(message));
385                        if !messages.is_empty() {
386                            let frame = Frame::MessageBatch(messages);
387                            self.write_state.park_frame(frame);
388                        }
389                    }
390                }
391                Poll::Ready(None) => unreachable!("Channel closed before end"),
392                Poll::Pending => return Ok(()),
393            }
394        }
395    }
396
397    fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> {
398        match frame {
399            Frame::RawBatch(raw_batch) => {
400                let mut processed_state: Option<String> = None;
401                for buf in raw_batch {
402                    let state_name: String = format!("{:?}", self.state);
403                    match self.state {
404                        State::Handshake(_) => self.on_handshake_message(buf)?,
405                        State::SecretStream(_) => self.on_secret_stream_message(buf)?,
406                        State::Established => {
407                            if let Some(processed_state) = processed_state.as_ref() {
408                                let previous_state = if self.options.encrypted {
409                                    State::SecretStream(None)
410                                } else {
411                                    State::Handshake(None)
412                                };
413                                if processed_state == &format!("{previous_state:?}") {
414                                    // This is the unlucky case where the batch had two or more messages where
415                                    // the first one was correctly identified as Raw but everything
416                                    // after that should have been (decrypted and) a MessageBatch. Correct the mistake
417                                    // here post-hoc.
418                                    let buf = self.read_state.decrypt_buf(&buf)?;
419                                    let frame = Frame::decode(&buf, &FrameType::Message)?;
420                                    self.on_inbound_frame(frame)?;
421                                    continue;
422                                }
423                            }
424                            unreachable!(
425                                "May not receive raw frames in Established state"
426                            )
427                        }
428                        _ => unreachable!(
429                            "May not receive raw frames outside of handshake or secretstream state, was {:?}",
430                            self.state
431                        ),
432                    };
433                    if processed_state.is_none() {
434                        processed_state = Some(state_name)
435                    }
436                }
437                Ok(())
438            }
439            Frame::MessageBatch(channel_messages) => match self.state {
440                State::Established => {
441                    for channel_message in channel_messages {
442                        self.on_inbound_message(channel_message)?
443                    }
444                    Ok(())
445                }
446                _ => unreachable!("May not receive message batch frames when not established"),
447            },
448        }
449    }
450
451    fn on_handshake_message(&mut self, buf: Vec<u8>) -> Result<()> {
452        let mut handshake = match &mut self.state {
453            State::Handshake(handshake) => handshake.take().unwrap(),
454            _ => unreachable!("May not call on_handshake_message when not in Handshake state"),
455        };
456
457        if let Some(response_buf) = handshake.read(&buf)? {
458            self.queue_frame_direct(response_buf.to_vec()).unwrap();
459        }
460
461        if !handshake.complete() {
462            self.state = State::Handshake(Some(handshake));
463        } else {
464            let handshake_result = handshake.into_result()?;
465
466            if self.options.encrypted {
467                // The cipher will be put to use to the writer only after the peer's answer has come
468                let (cipher, init_msg) = EncryptCipher::from_handshake_tx(&handshake_result)?;
469                self.state = State::SecretStream(Some(cipher));
470
471                // Send the secret stream init message header to the other side
472                self.queue_frame_direct(init_msg).unwrap();
473            } else {
474                // Skip secret stream and go straight to Established, then notify about
475                // handshake
476                self.read_state.set_frame_type(FrameType::Message);
477                let remote_public_key = parse_key(&handshake_result.remote_pubkey)?;
478                self.queue_event(Event::Handshake(remote_public_key));
479                self.state = State::Established;
480            }
481            // Store handshake result
482            self.handshake = Some(handshake_result);
483        }
484        Ok(())
485    }
486
487    fn on_secret_stream_message(&mut self, buf: Vec<u8>) -> Result<()> {
488        let encrypt_cipher = match &mut self.state {
489            State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(),
490            _ => {
491                unreachable!("May not call on_secret_stream_message when not in SecretStream state")
492            }
493        };
494        let handshake_result = &self
495            .handshake
496            .as_ref()
497            .expect("Handshake result must be set before secret stream");
498        let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?;
499        self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher);
500        self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher);
501        self.read_state.set_frame_type(FrameType::Message);
502
503        // Lastly notify that handshake is ready and set state to established
504        let remote_public_key = parse_key(&handshake_result.remote_pubkey)?;
505        self.queue_event(Event::Handshake(remote_public_key));
506        self.state = State::Established;
507        Ok(())
508    }
509
510    fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> {
511        // let channel_message = ChannelMessage::decode(buf)?;
512        let (remote_id, message) = channel_message.into_split();
513        match message {
514            Message::Open(msg) => self.on_open(remote_id, msg)?,
515            Message::Close(msg) => self.on_close(remote_id, msg)?,
516            _ => self
517                .channels
518                .forward_inbound_message(remote_id as usize, message)?,
519        }
520        Ok(())
521    }
522
523    fn on_command(&mut self, command: Command) -> Result<()> {
524        match command {
525            Command::Open(key) => self.command_open(key),
526            Command::Close(discovery_key) => self.command_close(discovery_key),
527            Command::SignalLocal((name, data)) => self.command_signal_local(name, data),
528        }
529    }
530
531    /// Open a Channel with the given key. Adding it to our channel map
532    fn command_open(&mut self, key: Key) -> Result<()> {
533        // Create a new channel.
534        let channel_handle = self.channels.attach_local(key);
535        // Safe because attach_local always puts Some(local_id)
536        let local_id = channel_handle.local_id().unwrap();
537        let discovery_key = *channel_handle.discovery_key();
538
539        // If the channel was already opened from the remote end, verify, and if
540        // verification is ok, push a channel open event.
541        if channel_handle.is_connected() {
542            self.accept_channel(local_id)?;
543        }
544
545        // Tell the remote end about the new channel.
546        let capability = self.capability(&key);
547        let channel = local_id as u64;
548        let message = Message::Open(Open {
549            channel,
550            protocol: PROTOCOL_NAME.to_string(),
551            discovery_key: discovery_key.to_vec(),
552            capability,
553        });
554        let channel_message = ChannelMessage::new(channel, message);
555        self.write_state
556            .queue_frame(Frame::MessageBatch(vec![channel_message]));
557        Ok(())
558    }
559
560    fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
561        if self.channels.has_channel(&discovery_key) {
562            self.channels.remove(&discovery_key);
563            self.queue_event(Event::Close(discovery_key));
564        }
565        Ok(())
566    }
567
568    fn command_signal_local(&mut self, name: String, data: Vec<u8>) -> Result<()> {
569        self.queue_event(Event::LocalSignal((name, data)));
570        Ok(())
571    }
572
573    fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> {
574        let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?;
575        let channel_handle =
576            self.channels
577                .attach_remote(discovery_key, ch as usize, msg.capability);
578
579        if channel_handle.is_connected() {
580            let local_id = channel_handle.local_id().unwrap();
581            self.accept_channel(local_id)?;
582        } else {
583            self.queue_event(Event::DiscoveryKey(discovery_key));
584        }
585
586        Ok(())
587    }
588
589    fn queue_event(&mut self, event: Event) {
590        self.queued_events.push_back(event);
591    }
592
593    fn queue_frame_direct(&mut self, body: Vec<u8>) -> Result<bool> {
594        let mut frame = Frame::RawBatch(vec![body]);
595        self.write_state.try_queue_direct(&mut frame)
596    }
597
598    fn accept_channel(&mut self, local_id: usize) -> Result<()> {
599        let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?;
600        self.verify_remote_capability(remote_capability.cloned(), key)?;
601        let channel = self.channels.accept(local_id, self.outbound_tx.clone())?;
602        self.queue_event(Event::Channel(channel));
603        Ok(())
604    }
605
606    fn close_local(&mut self, local_id: u64) {
607        if let Some(channel) = self.channels.get_local(local_id as usize) {
608            let discovery_key = *channel.discovery_key();
609            self.channels.remove(&discovery_key);
610            self.queue_event(Event::Close(discovery_key));
611        }
612    }
613
614    fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> {
615        if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) {
616            let discovery_key = *channel_handle.discovery_key();
617            // There is a possibility both sides will close at the same time, so
618            // the channel could be closed already, let's tolerate that.
619            self.channels
620                .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?;
621            self.channels.remove(&discovery_key);
622            self.queue_event(Event::Close(discovery_key));
623        }
624        Ok(())
625    }
626
627    fn capability(&self, key: &[u8]) -> Option<Vec<u8>> {
628        match self.handshake.as_ref() {
629            Some(handshake) => handshake.capability(key),
630            None => None,
631        }
632    }
633
634    fn verify_remote_capability(&self, capability: Option<Vec<u8>>, key: &[u8]) -> Result<()> {
635        match self.handshake.as_ref() {
636            Some(handshake) => handshake.verify_remote_capability(capability, key),
637            None => Err(Error::new(
638                ErrorKind::PermissionDenied,
639                "Missing handshake state for capability verification",
640            )),
641        }
642    }
643}
644
645impl<IO> Stream for Protocol<IO>
646where
647    IO: AsyncRead + AsyncWrite + Send + Unpin + 'static,
648{
649    type Item = Result<Event>;
650    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
651        Protocol::poll_next(self, cx).map(Some)
652    }
653}
654
655/// Send [Command](Command)s to the [Protocol](Protocol).
656#[derive(Clone, Debug)]
657pub struct CommandTx(Sender<Command>);
658
659impl CommandTx {
660    /// Send a protocol command
661    pub async fn send(&mut self, command: Command) -> Result<()> {
662        self.0.send(command).await.map_err(map_channel_err)
663    }
664    /// Open a protocol channel.
665    ///
666    /// The channel will be emitted on the main protocol.
667    pub async fn open(&mut self, key: Key) -> Result<()> {
668        self.send(Command::Open(key)).await
669    }
670
671    /// Close a protocol channel.
672    pub async fn close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
673        self.send(Command::Close(discovery_key)).await
674    }
675
676    /// Send a local signal event to the protocol.
677    pub async fn signal_local(&mut self, name: &str, data: Vec<u8>) -> Result<()> {
678        self.send(Command::SignalLocal((name.to_string(), data)))
679            .await
680    }
681}
682
683fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> {
684    key.try_into()
685        .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long"))
686}