1use async_channel::{Receiver, Sender};
2use futures_lite::stream::Stream;
3use hypercore_handshake::{CipherTrait, state_machine::PUBLIC_KEYLEN};
4use std::{
5 collections::VecDeque,
6 convert::TryInto,
7 fmt,
8 io::{self, Result},
9 pin::Pin,
10 task::{Context, Poll},
11};
12use tracing::{error, instrument};
13
14use crate::{
15 channels::{Channel, ChannelMap},
16 constants::PROTOCOL_NAME,
17 crypto::HandshakeResult,
18 message::{ChannelMessage, Message},
19 mqueue::MessageIo,
20 schema::*,
21 util::{map_channel_err, pretty_hash},
22};
23
24macro_rules! return_error {
25 ($msg:expr) => {
26 if let Err(e) = $msg {
27 return Poll::Ready(Err(e));
28 }
29 };
30}
31
32const CHANNEL_CAP: usize = 1000;
33
34pub(crate) type RemotePublicKey = [u8; 32];
36pub type DiscoveryKey = [u8; 32];
38pub type Key = [u8; 32];
40
41#[non_exhaustive]
43#[derive(PartialEq)]
44pub enum Event {
45 Handshake(RemotePublicKey),
48 DiscoveryKey(DiscoveryKey),
50 Channel(Channel),
52 Close(DiscoveryKey),
54 LocalSignal((String, Vec<u8>)),
57}
58
59#[derive(Debug)]
61pub enum Command {
62 Open(Key),
64 Close(DiscoveryKey),
66 SignalLocal((String, Vec<u8>)),
68}
69
70impl fmt::Debug for Event {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 Event::Handshake(remote_key) => {
74 write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key))
75 }
76 Event::DiscoveryKey(discovery_key) => {
77 write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key))
78 }
79 Event::Channel(channel) => {
80 write!(f, "Channel({})", &pretty_hash(channel.discovery_key()))
81 }
82 Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)),
83 Event::LocalSignal((name, data)) => {
84 write!(f, "LocalSignal(name={},len={})", name, data.len())
85 }
86 }
87 }
88}
89
90pub struct Protocol {
96 io: MessageIo,
97 is_initiator: bool,
98 channels: ChannelMap,
99 command_rx: Receiver<Command>,
100 command_tx: CommandTx,
101 outbound_rx: Receiver<Vec<ChannelMessage>>,
102 outbound_tx: Sender<Vec<ChannelMessage>>,
103 queued_events: VecDeque<Event>,
104 handshake_emitted: bool,
105}
106
107impl std::fmt::Debug for Protocol {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 f.debug_struct("Protocol")
110 .field("is_initiator", &self.is_initiator)
111 .field("channels", &self.channels)
112 .field("handshake_emitted", &self.handshake_emitted)
113 .field("queued_events", &self.queued_events)
114 .finish()
115 }
116}
117
118impl Protocol {
119 pub fn new(stream: Box<dyn CipherTrait>) -> Self {
124 let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP);
125 let (outbound_tx, outbound_rx): (
126 Sender<Vec<ChannelMessage>>,
127 Receiver<Vec<ChannelMessage>>,
128 ) = async_channel::bounded(CHANNEL_CAP);
129
130 let is_initiator = stream.is_initiator();
131
132 Protocol {
133 io: MessageIo::new(stream),
134 is_initiator,
135 channels: ChannelMap::new(),
136 command_rx,
137 command_tx: CommandTx(command_tx),
138 outbound_tx,
139 outbound_rx,
140 queued_events: VecDeque::new(),
141 handshake_emitted: false,
142 }
143 }
144
145 pub fn is_initiator(&self) -> bool {
147 self.is_initiator
148 }
149
150 pub fn public_key(&self) -> [u8; PUBLIC_KEYLEN] {
152 self.io.local_public_key()
153 }
154
155 pub fn remote_public_key(&self) -> Option<[u8; PUBLIC_KEYLEN]> {
157 self.io.remote_public_key()
158 }
159
160 pub fn commands(&self) -> CommandTx {
162 self.command_tx.clone()
163 }
164
165 pub async fn command(&self, command: Command) -> Result<()> {
167 self.command_tx.send(command).await
168 }
169
170 pub async fn open(&self, key: Key) -> Result<()> {
175 self.command_tx.open(key).await
176 }
177
178 pub fn channels(&self) -> impl Iterator<Item = &DiscoveryKey> {
180 self.channels.iter().map(|c| c.discovery_key())
181 }
182
183 #[instrument(skip_all, fields(initiator = ?self.is_initiator()))]
184 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Event>> {
185 let this = self.get_mut();
186
187 if this.is_initiator && this.io.handshake_hash().is_none() {
189 return_error!(this.poll_outbound_write(cx));
190 return_error!(this.poll_inbound_read(cx));
191 if this.io.handshake_hash().is_none() {
192 cx.waker().wake_by_ref();
193 return Poll::Pending;
194 }
195 }
196 if !this.handshake_emitted {
198 if let Some(remote_pubkey) = this.io.remote_public_key() {
199 this.handshake_emitted = true;
200 return Poll::Ready(Ok(Event::Handshake(remote_pubkey)));
201 } else {
202 cx.waker().wake_by_ref();
203 }
204 }
205
206 if let Some(event) = this.queued_events.pop_front() {
208 return Poll::Ready(Ok(event));
209 }
210
211 return_error!(this.poll_inbound_read(cx));
213
214 return_error!(this.poll_commands(cx));
216
217 return_error!(this.poll_outbound_write(cx));
219
220 if let Some(event) = this.queued_events.pop_front() {
222 Poll::Ready(Ok(event))
223 } else {
224 Poll::Pending
225 }
226 }
227
228 fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> {
230 while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) {
231 if let Err(e) = self.on_command(command) {
232 error!(error = ?e, "Error handling command");
233 return Err(e);
234 }
235 }
236 Ok(())
237 }
238
239 fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool {
241 if let ChannelMessage {
243 channel,
244 message: Message::Close(_),
245 ..
246 } = message
247 {
248 self.close_local(*channel);
249 } else if let ChannelMessage {
252 message: Message::LocalSignal((name, data)),
253 ..
254 } = message
255 {
256 self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec())));
257 return false;
258 }
259 true
260 }
261
262 #[instrument(skip_all, err)]
264 fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> {
265 loop {
266 match self.io.poll_inbound(cx) {
267 Poll::Ready(Some(result)) => {
268 let messages = result?;
269 self.on_inbound_channel_messages(messages)?;
270 }
271 Poll::Ready(None) => return Ok(()),
272 Poll::Pending => return Ok(()),
273 }
274 }
275 }
276
277 #[instrument(skip_all)]
279 fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> {
280 loop {
281 if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) {
283 error!(err = ?e, "error from poll_outbound");
284 return Err(e);
285 }
286 match Pin::new(&mut self.outbound_rx).poll_next(cx) {
288 Poll::Ready(Some(mut messages)) => {
289 if !messages.is_empty() {
290 messages.retain(|message| self.on_outbound_message(message));
291 for msg in messages {
292 self.io.enqueue(msg);
293 }
294 }
295 }
296 Poll::Ready(None) => unreachable!("Channel closed before end"),
297 Poll::Pending => return Ok(()),
298 }
299 }
300 }
301
302 #[instrument(skip_all)]
303 fn on_inbound_channel_messages(&mut self, channel_messages: Vec<ChannelMessage>) -> Result<()> {
304 for channel_message in channel_messages {
305 self.on_inbound_message(channel_message)?
306 }
307 Ok(())
308 }
309
310 #[instrument(skip_all)]
311 fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> {
312 let (remote_id, message) = channel_message.into_split();
313 match message {
314 Message::Open(msg) => self.on_open(remote_id, msg)?,
315 Message::Close(msg) => self.on_close(remote_id, msg)?,
316 _ => self
317 .channels
318 .forward_inbound_message(remote_id as usize, message)?,
319 }
320 Ok(())
321 }
322
323 #[instrument(skip(self))]
324 fn on_command(&mut self, command: Command) -> Result<()> {
325 match command {
326 Command::Open(key) => self.command_open(key),
327 Command::Close(discovery_key) => self.command_close(discovery_key),
328 Command::SignalLocal((name, data)) => self.command_signal_local(name, data),
329 }
330 }
331
332 #[instrument(skip_all)]
334 fn command_open(&mut self, key: Key) -> Result<()> {
335 let channel_handle = self.channels.attach_local(key);
337 let local_id = channel_handle.local_id().unwrap();
339 let discovery_key = *channel_handle.discovery_key();
340
341 if channel_handle.is_connected() {
344 self.accept_channel(local_id)?;
345 }
346
347 let capability = self.capability(&key);
349 let channel = local_id as u64;
350 let message = Message::Open(Open {
351 channel,
352 protocol: PROTOCOL_NAME.to_string(),
353 discovery_key: discovery_key.to_vec(),
354 capability,
355 });
356 let channel_message = ChannelMessage::new(channel, message);
357 self.io.enqueue(channel_message);
358 Ok(())
359 }
360
361 fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
362 if self.channels.has_channel(&discovery_key) {
363 self.channels.remove(&discovery_key);
364 self.queue_event(Event::Close(discovery_key));
365 }
366 Ok(())
367 }
368
369 fn command_signal_local(&mut self, name: String, data: Vec<u8>) -> Result<()> {
370 self.queue_event(Event::LocalSignal((name, data)));
371 Ok(())
372 }
373
374 #[instrument(skip(self))]
375 fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> {
376 let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?;
377 let channel_handle =
378 self.channels
379 .attach_remote(discovery_key, ch as usize, msg.capability);
380
381 if channel_handle.is_connected() {
382 let local_id = channel_handle.local_id().unwrap();
383 self.accept_channel(local_id)?;
384 } else {
385 self.queue_event(Event::DiscoveryKey(discovery_key));
386 }
387
388 Ok(())
389 }
390
391 #[instrument(skip(self))]
392 fn queue_event(&mut self, event: Event) {
393 self.queued_events.push_back(event);
394 }
395
396 #[instrument(skip(self))]
397 fn accept_channel(&mut self, local_id: usize) -> Result<()> {
398 let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?;
399 self.verify_remote_capability(remote_capability.cloned(), key)
400 .expect("TODO channel can only be accepted after first message")?;
401 let channel = self.channels.accept(local_id, self.outbound_tx.clone())?;
402 self.queue_event(Event::Channel(channel));
403 Ok(())
404 }
405
406 fn close_local(&mut self, local_id: u64) {
407 if let Some(channel) = self.channels.get_local(local_id as usize) {
408 let discovery_key = *channel.discovery_key();
409 self.channels.remove(&discovery_key);
410 self.queue_event(Event::Close(discovery_key));
411 }
412 }
413
414 fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> {
415 if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) {
416 let discovery_key = *channel_handle.discovery_key();
417 self.channels
420 .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?;
421 self.channels.remove(&discovery_key);
422 self.queue_event(Event::Close(discovery_key));
423 }
424 Ok(())
425 }
426
427 #[instrument(skip_all)]
428 fn capability(&self, key: &[u8]) -> Option<Vec<u8>> {
429 let is_initiator = self.is_initiator;
430 let remote_pubkey = self.remote_public_key()?;
431 let local_pubkey = self.public_key();
432 let handshake_hash = self.io.handshake_hash()?;
433 HandshakeResult::from_pre_encrypted(
434 is_initiator,
435 local_pubkey,
436 remote_pubkey,
437 handshake_hash.to_vec(),
438 )
439 .capability(key)
440 }
441
442 #[instrument(skip_all)]
443 fn verify_remote_capability(
444 &self,
445 capability: Option<Vec<u8>>,
446 key: &[u8],
447 ) -> Option<Result<()>> {
448 let is_initiator = self.is_initiator;
449 let remote_pubkey = self.remote_public_key()?;
450 let local_pubkey = self.public_key();
451 let handshake_hash = self.io.handshake_hash()?;
452 Some(
453 HandshakeResult::from_pre_encrypted(
454 is_initiator,
455 local_pubkey,
456 remote_pubkey,
457 handshake_hash.to_vec(),
458 )
459 .verify_remote_capability(capability, key),
460 )
461 }
462}
463
464impl Stream for Protocol {
465 type Item = Result<Event>;
466 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
467 match Protocol::poll_next(self, cx) {
468 Poll::Ready(Ok(e)) => Poll::Ready(Some(Ok(e))),
469 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
470 Poll::Pending => Poll::Pending,
471 }
472 }
473}
474
475#[derive(Clone, Debug)]
477pub struct CommandTx(Sender<Command>);
478
479impl CommandTx {
480 pub async fn send(&self, command: Command) -> Result<()> {
482 self.0.send(command).await.map_err(map_channel_err)
483 }
484 pub async fn open(&self, key: Key) -> Result<()> {
488 self.send(Command::Open(key)).await
489 }
490
491 pub async fn close(&self, discovery_key: DiscoveryKey) -> Result<()> {
493 self.send(Command::Close(discovery_key)).await
494 }
495
496 pub async fn signal_local(&self, name: &str, data: Vec<u8>) -> Result<()> {
498 self.send(Command::SignalLocal((name.to_string(), data)))
499 .await
500 }
501}
502
503fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> {
504 key.try_into()
505 .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long"))
506}