1use async_channel::{Receiver, Sender};
2use futures_lite::io::{AsyncRead, AsyncWrite};
3use futures_lite::stream::Stream;
4use futures_timer::Delay;
5use std::collections::VecDeque;
6use std::convert::TryInto;
7use std::fmt;
8use std::future::Future;
9use std::io::{self, Error, ErrorKind, Result};
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use crate::channels::{Channel, ChannelMap};
15use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME};
16use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult};
17use crate::message::{ChannelMessage, Frame, FrameType, Message};
18use crate::reader::ReadState;
19use crate::schema::*;
20use crate::util::{map_channel_err, pretty_hash};
21use crate::writer::WriteState;
22
23macro_rules! return_error {
24 ($msg:expr) => {
25 if let Err(e) = $msg {
26 return Poll::Ready(Err(e));
27 }
28 };
29}
30
31const CHANNEL_CAP: usize = 1000;
32const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64);
33
34#[derive(Debug)]
36pub(crate) struct Options {
37 pub(crate) is_initiator: bool,
39 pub(crate) noise: bool,
43 pub(crate) encrypted: bool,
45}
46
47impl Options {
48 pub(crate) fn new(is_initiator: bool) -> Self {
50 Self {
51 is_initiator,
52 noise: true,
53 encrypted: true,
54 }
55 }
56}
57
58pub(crate) type RemotePublicKey = [u8; 32];
60pub type DiscoveryKey = [u8; 32];
62pub type Key = [u8; 32];
64
65#[non_exhaustive]
67#[derive(PartialEq)]
68pub enum Event {
69 Handshake(RemotePublicKey),
72 DiscoveryKey(DiscoveryKey),
74 Channel(Channel),
76 Close(DiscoveryKey),
78 LocalSignal((String, Vec<u8>)),
81}
82
83#[derive(Debug)]
85pub enum Command {
86 Open(Key),
88 Close(DiscoveryKey),
90 SignalLocal((String, Vec<u8>)),
92}
93
94impl fmt::Debug for Event {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 match self {
97 Event::Handshake(remote_key) => {
98 write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key))
99 }
100 Event::DiscoveryKey(discovery_key) => {
101 write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key))
102 }
103 Event::Channel(channel) => {
104 write!(f, "Channel({})", &pretty_hash(channel.discovery_key()))
105 }
106 Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)),
107 Event::LocalSignal((name, data)) => {
108 write!(f, "LocalSignal(name={},len={})", name, data.len())
109 }
110 }
111 }
112}
113
114#[allow(clippy::large_enum_variant)]
116pub(crate) enum State {
117 NotInitialized,
118 Handshake(Option<Handshake>),
121 SecretStream(Option<EncryptCipher>),
122 Established,
123}
124
125impl fmt::Debug for State {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 State::NotInitialized => write!(f, "NotInitialized"),
129 State::Handshake(_) => write!(f, "Handshaking"),
130 State::SecretStream(_) => write!(f, "SecretStream"),
131 State::Established => write!(f, "Established"),
132 }
133 }
134}
135
136pub struct Protocol<IO> {
138 write_state: WriteState,
139 read_state: ReadState,
140 io: IO,
141 state: State,
142 options: Options,
143 handshake: Option<HandshakeResult>,
144 channels: ChannelMap,
145 command_rx: Receiver<Command>,
146 command_tx: CommandTx,
147 outbound_rx: Receiver<Vec<ChannelMessage>>,
148 outbound_tx: Sender<Vec<ChannelMessage>>,
149 keepalive: Delay,
150 queued_events: VecDeque<Event>,
151}
152
153impl<IO> std::fmt::Debug for Protocol<IO> {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 f.debug_struct("Protocol")
156 .field("write_state", &self.write_state)
157 .field("read_state", &self.read_state)
158 .field("state", &self.state)
160 .field("options", &self.options)
161 .field("handshake", &self.handshake)
162 .field("channels", &self.channels)
163 .field("command_rx", &self.command_rx)
164 .field("command_tx", &self.command_tx)
165 .field("outbound_rx", &self.outbound_rx)
166 .field("outbound_tx", &self.outbound_tx)
167 .field("keepalive", &self.keepalive)
168 .field("queued_events", &self.queued_events)
169 .finish()
170 }
171}
172
173impl<IO> Protocol<IO>
174where
175 IO: AsyncWrite + AsyncRead + Send + Unpin + 'static,
176{
177 pub(crate) fn new(io: IO, options: Options) -> Self {
179 let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP);
180 let (outbound_tx, outbound_rx): (
181 Sender<Vec<ChannelMessage>>,
182 Receiver<Vec<ChannelMessage>>,
183 ) = async_channel::bounded(1);
184 Protocol {
185 io,
186 read_state: ReadState::new(),
187 write_state: WriteState::new(),
188 options,
189 state: State::NotInitialized,
190 channels: ChannelMap::new(),
191 handshake: None,
192 command_rx,
193 command_tx: CommandTx(command_tx),
194 outbound_tx,
195 outbound_rx,
196 keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)),
197 queued_events: VecDeque::new(),
198 }
199 }
200
201 pub fn is_initiator(&self) -> bool {
203 self.options.is_initiator
204 }
205
206 pub fn public_key(&self) -> Option<&[u8]> {
210 match &self.handshake {
211 None => None,
212 Some(handshake) => Some(handshake.local_pubkey.as_slice()),
213 }
214 }
215
216 pub fn remote_public_key(&self) -> Option<&[u8]> {
220 match &self.handshake {
221 None => None,
222 Some(handshake) => Some(handshake.remote_pubkey.as_slice()),
223 }
224 }
225
226 pub fn commands(&self) -> CommandTx {
228 self.command_tx.clone()
229 }
230
231 pub async fn command(&mut self, command: Command) -> Result<()> {
233 self.command_tx.send(command).await
234 }
235
236 pub async fn open(&mut self, key: Key) -> Result<()> {
241 self.command_tx.open(key).await
242 }
243
244 pub fn channels(&self) -> impl Iterator<Item = &DiscoveryKey> {
246 self.channels.iter().map(|c| c.discovery_key())
247 }
248
249 pub fn release(self) -> IO {
251 self.io
252 }
253
254 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Event>> {
255 let this = self.get_mut();
256
257 if let State::NotInitialized = this.state {
258 return_error!(this.init());
259 }
260
261 if let Some(event) = this.queued_events.pop_front() {
263 return Poll::Ready(Ok(event));
264 }
265
266 return_error!(this.poll_inbound_read(cx));
268
269 if let State::Established = this.state {
270 return_error!(this.poll_commands(cx));
272 }
273
274 this.poll_keepalive(cx);
276
277 return_error!(this.poll_outbound_write(cx));
279
280 if let Some(event) = this.queued_events.pop_front() {
282 Poll::Ready(Ok(event))
283 } else {
284 Poll::Pending
285 }
286 }
287
288 fn init(&mut self) -> Result<()> {
289 tracing::debug!(
290 "protocol init, state {:?}, options {:?}",
291 self.state,
292 self.options
293 );
294 match self.state {
295 State::NotInitialized => {}
296 _ => return Ok(()),
297 };
298
299 self.state = if self.options.noise {
300 let mut handshake = Handshake::new(self.options.is_initiator)?;
301 if let Some(buf) = handshake.start()? {
303 self.queue_frame_direct(buf.to_vec()).unwrap();
304 }
305 self.read_state.set_frame_type(FrameType::Raw);
306 State::Handshake(Some(handshake))
307 } else {
308 self.read_state.set_frame_type(FrameType::Message);
309 State::Established
310 };
311
312 Ok(())
313 }
314
315 fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> {
317 while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) {
318 self.on_command(command)?;
319 }
320 Ok(())
321 }
322
323 fn poll_keepalive(&mut self, cx: &mut Context<'_>) {
325 if Pin::new(&mut self.keepalive).poll(cx).is_ready() {
326 if let State::Established = self.state {
327 self.write_state
329 .queue_frame(Frame::RawBatch(vec![vec![0u8; 3]]));
330 }
331 self.keepalive.reset(KEEPALIVE_DURATION);
332 }
333 }
334
335 fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool {
336 if let ChannelMessage {
338 channel,
339 message: Message::Close(_),
340 ..
341 } = message
342 {
343 self.close_local(*channel);
344 } else if let ChannelMessage {
347 message: Message::LocalSignal((name, data)),
348 ..
349 } = message
350 {
351 self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec())));
352 return false;
353 }
354 true
355 }
356
357 fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> {
359 loop {
360 let msg = self.read_state.poll_reader(cx, &mut self.io);
361 match msg {
362 Poll::Ready(Ok(message)) => {
363 self.on_inbound_frame(message)?;
364 }
365 Poll::Ready(Err(e)) => return Err(e),
366 Poll::Pending => return Ok(()),
367 }
368 }
369 }
370
371 fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> {
373 loop {
374 if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) {
375 return Err(e);
376 }
377 if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) {
378 return Ok(());
379 }
380
381 match Pin::new(&mut self.outbound_rx).poll_next(cx) {
382 Poll::Ready(Some(mut messages)) => {
383 if !messages.is_empty() {
384 messages.retain(|message| self.on_outbound_message(message));
385 if !messages.is_empty() {
386 let frame = Frame::MessageBatch(messages);
387 self.write_state.park_frame(frame);
388 }
389 }
390 }
391 Poll::Ready(None) => unreachable!("Channel closed before end"),
392 Poll::Pending => return Ok(()),
393 }
394 }
395 }
396
397 fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> {
398 match frame {
399 Frame::RawBatch(raw_batch) => {
400 let mut processed_state: Option<String> = None;
401 for buf in raw_batch {
402 let state_name: String = format!("{:?}", self.state);
403 match self.state {
404 State::Handshake(_) => self.on_handshake_message(buf)?,
405 State::SecretStream(_) => self.on_secret_stream_message(buf)?,
406 State::Established => {
407 if let Some(processed_state) = processed_state.as_ref() {
408 let previous_state = if self.options.encrypted {
409 State::SecretStream(None)
410 } else {
411 State::Handshake(None)
412 };
413 if processed_state == &format!("{previous_state:?}") {
414 let buf = self.read_state.decrypt_buf(&buf)?;
419 let frame = Frame::decode(&buf, &FrameType::Message)?;
420 self.on_inbound_frame(frame)?;
421 continue;
422 }
423 }
424 unreachable!(
425 "May not receive raw frames in Established state"
426 )
427 }
428 _ => unreachable!(
429 "May not receive raw frames outside of handshake or secretstream state, was {:?}",
430 self.state
431 ),
432 };
433 if processed_state.is_none() {
434 processed_state = Some(state_name)
435 }
436 }
437 Ok(())
438 }
439 Frame::MessageBatch(channel_messages) => match self.state {
440 State::Established => {
441 for channel_message in channel_messages {
442 self.on_inbound_message(channel_message)?
443 }
444 Ok(())
445 }
446 _ => unreachable!("May not receive message batch frames when not established"),
447 },
448 }
449 }
450
451 fn on_handshake_message(&mut self, buf: Vec<u8>) -> Result<()> {
452 let mut handshake = match &mut self.state {
453 State::Handshake(handshake) => handshake.take().unwrap(),
454 _ => unreachable!("May not call on_handshake_message when not in Handshake state"),
455 };
456
457 if let Some(response_buf) = handshake.read(&buf)? {
458 self.queue_frame_direct(response_buf.to_vec()).unwrap();
459 }
460
461 if !handshake.complete() {
462 self.state = State::Handshake(Some(handshake));
463 } else {
464 let handshake_result = handshake.into_result()?;
465
466 if self.options.encrypted {
467 let (cipher, init_msg) = EncryptCipher::from_handshake_tx(&handshake_result)?;
469 self.state = State::SecretStream(Some(cipher));
470
471 self.queue_frame_direct(init_msg).unwrap();
473 } else {
474 self.read_state.set_frame_type(FrameType::Message);
477 let remote_public_key = parse_key(&handshake_result.remote_pubkey)?;
478 self.queue_event(Event::Handshake(remote_public_key));
479 self.state = State::Established;
480 }
481 self.handshake = Some(handshake_result);
483 }
484 Ok(())
485 }
486
487 fn on_secret_stream_message(&mut self, buf: Vec<u8>) -> Result<()> {
488 let encrypt_cipher = match &mut self.state {
489 State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(),
490 _ => {
491 unreachable!("May not call on_secret_stream_message when not in SecretStream state")
492 }
493 };
494 let handshake_result = &self
495 .handshake
496 .as_ref()
497 .expect("Handshake result must be set before secret stream");
498 let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?;
499 self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher);
500 self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher);
501 self.read_state.set_frame_type(FrameType::Message);
502
503 let remote_public_key = parse_key(&handshake_result.remote_pubkey)?;
505 self.queue_event(Event::Handshake(remote_public_key));
506 self.state = State::Established;
507 Ok(())
508 }
509
510 fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> {
511 let (remote_id, message) = channel_message.into_split();
513 match message {
514 Message::Open(msg) => self.on_open(remote_id, msg)?,
515 Message::Close(msg) => self.on_close(remote_id, msg)?,
516 _ => self
517 .channels
518 .forward_inbound_message(remote_id as usize, message)?,
519 }
520 Ok(())
521 }
522
523 fn on_command(&mut self, command: Command) -> Result<()> {
524 match command {
525 Command::Open(key) => self.command_open(key),
526 Command::Close(discovery_key) => self.command_close(discovery_key),
527 Command::SignalLocal((name, data)) => self.command_signal_local(name, data),
528 }
529 }
530
531 fn command_open(&mut self, key: Key) -> Result<()> {
533 let channel_handle = self.channels.attach_local(key);
535 let local_id = channel_handle.local_id().unwrap();
537 let discovery_key = *channel_handle.discovery_key();
538
539 if channel_handle.is_connected() {
542 self.accept_channel(local_id)?;
543 }
544
545 let capability = self.capability(&key);
547 let channel = local_id as u64;
548 let message = Message::Open(Open {
549 channel,
550 protocol: PROTOCOL_NAME.to_string(),
551 discovery_key: discovery_key.to_vec(),
552 capability,
553 });
554 let channel_message = ChannelMessage::new(channel, message);
555 self.write_state
556 .queue_frame(Frame::MessageBatch(vec![channel_message]));
557 Ok(())
558 }
559
560 fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
561 if self.channels.has_channel(&discovery_key) {
562 self.channels.remove(&discovery_key);
563 self.queue_event(Event::Close(discovery_key));
564 }
565 Ok(())
566 }
567
568 fn command_signal_local(&mut self, name: String, data: Vec<u8>) -> Result<()> {
569 self.queue_event(Event::LocalSignal((name, data)));
570 Ok(())
571 }
572
573 fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> {
574 let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?;
575 let channel_handle =
576 self.channels
577 .attach_remote(discovery_key, ch as usize, msg.capability);
578
579 if channel_handle.is_connected() {
580 let local_id = channel_handle.local_id().unwrap();
581 self.accept_channel(local_id)?;
582 } else {
583 self.queue_event(Event::DiscoveryKey(discovery_key));
584 }
585
586 Ok(())
587 }
588
589 fn queue_event(&mut self, event: Event) {
590 self.queued_events.push_back(event);
591 }
592
593 fn queue_frame_direct(&mut self, body: Vec<u8>) -> Result<bool> {
594 let mut frame = Frame::RawBatch(vec![body]);
595 self.write_state.try_queue_direct(&mut frame)
596 }
597
598 fn accept_channel(&mut self, local_id: usize) -> Result<()> {
599 let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?;
600 self.verify_remote_capability(remote_capability.cloned(), key)?;
601 let channel = self.channels.accept(local_id, self.outbound_tx.clone())?;
602 self.queue_event(Event::Channel(channel));
603 Ok(())
604 }
605
606 fn close_local(&mut self, local_id: u64) {
607 if let Some(channel) = self.channels.get_local(local_id as usize) {
608 let discovery_key = *channel.discovery_key();
609 self.channels.remove(&discovery_key);
610 self.queue_event(Event::Close(discovery_key));
611 }
612 }
613
614 fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> {
615 if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) {
616 let discovery_key = *channel_handle.discovery_key();
617 self.channels
620 .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?;
621 self.channels.remove(&discovery_key);
622 self.queue_event(Event::Close(discovery_key));
623 }
624 Ok(())
625 }
626
627 fn capability(&self, key: &[u8]) -> Option<Vec<u8>> {
628 match self.handshake.as_ref() {
629 Some(handshake) => handshake.capability(key),
630 None => None,
631 }
632 }
633
634 fn verify_remote_capability(&self, capability: Option<Vec<u8>>, key: &[u8]) -> Result<()> {
635 match self.handshake.as_ref() {
636 Some(handshake) => handshake.verify_remote_capability(capability, key),
637 None => Err(Error::new(
638 ErrorKind::PermissionDenied,
639 "Missing handshake state for capability verification",
640 )),
641 }
642 }
643}
644
645impl<IO> Stream for Protocol<IO>
646where
647 IO: AsyncRead + AsyncWrite + Send + Unpin + 'static,
648{
649 type Item = Result<Event>;
650 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
651 Protocol::poll_next(self, cx).map(Some)
652 }
653}
654
655#[derive(Clone, Debug)]
657pub struct CommandTx(Sender<Command>);
658
659impl CommandTx {
660 pub async fn send(&mut self, command: Command) -> Result<()> {
662 self.0.send(command).await.map_err(map_channel_err)
663 }
664 pub async fn open(&mut self, key: Key) -> Result<()> {
668 self.send(Command::Open(key)).await
669 }
670
671 pub async fn close(&mut self, discovery_key: DiscoveryKey) -> Result<()> {
673 self.send(Command::Close(discovery_key)).await
674 }
675
676 pub async fn signal_local(&mut self, name: &str, data: Vec<u8>) -> Result<()> {
678 self.send(Command::SignalLocal((name.to_string(), data)))
679 .await
680 }
681}
682
683fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> {
684 key.try_into()
685 .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long"))
686}