Skip to main content

hypercore_protocol/
protocol.rs

1use async_channel::{Receiver, Sender};
2use futures_lite::stream::Stream;
3use hypercore_handshake::{CipherTrait, state_machine::PUBLIC_KEYLEN};
4use std::{
5    collections::VecDeque,
6    convert::TryInto,
7    fmt,
8    io::{self, Result},
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tracing::{error, instrument};
13
14use crate::{
15    channels::{Channel, ChannelMap},
16    constants::PROTOCOL_NAME,
17    crypto::HandshakeResult,
18    message::{ChannelMessage, Message},
19    mqueue::MessageIo,
20    schema::*,
21    util::{map_channel_err, pretty_hash},
22};
23
24macro_rules! return_error {
25    ($msg:expr) => {
26        if let Err(e) = $msg {
27            return Poll::Ready(Err(e));
28        }
29    };
30}
31
32const CHANNEL_CAP: usize = 1000;
33
34/// Remote public key (32 bytes).
35pub(crate) type RemotePublicKey = [u8; 32];
36/// Discovery key (32 bytes).
37pub type DiscoveryKey = [u8; 32];
38/// Key (32 bytes).
39pub type Key = [u8; 32];
40
41/// A protocol event.
42#[non_exhaustive]
43#[derive(PartialEq)]
44pub enum Event {
45    /// Emitted after the handshake with the remote peer is complete.
46    /// This is the first event.
47    Handshake(RemotePublicKey),
48    /// Emitted when the remote peer opens a channel that we did not yet open.
49    DiscoveryKey(DiscoveryKey),
50    /// Emitted when a channel is established.
51    Channel(Channel),
52    /// Emitted when a channel is closed.
53    Close(DiscoveryKey),
54    /// Convenience event to make it possible to signal the protocol from a channel.
55    /// See channel.signal_local() and protocol.commands().signal_local().
56    LocalSignal((String, Vec<u8>)),
57}
58
59/// A protocol command.
60#[derive(Debug)]
61pub enum Command {
62    /// Open a channel
63    Open(Key),
64    /// Close a channel by discovery key
65    Close(DiscoveryKey),
66    /// Signal locally to protocol
67    SignalLocal((String, Vec<u8>)),
68}
69
70impl fmt::Debug for Event {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        match self {
73            Event::Handshake(remote_key) => {
74                write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key))
75            }
76            Event::DiscoveryKey(discovery_key) => {
77                write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key))
78            }
79            Event::Channel(channel) => {
80                write!(f, "Channel({})", &pretty_hash(channel.discovery_key()))
81            }
82            Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)),
83            Event::LocalSignal((name, data)) => {
84                write!(f, "LocalSignal(name={},len={})", name, data.len())
85            }
86        }
87    }
88}
89
90/// A Protocol stream for replicating hypercores over an encrypted connection.
91///
92/// The protocol expects an already-encrypted, message-framed connection
93/// (e.g., from hyperswarm). The `HandshakeResult` provides the handshake hash
94/// and public keys needed for capability verification.
95pub struct Protocol {
96    io: MessageIo,
97    is_initiator: bool,
98    channels: ChannelMap,
99    command_rx: Receiver<Command>,
100    command_tx: CommandTx,
101    outbound_rx: Receiver<Vec<ChannelMessage>>,
102    outbound_tx: Sender<Vec<ChannelMessage>>,
103    queued_events: VecDeque<Event>,
104    handshake_emitted: bool,
105}
106
107impl std::fmt::Debug for Protocol {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        f.debug_struct("Protocol")
110            .field("is_initiator", &self.is_initiator)
111            .field("channels", &self.channels)
112            .field("handshake_emitted", &self.handshake_emitted)
113            .field("queued_events", &self.queued_events)
114            .finish()
115    }
116}
117
118impl Protocol {
119    /// Create a new protocol instance.
120    ///
121    /// # Arguments
122    /// * `stream` - An already-encrypted, message-framed connection (e.g., hyperswarm `Connection`)
123    pub fn new(stream: Box<dyn CipherTrait>) -> Self {
124        let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP);
125        let (outbound_tx, outbound_rx): (
126            Sender<Vec<ChannelMessage>>,
127            Receiver<Vec<ChannelMessage>>,
128        ) = async_channel::bounded(CHANNEL_CAP);
129
130        let is_initiator = stream.is_initiator();
131
132        Protocol {
133            io: MessageIo::new(stream),
134            is_initiator,
135            channels: ChannelMap::new(),
136            command_rx,
137            command_tx: CommandTx(command_tx),
138            outbound_tx,
139            outbound_rx,
140            queued_events: VecDeque::new(),
141            handshake_emitted: false,
142        }
143    }
144
145    /// Whether this protocol stream initiated the underlying IO connection.
146    pub fn is_initiator(&self) -> bool {
147        self.is_initiator
148    }
149
150    /// Get your own Noise public key.
151    pub fn public_key(&self) -> [u8; PUBLIC_KEYLEN] {
152        self.io.local_public_key()
153    }
154
155    /// Get the remote's Noise public key.
156    pub fn remote_public_key(&self) -> Option<[u8; PUBLIC_KEYLEN]> {
157        self.io.remote_public_key()
158    }
159
160    /// Get a sender to send commands.
161    pub fn commands(&self) -> CommandTx {
162        self.command_tx.clone()
163    }
164
165    /// Give a command to the protocol.
166    pub async fn command(&self, command: Command) -> Result<()> {
167        self.command_tx.send(command).await
168    }
169
170    /// Open a new protocol channel.
171    ///
172    /// Once the other side proofed that it also knows the `key`, the channel is emitted as
173    /// `Event::Channel` on the protocol event stream.
174    pub fn open(&self, key: Key) -> impl Future<Output = Result<()>> + use<> {
175        self.command_tx.open(key)
176    }
177
178    /// Iterator of all currently opened channels.
179    pub fn channels(&self) -> impl Iterator<Item = &DiscoveryKey> {
180        self.channels.iter().map(|c| c.discovery_key())
181    }
182
183    #[instrument(skip_all, fields(initiator = ?self.is_initiator()))]
184    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Event>> {
185        let this = self.get_mut();
186
187        // Initiator needs to send and receive a message before proceeding
188        if this.is_initiator && this.io.handshake_hash().is_none() {
189            return_error!(this.poll_outbound_write(cx));
190            return_error!(this.poll_inbound_read(cx));
191            if this.io.handshake_hash().is_none() {
192                return Poll::Pending;
193            }
194        }
195        // Emit handshake event on first poll
196        if !this.handshake_emitted {
197            if let Some(remote_pubkey) = this.io.remote_public_key() {
198                this.handshake_emitted = true;
199                return Poll::Ready(Ok(Event::Handshake(remote_pubkey)));
200            }
201        }
202
203        // Drain queued events first.
204        if let Some(event) = this.queued_events.pop_front() {
205            return Poll::Ready(Ok(event));
206        }
207
208        // Read and process incoming messages.
209        return_error!(this.poll_inbound_read(cx));
210
211        // Check for commands.
212        return_error!(this.poll_commands(cx));
213
214        // Write everything we can write.
215        return_error!(this.poll_outbound_write(cx));
216
217        // Check if any events are enqueued.
218        if let Some(event) = this.queued_events.pop_front() {
219            Poll::Ready(Ok(event))
220        } else {
221            Poll::Pending
222        }
223    }
224
225    /// Poll commands.
226    fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> {
227        while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) {
228            if let Err(e) = self.on_command(command) {
229                error!(error = ?e, "Error handling command");
230                return Err(e);
231            }
232        }
233        Ok(())
234    }
235
236    // just handles Close and LocalSignal
237    fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool {
238        // If message is close, close the local channel.
239        if let ChannelMessage {
240            channel,
241            message: Message::Close(_),
242            ..
243        } = message
244        {
245            self.close_local(*channel);
246        // If message is a LocalSignal, emit an event and return false to indicate
247        // this message should be filtered out.
248        } else if let ChannelMessage {
249            message: Message::LocalSignal((name, data)),
250            ..
251        } = message
252        {
253            self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec())));
254            return false;
255        }
256        true
257    }
258
259    /// Poll for inbound messages and process them.
260    #[instrument(skip_all, err)]
261    fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> {
262        loop {
263            match self.io.poll_inbound(cx) {
264                Poll::Ready(Some(result)) => {
265                    let messages = result?;
266                    self.on_inbound_channel_messages(messages)?;
267                }
268                Poll::Ready(None) => return Ok(()),
269                Poll::Pending => return Ok(()),
270            }
271        }
272    }
273
274    /// Poll for outbound messages and write them.
275    #[instrument(skip_all)]
276    fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> {
277        loop {
278            // Drive outbound IO
279            if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) {
280                error!(err = ?e, "error from poll_outbound");
281                return Err(e);
282            }
283            // Send messages from outbound_rx
284            match Pin::new(&mut self.outbound_rx).poll_next(cx) {
285                Poll::Ready(Some(mut messages)) => {
286                    if !messages.is_empty() {
287                        messages.retain(|message| self.on_outbound_message(message));
288                        for msg in messages {
289                            self.io.enqueue(msg);
290                        }
291                    }
292                }
293                Poll::Ready(None) => unreachable!("Channel closed before end"),
294                Poll::Pending => return Ok(()),
295            }
296        }
297    }
298
299    #[instrument(skip_all)]
300    fn on_inbound_channel_messages(&mut self, channel_messages: Vec<ChannelMessage>) -> Result<()> {
301        for channel_message in channel_messages {
302            self.on_inbound_message(channel_message)?
303        }
304        Ok(())
305    }
306
307    #[instrument(skip_all)]
308    fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> {
309        let (remote_id, message) = channel_message.into_split();
310        match message {
311            Message::Open(msg) => self.on_open(remote_id, msg)?,
312            Message::Close(msg) => self.on_close(remote_id, msg)?,
313            _ => self
314                .channels
315                .forward_inbound_message(remote_id as usize, message)?,
316        }
317        Ok(())
318    }
319
320    #[instrument(skip(self))]
321    fn on_command(&mut self, command: Command) -> Result<()> {
322        match command {
323            Command::Open(key) => self.command_open(key),
324            Command::Close(discovery_key) => self.command_close(discovery_key),
325            Command::SignalLocal((name, data)) => self.command_signal_local(name, data),
326        }
327    }
328
329    /// Open a Channel with the given key. Adding it to our channel map
330    #[instrument(skip_all)]
331    fn command_open(&mut self, key: Key) -> Result<()> {
332        // Create a new channel.
333        let channel_handle = self.channels.attach_local(key);
334        // Safe because attach_local always puts Some(local_id)
335        let local_id = channel_handle.local_id().unwrap();
336        let discovery_key = *channel_handle.discovery_key();
337
338        // If the channel was already opened from the remote end, verify, and if
339        // verification is ok, push a channel open event.
340        if channel_handle.is_connected() {
341            self.accept_channel(local_id)?;
342        }
343
344        // Tell the remote end about the new channel.
345        let capability = self.capability(&key);
346        let channel = local_id as u64;
347        let message = Message::Open(Open {
348            channel,
349            protocol: PROTOCOL_NAME.to_string(),
350            discovery_key: discovery_key.to_vec(),
351            capability,
352        });
353        let channel_message = ChannelMessage::new(channel, message);
354        self.io.enqueue(channel_message);
355        Ok(())
356    }
357
358    fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
359        if self.channels.has_channel(&discovery_key) {
360            self.channels.remove(&discovery_key);
361            self.queue_event(Event::Close(discovery_key));
362        }
363        Ok(())
364    }
365
366    fn command_signal_local(&mut self, name: String, data: Vec<u8>) -> Result<()> {
367        self.queue_event(Event::LocalSignal((name, data)));
368        Ok(())
369    }
370
371    #[instrument(skip(self))]
372    fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> {
373        let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?;
374        let channel_handle =
375            self.channels
376                .attach_remote(discovery_key, ch as usize, msg.capability);
377
378        if channel_handle.is_connected() {
379            let local_id = channel_handle.local_id().unwrap();
380            self.accept_channel(local_id)?;
381        } else {
382            self.queue_event(Event::DiscoveryKey(discovery_key));
383        }
384
385        Ok(())
386    }
387
388    #[instrument(skip(self))]
389    fn queue_event(&mut self, event: Event) {
390        self.queued_events.push_back(event);
391    }
392
393    #[instrument(skip(self))]
394    fn accept_channel(&mut self, local_id: usize) -> Result<()> {
395        let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?;
396        self.verify_remote_capability(remote_capability.cloned(), key)
397            .expect("TODO channel can only be accepted after first message")?;
398        let channel = self.channels.accept(local_id, self.outbound_tx.clone())?;
399        self.queue_event(Event::Channel(channel));
400        Ok(())
401    }
402
403    fn close_local(&mut self, local_id: u64) {
404        if let Some(channel) = self.channels.get_local(local_id as usize) {
405            let discovery_key = *channel.discovery_key();
406            self.channels.remove(&discovery_key);
407            self.queue_event(Event::Close(discovery_key));
408        }
409    }
410
411    fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> {
412        if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) {
413            let discovery_key = *channel_handle.discovery_key();
414            // There is a possibility both sides will close at the same time, so
415            // the channel could be closed already, let's tolerate that.
416            self.channels
417                .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?;
418            self.channels.remove(&discovery_key);
419            self.queue_event(Event::Close(discovery_key));
420        }
421        Ok(())
422    }
423
424    #[instrument(skip_all)]
425    fn capability(&self, key: &[u8]) -> Option<Vec<u8>> {
426        let is_initiator = self.is_initiator;
427        let remote_pubkey = self.remote_public_key()?;
428        let local_pubkey = self.public_key();
429        let handshake_hash = self.io.handshake_hash()?;
430        HandshakeResult::from_pre_encrypted(
431            is_initiator,
432            local_pubkey,
433            remote_pubkey,
434            handshake_hash.to_vec(),
435        )
436        .capability(key)
437    }
438
439    #[instrument(skip_all)]
440    fn verify_remote_capability(
441        &self,
442        capability: Option<Vec<u8>>,
443        key: &[u8],
444    ) -> Option<Result<()>> {
445        let is_initiator = self.is_initiator;
446        let remote_pubkey = self.remote_public_key()?;
447        let local_pubkey = self.public_key();
448        let handshake_hash = self.io.handshake_hash()?;
449        Some(
450            HandshakeResult::from_pre_encrypted(
451                is_initiator,
452                local_pubkey,
453                remote_pubkey,
454                handshake_hash.to_vec(),
455            )
456            .verify_remote_capability(capability, key),
457        )
458    }
459}
460
461impl Stream for Protocol {
462    type Item = Result<Event>;
463    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
464        match Protocol::poll_next(self, cx) {
465            Poll::Ready(Ok(e)) => Poll::Ready(Some(Ok(e))),
466            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
467            Poll::Pending => Poll::Pending,
468        }
469    }
470}
471
472/// Send [`Command`]s to the [`Protocol`].
473#[derive(Clone, Debug)]
474pub struct CommandTx(Sender<Command>);
475
476impl CommandTx {
477    /// Send a protocol command
478    pub fn send(&self, command: Command) -> impl Future<Output = Result<()>> + use<> {
479        let sender = self.0.clone();
480        async move { sender.send(command).await.map_err(map_channel_err) }
481    }
482    /// Open a protocol channel.
483    ///
484    /// The channel will be emitted on the main protocol.
485    pub fn open(&self, key: Key) -> impl Future<Output = Result<()>> + use<> {
486        self.send(Command::Open(key))
487    }
488
489    /// Close a protocol channel.
490    pub async fn close(&self, discovery_key: DiscoveryKey) -> Result<()> {
491        self.send(Command::Close(discovery_key)).await
492    }
493
494    /// Send a local signal event to the protocol.
495    pub async fn signal_local(&self, name: &str, data: Vec<u8>) -> Result<()> {
496        self.send(Command::SignalLocal((name.to_string(), data)))
497            .await
498    }
499}
500
501fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> {
502    key.try_into()
503        .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long"))
504}