1use async_channel::{Receiver, Sender};
2use futures_lite::stream::Stream;
3use hypercore_handshake::{CipherTrait, state_machine::PUBLIC_KEYLEN};
4use std::{
5 collections::VecDeque,
6 convert::TryInto,
7 fmt,
8 io::{self, Result},
9 pin::Pin,
10 task::{Context, Poll},
11};
12use tracing::{error, instrument};
13
14use crate::{
15 channels::{Channel, ChannelMap},
16 constants::PROTOCOL_NAME,
17 crypto::HandshakeResult,
18 message::{ChannelMessage, Message},
19 mqueue::MessageIo,
20 schema::*,
21 util::{map_channel_err, pretty_hash},
22};
23
24macro_rules! return_error {
25 ($msg:expr) => {
26 if let Err(e) = $msg {
27 return Poll::Ready(Err(e));
28 }
29 };
30}
31
32const CHANNEL_CAP: usize = 1000;
33
34pub(crate) type RemotePublicKey = [u8; 32];
36pub type DiscoveryKey = [u8; 32];
38pub type Key = [u8; 32];
40
41#[non_exhaustive]
43#[derive(PartialEq)]
44pub enum Event {
45 Handshake(RemotePublicKey),
48 DiscoveryKey(DiscoveryKey),
50 Channel(Channel),
52 Close(DiscoveryKey),
54 LocalSignal((String, Vec<u8>)),
57}
58
59#[derive(Debug)]
61pub enum Command {
62 Open(Key),
64 Close(DiscoveryKey),
66 SignalLocal((String, Vec<u8>)),
68}
69
70impl fmt::Debug for Event {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 Event::Handshake(remote_key) => {
74 write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key))
75 }
76 Event::DiscoveryKey(discovery_key) => {
77 write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key))
78 }
79 Event::Channel(channel) => {
80 write!(f, "Channel({})", &pretty_hash(channel.discovery_key()))
81 }
82 Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)),
83 Event::LocalSignal((name, data)) => {
84 write!(f, "LocalSignal(name={},len={})", name, data.len())
85 }
86 }
87 }
88}
89
90pub struct Protocol {
96 io: MessageIo,
97 is_initiator: bool,
98 channels: ChannelMap,
99 command_rx: Receiver<Command>,
100 command_tx: CommandTx,
101 outbound_rx: Receiver<Vec<ChannelMessage>>,
102 outbound_tx: Sender<Vec<ChannelMessage>>,
103 queued_events: VecDeque<Event>,
104 handshake_emitted: bool,
105}
106
107impl std::fmt::Debug for Protocol {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 f.debug_struct("Protocol")
110 .field("is_initiator", &self.is_initiator)
111 .field("channels", &self.channels)
112 .field("handshake_emitted", &self.handshake_emitted)
113 .field("queued_events", &self.queued_events)
114 .finish()
115 }
116}
117
118impl Protocol {
119 pub fn new(stream: Box<dyn CipherTrait>) -> Self {
124 let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP);
125 let (outbound_tx, outbound_rx): (
126 Sender<Vec<ChannelMessage>>,
127 Receiver<Vec<ChannelMessage>>,
128 ) = async_channel::bounded(CHANNEL_CAP);
129
130 let is_initiator = stream.is_initiator();
131
132 Protocol {
133 io: MessageIo::new(stream),
134 is_initiator,
135 channels: ChannelMap::new(),
136 command_rx,
137 command_tx: CommandTx(command_tx),
138 outbound_tx,
139 outbound_rx,
140 queued_events: VecDeque::new(),
141 handshake_emitted: false,
142 }
143 }
144
145 pub fn is_initiator(&self) -> bool {
147 self.is_initiator
148 }
149
150 pub fn public_key(&self) -> [u8; PUBLIC_KEYLEN] {
152 self.io.local_public_key()
153 }
154
155 pub fn remote_public_key(&self) -> Option<[u8; PUBLIC_KEYLEN]> {
157 self.io.remote_public_key()
158 }
159
160 pub fn commands(&self) -> CommandTx {
162 self.command_tx.clone()
163 }
164
165 pub async fn command(&self, command: Command) -> Result<()> {
167 self.command_tx.send(command).await
168 }
169
170 pub fn open(&self, key: Key) -> impl Future<Output = Result<()>> + use<> {
175 self.command_tx.open(key)
176 }
177
178 pub fn channels(&self) -> impl Iterator<Item = &DiscoveryKey> {
180 self.channels.iter().map(|c| c.discovery_key())
181 }
182
183 #[instrument(skip_all, fields(initiator = ?self.is_initiator()))]
184 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Event>> {
185 let this = self.get_mut();
186
187 if this.is_initiator && this.io.handshake_hash().is_none() {
189 return_error!(this.poll_outbound_write(cx));
190 return_error!(this.poll_inbound_read(cx));
191 if this.io.handshake_hash().is_none() {
192 return Poll::Pending;
193 }
194 }
195 if !this.handshake_emitted {
197 if let Some(remote_pubkey) = this.io.remote_public_key() {
198 this.handshake_emitted = true;
199 return Poll::Ready(Ok(Event::Handshake(remote_pubkey)));
200 }
201 }
202
203 if let Some(event) = this.queued_events.pop_front() {
205 return Poll::Ready(Ok(event));
206 }
207
208 return_error!(this.poll_inbound_read(cx));
210
211 return_error!(this.poll_commands(cx));
213
214 return_error!(this.poll_outbound_write(cx));
216
217 if let Some(event) = this.queued_events.pop_front() {
219 Poll::Ready(Ok(event))
220 } else {
221 Poll::Pending
222 }
223 }
224
225 fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> {
227 while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) {
228 if let Err(e) = self.on_command(command) {
229 error!(error = ?e, "Error handling command");
230 return Err(e);
231 }
232 }
233 Ok(())
234 }
235
236 fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool {
238 if let ChannelMessage {
240 channel,
241 message: Message::Close(_),
242 ..
243 } = message
244 {
245 self.close_local(*channel);
246 } else if let ChannelMessage {
249 message: Message::LocalSignal((name, data)),
250 ..
251 } = message
252 {
253 self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec())));
254 return false;
255 }
256 true
257 }
258
259 #[instrument(skip_all, err)]
261 fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> {
262 loop {
263 match self.io.poll_inbound(cx) {
264 Poll::Ready(Some(result)) => {
265 let messages = result?;
266 self.on_inbound_channel_messages(messages)?;
267 }
268 Poll::Ready(None) => return Ok(()),
269 Poll::Pending => return Ok(()),
270 }
271 }
272 }
273
274 #[instrument(skip_all)]
276 fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> {
277 loop {
278 if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) {
280 error!(err = ?e, "error from poll_outbound");
281 return Err(e);
282 }
283 match Pin::new(&mut self.outbound_rx).poll_next(cx) {
285 Poll::Ready(Some(mut messages)) => {
286 if !messages.is_empty() {
287 messages.retain(|message| self.on_outbound_message(message));
288 for msg in messages {
289 self.io.enqueue(msg);
290 }
291 }
292 }
293 Poll::Ready(None) => unreachable!("Channel closed before end"),
294 Poll::Pending => return Ok(()),
295 }
296 }
297 }
298
299 #[instrument(skip_all)]
300 fn on_inbound_channel_messages(&mut self, channel_messages: Vec<ChannelMessage>) -> Result<()> {
301 for channel_message in channel_messages {
302 self.on_inbound_message(channel_message)?
303 }
304 Ok(())
305 }
306
307 #[instrument(skip_all)]
308 fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> {
309 let (remote_id, message) = channel_message.into_split();
310 match message {
311 Message::Open(msg) => self.on_open(remote_id, msg)?,
312 Message::Close(msg) => self.on_close(remote_id, msg)?,
313 _ => self
314 .channels
315 .forward_inbound_message(remote_id as usize, message)?,
316 }
317 Ok(())
318 }
319
320 #[instrument(skip(self))]
321 fn on_command(&mut self, command: Command) -> Result<()> {
322 match command {
323 Command::Open(key) => self.command_open(key),
324 Command::Close(discovery_key) => self.command_close(discovery_key),
325 Command::SignalLocal((name, data)) => self.command_signal_local(name, data),
326 }
327 }
328
329 #[instrument(skip_all)]
331 fn command_open(&mut self, key: Key) -> Result<()> {
332 let channel_handle = self.channels.attach_local(key);
334 let local_id = channel_handle.local_id().unwrap();
336 let discovery_key = *channel_handle.discovery_key();
337
338 if channel_handle.is_connected() {
341 self.accept_channel(local_id)?;
342 }
343
344 let capability = self.capability(&key);
346 let channel = local_id as u64;
347 let message = Message::Open(Open {
348 channel,
349 protocol: PROTOCOL_NAME.to_string(),
350 discovery_key: discovery_key.to_vec(),
351 capability,
352 });
353 let channel_message = ChannelMessage::new(channel, message);
354 self.io.enqueue(channel_message);
355 Ok(())
356 }
357
358 fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
359 if self.channels.has_channel(&discovery_key) {
360 self.channels.remove(&discovery_key);
361 self.queue_event(Event::Close(discovery_key));
362 }
363 Ok(())
364 }
365
366 fn command_signal_local(&mut self, name: String, data: Vec<u8>) -> Result<()> {
367 self.queue_event(Event::LocalSignal((name, data)));
368 Ok(())
369 }
370
371 #[instrument(skip(self))]
372 fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> {
373 let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?;
374 let channel_handle =
375 self.channels
376 .attach_remote(discovery_key, ch as usize, msg.capability);
377
378 if channel_handle.is_connected() {
379 let local_id = channel_handle.local_id().unwrap();
380 self.accept_channel(local_id)?;
381 } else {
382 self.queue_event(Event::DiscoveryKey(discovery_key));
383 }
384
385 Ok(())
386 }
387
388 #[instrument(skip(self))]
389 fn queue_event(&mut self, event: Event) {
390 self.queued_events.push_back(event);
391 }
392
393 #[instrument(skip(self))]
394 fn accept_channel(&mut self, local_id: usize) -> Result<()> {
395 let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?;
396 self.verify_remote_capability(remote_capability.cloned(), key)
397 .expect("TODO channel can only be accepted after first message")?;
398 let channel = self.channels.accept(local_id, self.outbound_tx.clone())?;
399 self.queue_event(Event::Channel(channel));
400 Ok(())
401 }
402
403 fn close_local(&mut self, local_id: u64) {
404 if let Some(channel) = self.channels.get_local(local_id as usize) {
405 let discovery_key = *channel.discovery_key();
406 self.channels.remove(&discovery_key);
407 self.queue_event(Event::Close(discovery_key));
408 }
409 }
410
411 fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> {
412 if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) {
413 let discovery_key = *channel_handle.discovery_key();
414 self.channels
417 .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?;
418 self.channels.remove(&discovery_key);
419 self.queue_event(Event::Close(discovery_key));
420 }
421 Ok(())
422 }
423
424 #[instrument(skip_all)]
425 fn capability(&self, key: &[u8]) -> Option<Vec<u8>> {
426 let is_initiator = self.is_initiator;
427 let remote_pubkey = self.remote_public_key()?;
428 let local_pubkey = self.public_key();
429 let handshake_hash = self.io.handshake_hash()?;
430 HandshakeResult::from_pre_encrypted(
431 is_initiator,
432 local_pubkey,
433 remote_pubkey,
434 handshake_hash.to_vec(),
435 )
436 .capability(key)
437 }
438
439 #[instrument(skip_all)]
440 fn verify_remote_capability(
441 &self,
442 capability: Option<Vec<u8>>,
443 key: &[u8],
444 ) -> Option<Result<()>> {
445 let is_initiator = self.is_initiator;
446 let remote_pubkey = self.remote_public_key()?;
447 let local_pubkey = self.public_key();
448 let handshake_hash = self.io.handshake_hash()?;
449 Some(
450 HandshakeResult::from_pre_encrypted(
451 is_initiator,
452 local_pubkey,
453 remote_pubkey,
454 handshake_hash.to_vec(),
455 )
456 .verify_remote_capability(capability, key),
457 )
458 }
459}
460
461impl Stream for Protocol {
462 type Item = Result<Event>;
463 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
464 match Protocol::poll_next(self, cx) {
465 Poll::Ready(Ok(e)) => Poll::Ready(Some(Ok(e))),
466 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
467 Poll::Pending => Poll::Pending,
468 }
469 }
470}
471
472#[derive(Clone, Debug)]
474pub struct CommandTx(Sender<Command>);
475
476impl CommandTx {
477 pub fn send(&self, command: Command) -> impl Future<Output = Result<()>> + use<> {
479 let sender = self.0.clone();
480 async move { sender.send(command).await.map_err(map_channel_err) }
481 }
482 pub fn open(&self, key: Key) -> impl Future<Output = Result<()>> + use<> {
486 self.send(Command::Open(key))
487 }
488
489 pub async fn close(&self, discovery_key: DiscoveryKey) -> Result<()> {
491 self.send(Command::Close(discovery_key)).await
492 }
493
494 pub async fn signal_local(&self, name: &str, data: Vec<u8>) -> Result<()> {
496 self.send(Command::SignalLocal((name.to_string(), data)))
497 .await
498 }
499}
500
501fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> {
502 key.try_into()
503 .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long"))
504}