Skip to main content

hypercore_protocol/
channels.rs

1use crate::{
2    DiscoveryKey, Key, Message, discovery_key,
3    message::ChannelMessage,
4    schema::*,
5    util::{map_channel_err, pretty_hash},
6};
7use async_channel::{Receiver, Sender, TrySendError};
8use futures_lite::{ready, stream::Stream};
9use std::{
10    collections::HashMap,
11    fmt,
12    io::{Error, ErrorKind, Result},
13    pin::Pin,
14    sync::{
15        Arc,
16        atomic::{AtomicBool, Ordering},
17    },
18    task::Poll,
19};
20use tracing::instrument;
21
22/// A protocol channel.
23///
24/// This is the handle that can be sent to other threads.
25#[derive(Clone)]
26pub struct Channel {
27    inbound_rx: Option<Receiver<Message>>,
28    direct_inbound_tx: Sender<Message>,
29    outbound_tx: Sender<Vec<ChannelMessage>>,
30    key: Key,
31    discovery_key: DiscoveryKey,
32    local_id: usize,
33    closed: Arc<AtomicBool>,
34}
35
36impl PartialEq for Channel {
37    fn eq(&self, other: &Self) -> bool {
38        self.key == other.key
39            && self.discovery_key == other.discovery_key
40            && self.local_id == other.local_id
41    }
42}
43
44impl fmt::Debug for Channel {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.debug_struct("Channel")
47            .field("discovery_key", &pretty_hash(&self.discovery_key))
48            .finish()
49    }
50}
51
52impl Channel {
53    fn new(
54        inbound_rx: Option<Receiver<Message>>,
55        direct_inbound_tx: Sender<Message>,
56        outbound_tx: Sender<Vec<ChannelMessage>>,
57        discovery_key: DiscoveryKey,
58        key: Key,
59        local_id: usize,
60        closed: Arc<AtomicBool>,
61    ) -> Self {
62        Self {
63            inbound_rx,
64            direct_inbound_tx,
65            outbound_tx,
66            key,
67            discovery_key,
68            local_id,
69            closed,
70        }
71    }
72    /// Get the discovery key of this channel.
73    pub fn discovery_key(&self) -> &[u8; 32] {
74        &self.discovery_key
75    }
76
77    /// Get the key of this channel.
78    pub fn key(&self) -> &[u8; 32] {
79        &self.key
80    }
81
82    /// Get the local wire ID of this channel.
83    pub fn id(&self) -> usize {
84        self.local_id
85    }
86
87    /// Check if the channel is closed.
88    pub fn closed(&self) -> bool {
89        self.closed.load(Ordering::SeqCst)
90    }
91
92    /// Send a message over the channel.
93    pub async fn send(&self, message: Message) -> Result<()> {
94        if self.closed() {
95            return Err(Error::new(
96                ErrorKind::ConnectionAborted,
97                "Channel is closed",
98            ));
99        }
100        let message = ChannelMessage::new(self.local_id as u64, message);
101        self.outbound_tx
102            .send(vec![message])
103            .await
104            .map_err(map_channel_err)
105    }
106
107    /// Send a batch of messages over the channel.
108    pub fn send_batch(&self, messages: &[Message]) -> impl Future<Output = Result<()>> + use<> {
109        // In javascript this is cork()/uncork(), e.g.:
110        //
111        // https://github.com/holepunchto/hypercore/blob/c338b9aaa4442d35bc9d283d2c242b86a46de6d4/lib/replicator.js#L402-L418
112        //
113        // at the protomux level, where there can be messages from multiple channels in a single
114        // stream write:
115        //
116        // https://github.com/holepunchto/protomux/blob/d3d6f8f55e52c2fbe5cd56f5d067ac43ca13c27d/index.js#L368-L389
117        //
118        // Batching messages across channels like protomux is capable of doing is not (yet) implemented.
119
120        let closed = self.closed();
121
122        // we do this to avoid having the future capture &[Messages]
123        let messages = if !closed {
124            messages
125                .iter()
126                .map(|message| ChannelMessage::new(self.local_id as u64, message.clone()))
127                .collect()
128        } else {
129            vec![]
130        };
131
132        let outbound_tx = self.outbound_tx.clone();
133        async move {
134            if closed {
135                return Err(Error::new(
136                    ErrorKind::ConnectionAborted,
137                    "Channel is closed",
138                ));
139            }
140
141            outbound_tx.send(messages).await.map_err(map_channel_err)
142        }
143    }
144
145    /// Take the receiving part out of the channel.
146    ///
147    /// After taking the receiver, this Channel will not emit messages when
148    /// polled as a stream. The returned receiver will.
149    pub fn take_receiver(&mut self) -> Option<Receiver<Message>> {
150        self.inbound_rx.take()
151    }
152
153    /// Clone the local sending part of the channel receiver. Useful
154    /// for direct local communication to the channel listener. Typically
155    /// you will only want to send a LocalSignal message with this sender to make
156    /// it clear what event came from the remote peer and what was local
157    /// signaling.
158    pub fn local_sender(&self) -> Sender<Message> {
159        self.direct_inbound_tx.clone()
160    }
161
162    /// Send a close message and close this channel.
163    pub async fn close(&self) -> Result<()> {
164        if self.closed() {
165            return Ok(());
166        }
167        let close = Close {
168            channel: self.local_id as u64,
169        };
170        self.send(Message::Close(close)).await?;
171        self.closed.store(true, Ordering::SeqCst);
172        Ok(())
173    }
174
175    /// Signal the protocol to produce Event::LocalSignal. If you want to send a message
176    /// to the channel level, see take_receiver() and local_sender().
177    pub async fn signal_local_protocol(&self, name: &str, data: Vec<u8>) -> Result<()> {
178        self.send(Message::LocalSignal((name.to_string(), data)))
179            .await?;
180        Ok(())
181    }
182}
183
184impl Stream for Channel {
185    type Item = Message;
186    fn poll_next(
187        self: Pin<&mut Self>,
188        cx: &mut std::task::Context<'_>,
189    ) -> std::task::Poll<Option<Self::Item>> {
190        let this = self.get_mut();
191        match this.inbound_rx.as_mut() {
192            None => Poll::Ready(None),
193            Some(ref mut inbound_rx) => {
194                let message = ready!(Pin::new(inbound_rx).poll_next(cx));
195                Poll::Ready(message)
196            }
197        }
198    }
199}
200
201/// The handle for a channel that lives with the main Protocol.
202#[derive(Clone, Debug)]
203pub(crate) struct ChannelHandle {
204    discovery_key: DiscoveryKey,
205    local_state: Option<LocalState>,
206    remote_state: Option<RemoteState>,
207    inbound_tx: Option<Sender<Message>>,
208    closed: Arc<AtomicBool>,
209}
210
211#[derive(Clone, Debug)]
212struct LocalState {
213    key: Key,
214    local_id: usize,
215}
216
217#[derive(Clone, Debug)]
218struct RemoteState {
219    remote_id: usize,
220    remote_capability: Option<Vec<u8>>,
221}
222
223impl ChannelHandle {
224    fn new(discovery_key: DiscoveryKey) -> Self {
225        Self {
226            discovery_key,
227            local_state: None,
228            remote_state: None,
229            inbound_tx: None,
230            closed: Arc::new(AtomicBool::new(false)),
231        }
232    }
233    fn new_local(local_id: usize, discovery_key: DiscoveryKey, key: Key) -> Self {
234        let mut this = Self::new(discovery_key);
235        this.attach_local(local_id, key);
236        this
237    }
238
239    fn new_remote(
240        remote_id: usize,
241        discovery_key: DiscoveryKey,
242        remote_capability: Option<Vec<u8>>,
243    ) -> Self {
244        let mut this = Self::new(discovery_key);
245        this.attach_remote(remote_id, remote_capability);
246        this
247    }
248
249    pub(crate) fn discovery_key(&self) -> &[u8; 32] {
250        &self.discovery_key
251    }
252
253    pub(crate) fn local_id(&self) -> Option<usize> {
254        self.local_state.as_ref().map(|s| s.local_id)
255    }
256
257    pub(crate) fn remote_id(&self) -> Option<usize> {
258        self.remote_state.as_ref().map(|s| s.remote_id)
259    }
260
261    #[instrument(skip_all, fields(local_id = local_id))]
262    pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) {
263        let local_state = LocalState { local_id, key };
264        self.local_state = Some(local_state);
265    }
266
267    pub(crate) fn attach_remote(&mut self, remote_id: usize, remote_capability: Option<Vec<u8>>) {
268        let remote_state = RemoteState {
269            remote_id,
270            remote_capability,
271        };
272        self.remote_state = Some(remote_state);
273    }
274
275    pub(crate) fn is_connected(&self) -> bool {
276        self.local_state.is_some() && self.remote_state.is_some()
277    }
278
279    pub(crate) fn prepare_to_verify(&self) -> Result<(&Key, Option<&Vec<u8>>)> {
280        if !self.is_connected() {
281            return Err(error("Channel is not opened from both local and remote"));
282        }
283        // Safe because of the is_connected() check above.
284        let local_state = self.local_state.as_ref().unwrap();
285        let remote_state = self.remote_state.as_ref().unwrap();
286        Ok((&local_state.key, remote_state.remote_capability.as_ref()))
287    }
288
289    #[instrument(skip_all)]
290    pub(crate) fn open(&mut self, outbound_tx: Sender<Vec<ChannelMessage>>) -> Channel {
291        let local_state = self
292            .local_state
293            .as_ref()
294            .expect("May not open channel that is not locally attached");
295
296        let (inbound_tx, inbound_rx) = async_channel::unbounded();
297        let channel = Channel::new(
298            Some(inbound_rx),
299            inbound_tx.clone(),
300            outbound_tx,
301            self.discovery_key,
302            local_state.key,
303            local_state.local_id,
304            self.closed.clone(),
305        );
306
307        self.inbound_tx = Some(inbound_tx);
308        channel
309    }
310
311    pub(crate) fn try_send_inbound(&mut self, message: Message) -> std::io::Result<()> {
312        if let Some(inbound_tx) = self.inbound_tx.as_mut() {
313            inbound_tx
314                .try_send(message)
315                .map_err(|e| error(format!("Sending to channel failed: {e}").as_str()))
316        } else {
317            Err(error("Channel is not open"))
318        }
319    }
320
321    pub(crate) fn try_send_inbound_tolerate_closed(
322        &mut self,
323        message: Message,
324    ) -> std::io::Result<()> {
325        if let Some(inbound_tx) = self.inbound_tx.as_mut()
326            && let Err(err) = inbound_tx.try_send(message)
327        {
328            match err {
329                TrySendError::Full(e) => {
330                    return Err(error(format!("Sending to channel failed: {e}").as_str()));
331                }
332                TrySendError::Closed(_) => {}
333            }
334        }
335        Ok(())
336    }
337}
338
339impl Drop for ChannelHandle {
340    fn drop(&mut self) {
341        self.closed.store(true, Ordering::SeqCst);
342    }
343}
344
345/// The ChannelMap maintains a list of open channels and their local (tx) and remote (rx) channel IDs.
346#[derive(Debug)]
347pub(crate) struct ChannelMap {
348    channels: HashMap<String, ChannelHandle>,
349    local_id: Vec<Option<String>>,
350    remote_id: Vec<Option<String>>,
351}
352
353impl ChannelMap {
354    pub(crate) fn new() -> Self {
355        Self {
356            channels: HashMap::new(),
357            // Add a first None value to local_id to start ids at 1.
358            // This makes sure that 0 may be used for stream-level extensions.
359            local_id: vec![None],
360            remote_id: vec![],
361        }
362    }
363
364    pub(crate) fn attach_local(&mut self, key: Key) -> &ChannelHandle {
365        let discovery_key = discovery_key(&key);
366        let hdkey = hex::encode(discovery_key);
367        let local_id = self.alloc_local();
368
369        self.channels
370            .entry(hdkey.clone())
371            .and_modify(|channel| channel.attach_local(local_id, key))
372            .or_insert_with(|| ChannelHandle::new_local(local_id, discovery_key, key));
373
374        self.local_id[local_id] = Some(hdkey.clone());
375        self.channels.get(&hdkey).unwrap()
376    }
377
378    pub(crate) fn attach_remote(
379        &mut self,
380        discovery_key: DiscoveryKey,
381        remote_id: usize,
382        remote_capability: Option<Vec<u8>>,
383    ) -> &ChannelHandle {
384        let hdkey = hex::encode(discovery_key);
385        self.alloc_remote(remote_id);
386        self.channels
387            .entry(hdkey.clone())
388            .and_modify(|channel| channel.attach_remote(remote_id, remote_capability.clone()))
389            .or_insert_with(|| {
390                ChannelHandle::new_remote(remote_id, discovery_key, remote_capability)
391            });
392        self.remote_id[remote_id] = Some(hdkey.clone());
393        self.channels.get(&hdkey).unwrap()
394    }
395
396    pub(crate) fn get_remote_mut(&mut self, remote_id: usize) -> Option<&mut ChannelHandle> {
397        if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() {
398            self.channels.get_mut(hdkey)
399        } else {
400            None
401        }
402    }
403
404    pub(crate) fn get_remote(&self, remote_id: usize) -> Option<&ChannelHandle> {
405        if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() {
406            self.channels.get(hdkey)
407        } else {
408            None
409        }
410    }
411
412    pub(crate) fn get_local_mut(&mut self, local_id: usize) -> Option<&mut ChannelHandle> {
413        if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() {
414            self.channels.get_mut(hdkey)
415        } else {
416            None
417        }
418    }
419
420    pub(crate) fn get_local(&self, local_id: usize) -> Option<&ChannelHandle> {
421        if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() {
422            self.channels.get(hdkey)
423        } else {
424            None
425        }
426    }
427
428    pub(crate) fn has_channel(&self, discovery_key: &[u8]) -> bool {
429        let hdkey = hex::encode(discovery_key);
430        self.channels.contains_key(&hdkey)
431    }
432
433    pub(crate) fn remove(&mut self, discovery_key: &[u8]) {
434        let hdkey = hex::encode(discovery_key);
435        let channel = self.channels.get(&hdkey);
436        if let Some(channel) = channel {
437            if let Some(local_id) = channel.local_id() {
438                self.local_id[local_id] = None;
439            }
440            if let Some(remote_id) = channel.remote_id() {
441                self.remote_id[remote_id] = None;
442            }
443        }
444        self.channels.remove(&hdkey);
445    }
446
447    #[instrument(skip(self))]
448    pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec<u8>>)> {
449        let channel_handle = self
450            .get_local(local_id)
451            .ok_or_else(|| error("Channel not found"))?;
452        channel_handle.prepare_to_verify()
453    }
454
455    pub(crate) fn accept(
456        &mut self,
457        local_id: usize,
458        outbound_tx: Sender<Vec<ChannelMessage>>,
459    ) -> Result<Channel> {
460        let channel_handle = self
461            .get_local_mut(local_id)
462            .ok_or_else(|| error("Channel not found"))?;
463        if !channel_handle.is_connected() {
464            return Err(error("Channel is not opened from remote"));
465        }
466        let channel = channel_handle.open(outbound_tx);
467        Ok(channel)
468    }
469
470    pub(crate) fn forward_inbound_message(
471        &mut self,
472        remote_id: usize,
473        message: Message,
474    ) -> Result<()> {
475        if let Some(channel_handle) = self.get_remote_mut(remote_id) {
476            channel_handle.try_send_inbound(message)?;
477        }
478        Ok(())
479    }
480
481    pub(crate) fn forward_inbound_message_tolerate_closed(
482        &mut self,
483        remote_id: usize,
484        message: Message,
485    ) -> Result<()> {
486        if let Some(channel_handle) = self.get_remote_mut(remote_id) {
487            channel_handle.try_send_inbound_tolerate_closed(message)?;
488        }
489        Ok(())
490    }
491
492    fn alloc_local(&mut self) -> usize {
493        let empty_id = self
494            .local_id
495            .iter()
496            .skip(1)
497            .position(|x| x.is_none())
498            .map(|position| position + 1);
499        match empty_id {
500            Some(empty_id) => empty_id,
501            None => {
502                self.local_id.push(None);
503                self.local_id.len() - 1
504            }
505        }
506    }
507
508    fn alloc_remote(&mut self, id: usize) {
509        if self.remote_id.len() > id {
510            self.remote_id[id] = None;
511        } else {
512            self.remote_id.resize(id + 1, None)
513        }
514    }
515
516    pub(crate) fn iter(&self) -> impl Iterator<Item = &ChannelHandle> {
517        self.channels.values()
518    }
519}
520
521fn error(message: &str) -> Error {
522    Error::other(message)
523}