hypercore_protocol/
channels.rs

1use crate::message::ChannelMessage;
2use crate::schema::*;
3use crate::util::{map_channel_err, pretty_hash};
4use crate::Message;
5use crate::{discovery_key, DiscoveryKey, Key};
6use async_channel::{Receiver, Sender, TrySendError};
7use futures_lite::ready;
8use futures_lite::stream::Stream;
9use std::collections::HashMap;
10use std::fmt;
11use std::io::{Error, ErrorKind, Result};
12use std::pin::Pin;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::Arc;
15use std::task::Poll;
16use tracing::debug;
17
18/// A protocol channel.
19///
20/// This is the handle that can be sent to other threads.
21#[derive(Clone)]
22pub struct Channel {
23    inbound_rx: Option<Receiver<Message>>,
24    direct_inbound_tx: Sender<Message>,
25    outbound_tx: Sender<Vec<ChannelMessage>>,
26    key: Key,
27    discovery_key: DiscoveryKey,
28    local_id: usize,
29    closed: Arc<AtomicBool>,
30}
31
32impl PartialEq for Channel {
33    fn eq(&self, other: &Self) -> bool {
34        self.key == other.key
35            && self.discovery_key == other.discovery_key
36            && self.local_id == other.local_id
37    }
38}
39
40impl fmt::Debug for Channel {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("Channel")
43            .field("discovery_key", &pretty_hash(&self.discovery_key))
44            .finish()
45    }
46}
47
48impl Channel {
49    fn new(
50        inbound_rx: Option<Receiver<Message>>,
51        direct_inbound_tx: Sender<Message>,
52        outbound_tx: Sender<Vec<ChannelMessage>>,
53        discovery_key: DiscoveryKey,
54        key: Key,
55        local_id: usize,
56        closed: Arc<AtomicBool>,
57    ) -> Self {
58        Self {
59            inbound_rx,
60            direct_inbound_tx,
61            outbound_tx,
62            key,
63            discovery_key,
64            local_id,
65            closed,
66        }
67    }
68    /// Get the discovery key of this channel.
69    pub fn discovery_key(&self) -> &[u8; 32] {
70        &self.discovery_key
71    }
72
73    /// Get the key of this channel.
74    pub fn key(&self) -> &[u8; 32] {
75        &self.key
76    }
77
78    /// Get the local wire ID of this channel.
79    pub fn id(&self) -> usize {
80        self.local_id
81    }
82
83    /// Check if the channel is closed.
84    pub fn closed(&self) -> bool {
85        self.closed.load(Ordering::SeqCst)
86    }
87
88    /// Send a message over the channel.
89    pub async fn send(&mut self, message: Message) -> Result<()> {
90        if self.closed() {
91            return Err(Error::new(
92                ErrorKind::ConnectionAborted,
93                "Channel is closed",
94            ));
95        }
96        debug!("TX:\n{message:?}\n");
97        let message = ChannelMessage::new(self.local_id as u64, message);
98        self.outbound_tx
99            .send(vec![message])
100            .await
101            .map_err(map_channel_err)
102    }
103
104    /// Send a batch of messages over the channel.
105    pub async fn send_batch(&mut self, messages: &[Message]) -> Result<()> {
106        // In javascript this is cork()/uncork(), e.g.:
107        //
108        // https://github.com/holepunchto/hypercore/blob/c338b9aaa4442d35bc9d283d2c242b86a46de6d4/lib/replicator.js#L402-L418
109        //
110        // at the protomux level, where there can be messages from multiple channels in a single
111        // stream write:
112        //
113        // https://github.com/holepunchto/protomux/blob/d3d6f8f55e52c2fbe5cd56f5d067ac43ca13c27d/index.js#L368-L389
114        //
115        // Batching messages across channels like protomux is capable of doing is not (yet) implemented.
116        if self.closed() {
117            return Err(Error::new(
118                ErrorKind::ConnectionAborted,
119                "Channel is closed",
120            ));
121        }
122
123        let messages = messages
124            .iter()
125            .map(|message| {
126                debug!("TX:\n{message:?}\n");
127                ChannelMessage::new(self.local_id as u64, message.clone())
128            })
129            .collect();
130        self.outbound_tx
131            .send(messages)
132            .await
133            .map_err(map_channel_err)
134    }
135
136    /// Take the receiving part out of the channel.
137    ///
138    /// After taking the receiver, this Channel will not emit messages when
139    /// polled as a stream. The returned receiver will.
140    pub fn take_receiver(&mut self) -> Option<Receiver<Message>> {
141        self.inbound_rx.take()
142    }
143
144    /// Clone the local sending part of the channel receiver. Useful
145    /// for direct local communication to the channel listener. Typically
146    /// you will only want to send a LocalSignal message with this sender to make
147    /// it clear what event came from the remote peer and what was local
148    /// signaling.
149    pub fn local_sender(&self) -> Sender<Message> {
150        self.direct_inbound_tx.clone()
151    }
152
153    /// Send a close message and close this channel.
154    pub async fn close(&mut self) -> Result<()> {
155        if self.closed() {
156            return Ok(());
157        }
158        let close = Close {
159            channel: self.local_id as u64,
160        };
161        self.send(Message::Close(close)).await?;
162        self.closed.store(true, Ordering::SeqCst);
163        Ok(())
164    }
165
166    /// Signal the protocol to produce Event::LocalSignal. If you want to send a message
167    /// to the channel level, see take_receiver() and local_sender().
168    pub async fn signal_local_protocol(&mut self, name: &str, data: Vec<u8>) -> Result<()> {
169        self.send(Message::LocalSignal((name.to_string(), data)))
170            .await?;
171        Ok(())
172    }
173}
174
175impl Stream for Channel {
176    type Item = Message;
177    fn poll_next(
178        self: Pin<&mut Self>,
179        cx: &mut std::task::Context<'_>,
180    ) -> std::task::Poll<Option<Self::Item>> {
181        let this = self.get_mut();
182        match this.inbound_rx.as_mut() {
183            None => Poll::Ready(None),
184            Some(ref mut inbound_rx) => {
185                let message = ready!(Pin::new(inbound_rx).poll_next(cx));
186                Poll::Ready(message)
187            }
188        }
189    }
190}
191
192/// The handle for a channel that lives with the main Protocol.
193#[derive(Clone, Debug)]
194pub(crate) struct ChannelHandle {
195    discovery_key: DiscoveryKey,
196    local_state: Option<LocalState>,
197    remote_state: Option<RemoteState>,
198    inbound_tx: Option<Sender<Message>>,
199    closed: Arc<AtomicBool>,
200}
201
202#[derive(Clone, Debug)]
203struct LocalState {
204    key: Key,
205    local_id: usize,
206}
207
208#[derive(Clone, Debug)]
209struct RemoteState {
210    remote_id: usize,
211    remote_capability: Option<Vec<u8>>,
212}
213
214impl ChannelHandle {
215    fn new(discovery_key: DiscoveryKey) -> Self {
216        Self {
217            discovery_key,
218            local_state: None,
219            remote_state: None,
220            inbound_tx: None,
221            closed: Arc::new(AtomicBool::new(false)),
222        }
223    }
224    fn new_local(local_id: usize, discovery_key: DiscoveryKey, key: Key) -> Self {
225        let mut this = Self::new(discovery_key);
226        this.attach_local(local_id, key);
227        this
228    }
229
230    fn new_remote(
231        remote_id: usize,
232        discovery_key: DiscoveryKey,
233        remote_capability: Option<Vec<u8>>,
234    ) -> Self {
235        let mut this = Self::new(discovery_key);
236        this.attach_remote(remote_id, remote_capability);
237        this
238    }
239
240    pub(crate) fn discovery_key(&self) -> &[u8; 32] {
241        &self.discovery_key
242    }
243
244    pub(crate) fn local_id(&self) -> Option<usize> {
245        self.local_state.as_ref().map(|s| s.local_id)
246    }
247
248    pub(crate) fn remote_id(&self) -> Option<usize> {
249        self.remote_state.as_ref().map(|s| s.remote_id)
250    }
251
252    pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) {
253        let local_state = LocalState { local_id, key };
254        self.local_state = Some(local_state);
255    }
256
257    pub(crate) fn attach_remote(&mut self, remote_id: usize, remote_capability: Option<Vec<u8>>) {
258        let remote_state = RemoteState {
259            remote_id,
260            remote_capability,
261        };
262        self.remote_state = Some(remote_state);
263    }
264
265    pub(crate) fn is_connected(&self) -> bool {
266        self.local_state.is_some() && self.remote_state.is_some()
267    }
268
269    pub(crate) fn prepare_to_verify(&self) -> Result<(&Key, Option<&Vec<u8>>)> {
270        if !self.is_connected() {
271            return Err(error("Channel is not opened from both local and remote"));
272        }
273        // Safe because of the is_connected() check above.
274        let local_state = self.local_state.as_ref().unwrap();
275        let remote_state = self.remote_state.as_ref().unwrap();
276        Ok((&local_state.key, remote_state.remote_capability.as_ref()))
277    }
278
279    pub(crate) fn open(&mut self, outbound_tx: Sender<Vec<ChannelMessage>>) -> Channel {
280        let local_state = self
281            .local_state
282            .as_ref()
283            .expect("May not open channel that is not locally attached");
284
285        let (inbound_tx, inbound_rx) = async_channel::unbounded();
286        let channel = Channel::new(
287            Some(inbound_rx),
288            inbound_tx.clone(),
289            outbound_tx,
290            self.discovery_key,
291            local_state.key,
292            local_state.local_id,
293            self.closed.clone(),
294        );
295
296        self.inbound_tx = Some(inbound_tx);
297        channel
298    }
299
300    pub(crate) fn try_send_inbound(&mut self, message: Message) -> std::io::Result<()> {
301        if let Some(inbound_tx) = self.inbound_tx.as_mut() {
302            inbound_tx
303                .try_send(message)
304                .map_err(|e| error(format!("Sending to channel failed: {e}").as_str()))
305        } else {
306            Err(error("Channel is not open"))
307        }
308    }
309
310    pub(crate) fn try_send_inbound_tolerate_closed(
311        &mut self,
312        message: Message,
313    ) -> std::io::Result<()> {
314        if let Some(inbound_tx) = self.inbound_tx.as_mut() {
315            if let Err(err) = inbound_tx.try_send(message) {
316                match err {
317                    TrySendError::Full(e) => {
318                        return Err(error(format!("Sending to channel failed: {e}").as_str()))
319                    }
320                    TrySendError::Closed(_) => {}
321                }
322            }
323        }
324        Ok(())
325    }
326}
327
328impl Drop for ChannelHandle {
329    fn drop(&mut self) {
330        self.closed.store(true, Ordering::SeqCst);
331    }
332}
333
334/// The ChannelMap maintains a list of open channels and their local (tx) and remote (rx) channel IDs.
335#[derive(Debug)]
336pub(crate) struct ChannelMap {
337    channels: HashMap<String, ChannelHandle>,
338    local_id: Vec<Option<String>>,
339    remote_id: Vec<Option<String>>,
340}
341
342impl ChannelMap {
343    pub(crate) fn new() -> Self {
344        Self {
345            channels: HashMap::new(),
346            // Add a first None value to local_id to start ids at 1.
347            // This makes sure that 0 may be used for stream-level extensions.
348            local_id: vec![None],
349            remote_id: vec![],
350        }
351    }
352
353    pub(crate) fn attach_local(&mut self, key: Key) -> &ChannelHandle {
354        let discovery_key = discovery_key(&key);
355        let hdkey = hex::encode(discovery_key);
356        let local_id = self.alloc_local();
357
358        self.channels
359            .entry(hdkey.clone())
360            .and_modify(|channel| channel.attach_local(local_id, key))
361            .or_insert_with(|| ChannelHandle::new_local(local_id, discovery_key, key));
362
363        self.local_id[local_id] = Some(hdkey.clone());
364        self.channels.get(&hdkey).unwrap()
365    }
366
367    pub(crate) fn attach_remote(
368        &mut self,
369        discovery_key: DiscoveryKey,
370        remote_id: usize,
371        remote_capability: Option<Vec<u8>>,
372    ) -> &ChannelHandle {
373        let hdkey = hex::encode(discovery_key);
374        self.alloc_remote(remote_id);
375        self.channels
376            .entry(hdkey.clone())
377            .and_modify(|channel| channel.attach_remote(remote_id, remote_capability.clone()))
378            .or_insert_with(|| {
379                ChannelHandle::new_remote(remote_id, discovery_key, remote_capability)
380            });
381        self.remote_id[remote_id] = Some(hdkey.clone());
382        self.channels.get(&hdkey).unwrap()
383    }
384
385    pub(crate) fn get_remote_mut(&mut self, remote_id: usize) -> Option<&mut ChannelHandle> {
386        if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() {
387            self.channels.get_mut(hdkey)
388        } else {
389            None
390        }
391    }
392
393    pub(crate) fn get_remote(&self, remote_id: usize) -> Option<&ChannelHandle> {
394        if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() {
395            self.channels.get(hdkey)
396        } else {
397            None
398        }
399    }
400
401    pub(crate) fn get_local_mut(&mut self, local_id: usize) -> Option<&mut ChannelHandle> {
402        if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() {
403            self.channels.get_mut(hdkey)
404        } else {
405            None
406        }
407    }
408
409    pub(crate) fn get_local(&self, local_id: usize) -> Option<&ChannelHandle> {
410        if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() {
411            self.channels.get(hdkey)
412        } else {
413            None
414        }
415    }
416
417    pub(crate) fn has_channel(&self, discovery_key: &[u8]) -> bool {
418        let hdkey = hex::encode(discovery_key);
419        self.channels.contains_key(&hdkey)
420    }
421
422    pub(crate) fn remove(&mut self, discovery_key: &[u8]) {
423        let hdkey = hex::encode(discovery_key);
424        let channel = self.channels.get(&hdkey);
425        if let Some(channel) = channel {
426            if let Some(local_id) = channel.local_id() {
427                self.local_id[local_id] = None;
428            }
429            if let Some(remote_id) = channel.remote_id() {
430                self.remote_id[remote_id] = None;
431            }
432        }
433        self.channels.remove(&hdkey);
434    }
435
436    pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec<u8>>)> {
437        let channel_handle = self
438            .get_local(local_id)
439            .ok_or_else(|| error("Channel not found"))?;
440        channel_handle.prepare_to_verify()
441    }
442
443    pub(crate) fn accept(
444        &mut self,
445        local_id: usize,
446        outbound_tx: Sender<Vec<ChannelMessage>>,
447    ) -> Result<Channel> {
448        let channel_handle = self
449            .get_local_mut(local_id)
450            .ok_or_else(|| error("Channel not found"))?;
451        if !channel_handle.is_connected() {
452            return Err(error("Channel is not opened from remote"));
453        }
454        let channel = channel_handle.open(outbound_tx);
455        Ok(channel)
456    }
457
458    pub(crate) fn forward_inbound_message(
459        &mut self,
460        remote_id: usize,
461        message: Message,
462    ) -> Result<()> {
463        if let Some(channel_handle) = self.get_remote_mut(remote_id) {
464            channel_handle.try_send_inbound(message)?;
465        }
466        Ok(())
467    }
468
469    pub(crate) fn forward_inbound_message_tolerate_closed(
470        &mut self,
471        remote_id: usize,
472        message: Message,
473    ) -> Result<()> {
474        if let Some(channel_handle) = self.get_remote_mut(remote_id) {
475            channel_handle.try_send_inbound_tolerate_closed(message)?;
476        }
477        Ok(())
478    }
479
480    fn alloc_local(&mut self) -> usize {
481        let empty_id = self
482            .local_id
483            .iter()
484            .skip(1)
485            .position(|x| x.is_none())
486            .map(|position| position + 1);
487        match empty_id {
488            Some(empty_id) => empty_id,
489            None => {
490                self.local_id.push(None);
491                self.local_id.len() - 1
492            }
493        }
494    }
495
496    fn alloc_remote(&mut self, id: usize) {
497        if self.remote_id.len() > id {
498            self.remote_id[id] = None;
499        } else {
500            self.remote_id.resize(id + 1, None)
501        }
502    }
503
504    pub(crate) fn iter(&self) -> impl Iterator<Item = &ChannelHandle> {
505        self.channels.values()
506    }
507}
508
509fn error(message: &str) -> Error {
510    Error::new(ErrorKind::Other, message)
511}