Skip to main content

hypercore_protocol/
protocol.rs

1use async_channel::{Receiver, Sender};
2use futures_lite::stream::Stream;
3use hypercore_handshake::{CipherTrait, state_machine::PUBLIC_KEYLEN};
4use std::{
5    collections::VecDeque,
6    convert::TryInto,
7    fmt,
8    io::{self, Result},
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tracing::{error, instrument};
13
14use crate::{
15    channels::{Channel, ChannelMap},
16    constants::PROTOCOL_NAME,
17    crypto::HandshakeResult,
18    message::{ChannelMessage, Message},
19    mqueue::MessageIo,
20    schema::*,
21    util::{map_channel_err, pretty_hash},
22};
23
24macro_rules! return_error {
25    ($msg:expr) => {
26        if let Err(e) = $msg {
27            return Poll::Ready(Err(e));
28        }
29    };
30}
31
32const CHANNEL_CAP: usize = 1000;
33
34/// Remote public key (32 bytes).
35pub(crate) type RemotePublicKey = [u8; 32];
36/// Discovery key (32 bytes).
37pub type DiscoveryKey = [u8; 32];
38/// Key (32 bytes).
39pub type Key = [u8; 32];
40
41/// A protocol event.
42#[non_exhaustive]
43#[derive(PartialEq)]
44pub enum Event {
45    /// Emitted after the handshake with the remote peer is complete.
46    /// This is the first event.
47    Handshake(RemotePublicKey),
48    /// Emitted when the remote peer opens a channel that we did not yet open.
49    DiscoveryKey(DiscoveryKey),
50    /// Emitted when a channel is established.
51    Channel(Channel),
52    /// Emitted when a channel is closed.
53    Close(DiscoveryKey),
54    /// Convenience event to make it possible to signal the protocol from a channel.
55    /// See channel.signal_local() and protocol.commands().signal_local().
56    LocalSignal((String, Vec<u8>)),
57}
58
59/// A protocol command.
60#[derive(Debug)]
61pub enum Command {
62    /// Open a channel
63    Open(Key),
64    /// Close a channel by discovery key
65    Close(DiscoveryKey),
66    /// Signal locally to protocol
67    SignalLocal((String, Vec<u8>)),
68}
69
70impl fmt::Debug for Event {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        match self {
73            Event::Handshake(remote_key) => {
74                write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key))
75            }
76            Event::DiscoveryKey(discovery_key) => {
77                write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key))
78            }
79            Event::Channel(channel) => {
80                write!(f, "Channel({})", &pretty_hash(channel.discovery_key()))
81            }
82            Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)),
83            Event::LocalSignal((name, data)) => {
84                write!(f, "LocalSignal(name={},len={})", name, data.len())
85            }
86        }
87    }
88}
89
90/// A Protocol stream for replicating hypercores over an encrypted connection.
91///
92/// The protocol expects an already-encrypted, message-framed connection
93/// (e.g., from hyperswarm). The `HandshakeResult` provides the handshake hash
94/// and public keys needed for capability verification.
95pub struct Protocol {
96    io: MessageIo,
97    is_initiator: bool,
98    channels: ChannelMap,
99    command_rx: Receiver<Command>,
100    command_tx: CommandTx,
101    outbound_rx: Receiver<Vec<ChannelMessage>>,
102    outbound_tx: Sender<Vec<ChannelMessage>>,
103    queued_events: VecDeque<Event>,
104    handshake_emitted: bool,
105}
106
107impl std::fmt::Debug for Protocol {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        f.debug_struct("Protocol")
110            .field("is_initiator", &self.is_initiator)
111            .field("channels", &self.channels)
112            .field("handshake_emitted", &self.handshake_emitted)
113            .field("queued_events", &self.queued_events)
114            .finish()
115    }
116}
117
118impl Protocol {
119    /// Create a new protocol instance.
120    ///
121    /// # Arguments
122    /// * `stream` - An already-encrypted, message-framed connection (e.g., hyperswarm `Connection`)
123    pub fn new(stream: Box<dyn CipherTrait>) -> Self {
124        let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP);
125        let (outbound_tx, outbound_rx): (
126            Sender<Vec<ChannelMessage>>,
127            Receiver<Vec<ChannelMessage>>,
128        ) = async_channel::bounded(CHANNEL_CAP);
129
130        let is_initiator = stream.is_initiator();
131
132        Protocol {
133            io: MessageIo::new(stream),
134            is_initiator,
135            channels: ChannelMap::new(),
136            command_rx,
137            command_tx: CommandTx(command_tx),
138            outbound_tx,
139            outbound_rx,
140            queued_events: VecDeque::new(),
141            handshake_emitted: false,
142        }
143    }
144
145    /// Whether this protocol stream initiated the underlying IO connection.
146    pub fn is_initiator(&self) -> bool {
147        self.is_initiator
148    }
149
150    /// Get your own Noise public key.
151    pub fn public_key(&self) -> [u8; PUBLIC_KEYLEN] {
152        self.io.local_public_key()
153    }
154
155    /// Get the remote's Noise public key.
156    pub fn remote_public_key(&self) -> Option<[u8; PUBLIC_KEYLEN]> {
157        self.io.remote_public_key()
158    }
159
160    /// Get a sender to send commands.
161    pub fn commands(&self) -> CommandTx {
162        self.command_tx.clone()
163    }
164
165    /// Give a command to the protocol.
166    pub async fn command(&self, command: Command) -> Result<()> {
167        self.command_tx.send(command).await
168    }
169
170    /// Open a new protocol channel.
171    ///
172    /// Once the other side proofed that it also knows the `key`, the channel is emitted as
173    /// `Event::Channel` on the protocol event stream.
174    pub async fn open(&self, key: Key) -> Result<()> {
175        self.command_tx.open(key).await
176    }
177
178    /// Iterator of all currently opened channels.
179    pub fn channels(&self) -> impl Iterator<Item = &DiscoveryKey> {
180        self.channels.iter().map(|c| c.discovery_key())
181    }
182
183    #[instrument(skip_all, fields(initiator = ?self.is_initiator()))]
184    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Event>> {
185        let this = self.get_mut();
186
187        // Initiator needs to send and receive a message before proceeding
188        if this.is_initiator && this.io.handshake_hash().is_none() {
189            return_error!(this.poll_outbound_write(cx));
190            return_error!(this.poll_inbound_read(cx));
191            if this.io.handshake_hash().is_none() {
192                cx.waker().wake_by_ref();
193                return Poll::Pending;
194            }
195        }
196        // Emit handshake event on first poll
197        if !this.handshake_emitted {
198            if let Some(remote_pubkey) = this.io.remote_public_key() {
199                this.handshake_emitted = true;
200                return Poll::Ready(Ok(Event::Handshake(remote_pubkey)));
201            } else {
202                cx.waker().wake_by_ref();
203            }
204        }
205
206        // Drain queued events first.
207        if let Some(event) = this.queued_events.pop_front() {
208            return Poll::Ready(Ok(event));
209        }
210
211        // Read and process incoming messages.
212        return_error!(this.poll_inbound_read(cx));
213
214        // Check for commands.
215        return_error!(this.poll_commands(cx));
216
217        // Write everything we can write.
218        return_error!(this.poll_outbound_write(cx));
219
220        // Check if any events are enqueued.
221        if let Some(event) = this.queued_events.pop_front() {
222            Poll::Ready(Ok(event))
223        } else {
224            Poll::Pending
225        }
226    }
227
228    /// Poll commands.
229    fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> {
230        while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) {
231            if let Err(e) = self.on_command(command) {
232                error!(error = ?e, "Error handling command");
233                return Err(e);
234            }
235        }
236        Ok(())
237    }
238
239    // just handles Close and LocalSignal
240    fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool {
241        // If message is close, close the local channel.
242        if let ChannelMessage {
243            channel,
244            message: Message::Close(_),
245            ..
246        } = message
247        {
248            self.close_local(*channel);
249        // If message is a LocalSignal, emit an event and return false to indicate
250        // this message should be filtered out.
251        } else if let ChannelMessage {
252            message: Message::LocalSignal((name, data)),
253            ..
254        } = message
255        {
256            self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec())));
257            return false;
258        }
259        true
260    }
261
262    /// Poll for inbound messages and process them.
263    #[instrument(skip_all, err)]
264    fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> {
265        loop {
266            match self.io.poll_inbound(cx) {
267                Poll::Ready(Some(result)) => {
268                    let messages = result?;
269                    self.on_inbound_channel_messages(messages)?;
270                }
271                Poll::Ready(None) => return Ok(()),
272                Poll::Pending => return Ok(()),
273            }
274        }
275    }
276
277    /// Poll for outbound messages and write them.
278    #[instrument(skip_all)]
279    fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> {
280        loop {
281            // Drive outbound IO
282            if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) {
283                error!(err = ?e, "error from poll_outbound");
284                return Err(e);
285            }
286            // Send messages from outbound_rx
287            match Pin::new(&mut self.outbound_rx).poll_next(cx) {
288                Poll::Ready(Some(mut messages)) => {
289                    if !messages.is_empty() {
290                        messages.retain(|message| self.on_outbound_message(message));
291                        for msg in messages {
292                            self.io.enqueue(msg);
293                        }
294                    }
295                }
296                Poll::Ready(None) => unreachable!("Channel closed before end"),
297                Poll::Pending => return Ok(()),
298            }
299        }
300    }
301
302    #[instrument(skip_all)]
303    fn on_inbound_channel_messages(&mut self, channel_messages: Vec<ChannelMessage>) -> Result<()> {
304        for channel_message in channel_messages {
305            self.on_inbound_message(channel_message)?
306        }
307        Ok(())
308    }
309
310    #[instrument(skip_all)]
311    fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> {
312        let (remote_id, message) = channel_message.into_split();
313        match message {
314            Message::Open(msg) => self.on_open(remote_id, msg)?,
315            Message::Close(msg) => self.on_close(remote_id, msg)?,
316            _ => self
317                .channels
318                .forward_inbound_message(remote_id as usize, message)?,
319        }
320        Ok(())
321    }
322
323    #[instrument(skip(self))]
324    fn on_command(&mut self, command: Command) -> Result<()> {
325        match command {
326            Command::Open(key) => self.command_open(key),
327            Command::Close(discovery_key) => self.command_close(discovery_key),
328            Command::SignalLocal((name, data)) => self.command_signal_local(name, data),
329        }
330    }
331
332    /// Open a Channel with the given key. Adding it to our channel map
333    #[instrument(skip_all)]
334    fn command_open(&mut self, key: Key) -> Result<()> {
335        // Create a new channel.
336        let channel_handle = self.channels.attach_local(key);
337        // Safe because attach_local always puts Some(local_id)
338        let local_id = channel_handle.local_id().unwrap();
339        let discovery_key = *channel_handle.discovery_key();
340
341        // If the channel was already opened from the remote end, verify, and if
342        // verification is ok, push a channel open event.
343        if channel_handle.is_connected() {
344            self.accept_channel(local_id)?;
345        }
346
347        // Tell the remote end about the new channel.
348        let capability = self.capability(&key);
349        let channel = local_id as u64;
350        let message = Message::Open(Open {
351            channel,
352            protocol: PROTOCOL_NAME.to_string(),
353            discovery_key: discovery_key.to_vec(),
354            capability,
355        });
356        let channel_message = ChannelMessage::new(channel, message);
357        self.io.enqueue(channel_message);
358        Ok(())
359    }
360
361    fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
362        if self.channels.has_channel(&discovery_key) {
363            self.channels.remove(&discovery_key);
364            self.queue_event(Event::Close(discovery_key));
365        }
366        Ok(())
367    }
368
369    fn command_signal_local(&mut self, name: String, data: Vec<u8>) -> Result<()> {
370        self.queue_event(Event::LocalSignal((name, data)));
371        Ok(())
372    }
373
374    #[instrument(skip(self))]
375    fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> {
376        let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?;
377        let channel_handle =
378            self.channels
379                .attach_remote(discovery_key, ch as usize, msg.capability);
380
381        if channel_handle.is_connected() {
382            let local_id = channel_handle.local_id().unwrap();
383            self.accept_channel(local_id)?;
384        } else {
385            self.queue_event(Event::DiscoveryKey(discovery_key));
386        }
387
388        Ok(())
389    }
390
391    #[instrument(skip(self))]
392    fn queue_event(&mut self, event: Event) {
393        self.queued_events.push_back(event);
394    }
395
396    #[instrument(skip(self))]
397    fn accept_channel(&mut self, local_id: usize) -> Result<()> {
398        let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?;
399        self.verify_remote_capability(remote_capability.cloned(), key)
400            .expect("TODO channel can only be accepted after first message")?;
401        let channel = self.channels.accept(local_id, self.outbound_tx.clone())?;
402        self.queue_event(Event::Channel(channel));
403        Ok(())
404    }
405
406    fn close_local(&mut self, local_id: u64) {
407        if let Some(channel) = self.channels.get_local(local_id as usize) {
408            let discovery_key = *channel.discovery_key();
409            self.channels.remove(&discovery_key);
410            self.queue_event(Event::Close(discovery_key));
411        }
412    }
413
414    fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> {
415        if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) {
416            let discovery_key = *channel_handle.discovery_key();
417            // There is a possibility both sides will close at the same time, so
418            // the channel could be closed already, let's tolerate that.
419            self.channels
420                .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?;
421            self.channels.remove(&discovery_key);
422            self.queue_event(Event::Close(discovery_key));
423        }
424        Ok(())
425    }
426
427    #[instrument(skip_all)]
428    fn capability(&self, key: &[u8]) -> Option<Vec<u8>> {
429        let is_initiator = self.is_initiator;
430        let remote_pubkey = self.remote_public_key()?;
431        let local_pubkey = self.public_key();
432        let handshake_hash = self.io.handshake_hash()?;
433        HandshakeResult::from_pre_encrypted(
434            is_initiator,
435            local_pubkey,
436            remote_pubkey,
437            handshake_hash.to_vec(),
438        )
439        .capability(key)
440    }
441
442    #[instrument(skip_all)]
443    fn verify_remote_capability(
444        &self,
445        capability: Option<Vec<u8>>,
446        key: &[u8],
447    ) -> Option<Result<()>> {
448        let is_initiator = self.is_initiator;
449        let remote_pubkey = self.remote_public_key()?;
450        let local_pubkey = self.public_key();
451        let handshake_hash = self.io.handshake_hash()?;
452        Some(
453            HandshakeResult::from_pre_encrypted(
454                is_initiator,
455                local_pubkey,
456                remote_pubkey,
457                handshake_hash.to_vec(),
458            )
459            .verify_remote_capability(capability, key),
460        )
461    }
462}
463
464impl Stream for Protocol {
465    type Item = Result<Event>;
466    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
467        match Protocol::poll_next(self, cx) {
468            Poll::Ready(Ok(e)) => Poll::Ready(Some(Ok(e))),
469            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
470            Poll::Pending => Poll::Pending,
471        }
472    }
473}
474
475/// Send [`Command`]s to the [`Protocol`].
476#[derive(Clone, Debug)]
477pub struct CommandTx(Sender<Command>);
478
479impl CommandTx {
480    /// Send a protocol command
481    pub async fn send(&self, command: Command) -> Result<()> {
482        self.0.send(command).await.map_err(map_channel_err)
483    }
484    /// Open a protocol channel.
485    ///
486    /// The channel will be emitted on the main protocol.
487    pub async fn open(&self, key: Key) -> Result<()> {
488        self.send(Command::Open(key)).await
489    }
490
491    /// Close a protocol channel.
492    pub async fn close(&self, discovery_key: DiscoveryKey) -> Result<()> {
493        self.send(Command::Close(discovery_key)).await
494    }
495
496    /// Send a local signal event to the protocol.
497    pub async fn signal_local(&self, name: &str, data: Vec<u8>) -> Result<()> {
498        self.send(Command::SignalLocal((name.to_string(), data)))
499            .await
500    }
501}
502
503fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> {
504    key.try_into()
505        .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long"))
506}