1use std::{
6 net::SocketAddr,
7 num::NonZeroU32,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11 time::Duration,
12};
13
14use anyhow::{anyhow, bail, ensure, Context as _, Result};
15use bytes::Bytes;
16use futures_lite::Stream;
17use futures_sink::Sink;
18use futures_util::{
19 stream::{SplitSink, SplitStream, StreamExt},
20 SinkExt,
21};
22use tokio::sync::mpsc;
23use tokio_tungstenite_wasm::WebSocketStream;
24use tokio_util::{
25 codec::{FramedRead, FramedWrite},
26 task::AbortOnDropHandle,
27};
28use tracing::{debug, info_span, trace, Instrument};
29
30use crate::{
31 defaults::timeouts::relay::CLIENT_RECV_TIMEOUT,
32 key::{PublicKey, SecretKey},
33 relay::{
34 client::streams::{MaybeTlsStreamReader, MaybeTlsStreamWriter},
35 codec::{
36 write_frame, ClientInfo, DerpCodec, Frame, MAX_PACKET_SIZE,
37 PER_CLIENT_READ_QUEUE_DEPTH, PER_CLIENT_SEND_QUEUE_DEPTH, PROTOCOL_VERSION,
38 },
39 },
40};
41
42impl PartialEq for Conn {
43 fn eq(&self, other: &Self) -> bool {
44 Arc::ptr_eq(&self.inner, &other.inner)
45 }
46}
47
48impl Eq for Conn {}
49
50#[derive(Debug, Clone)]
55pub struct Conn {
56 inner: Arc<ConnTasks>,
57}
58
59#[derive(Debug)]
65pub struct ConnReceiver {
66 reader_channel: mpsc::Receiver<Result<ReceivedMessage>>,
68}
69
70impl ConnReceiver {
71 pub async fn recv(&mut self) -> Result<ReceivedMessage> {
75 let msg = self
76 .reader_channel
77 .recv()
78 .await
79 .ok_or(anyhow!("shut down"))??;
80 Ok(msg)
81 }
82}
83
84#[derive(derive_more::Debug)]
85pub struct ConnTasks {
86 local_addr: Option<SocketAddr>,
90 writer_channel: mpsc::Sender<ConnWriterMessage>,
93 writer_task: AbortOnDropHandle<Result<()>>,
95 reader_task: AbortOnDropHandle<()>,
96}
97
98impl Conn {
99 pub async fn send(&self, dstkey: PublicKey, packet: Bytes) -> Result<()> {
103 trace!(%dstkey, len = packet.len(), "[RELAY] send");
104
105 self.inner
106 .writer_channel
107 .send(ConnWriterMessage::Packet((dstkey, packet)))
108 .await?;
109 Ok(())
110 }
111
112 pub async fn send_ping(&self, data: [u8; 8]) -> Result<()> {
114 self.inner
115 .writer_channel
116 .send(ConnWriterMessage::Ping(data))
117 .await?;
118 Ok(())
119 }
120
121 pub async fn send_pong(&self, data: [u8; 8]) -> Result<()> {
124 self.inner
125 .writer_channel
126 .send(ConnWriterMessage::Pong(data))
127 .await?;
128 Ok(())
129 }
130
131 pub async fn note_preferred(&self, preferred: bool) -> Result<()> {
135 self.inner
136 .writer_channel
137 .send(ConnWriterMessage::NotePreferred(preferred))
138 .await?;
139 Ok(())
140 }
141
142 pub fn local_addr(&self) -> Option<SocketAddr> {
146 self.inner.local_addr
147 }
148
149 pub fn is_closed(&self) -> bool {
153 self.inner.writer_task.is_finished()
154 }
155
156 pub async fn close(&self) {
161 if self.inner.writer_task.is_finished() && self.inner.reader_task.is_finished() {
162 return;
163 }
164
165 self.inner
166 .writer_channel
167 .send(ConnWriterMessage::Shutdown)
168 .await
169 .ok();
170 self.inner.reader_task.abort();
171 }
172}
173
174fn process_incoming_frame(frame: Frame) -> Result<ReceivedMessage> {
175 match frame {
176 Frame::KeepAlive => {
177 Ok(ReceivedMessage::KeepAlive)
180 }
181 Frame::PeerGone { peer } => Ok(ReceivedMessage::PeerGone(peer)),
182 Frame::RecvPacket { src_key, content } => {
183 let packet = ReceivedMessage::ReceivedPacket {
184 source: src_key,
185 data: content,
186 };
187 Ok(packet)
188 }
189 Frame::Ping { data } => Ok(ReceivedMessage::Ping(data)),
190 Frame::Pong { data } => Ok(ReceivedMessage::Pong(data)),
191 Frame::Health { problem } => {
192 let problem = std::str::from_utf8(&problem)?.to_owned();
193 let problem = Some(problem);
194 Ok(ReceivedMessage::Health { problem })
195 }
196 Frame::Restarting {
197 reconnect_in,
198 try_for,
199 } => {
200 let reconnect_in = Duration::from_millis(reconnect_in as u64);
201 let try_for = Duration::from_millis(try_for as u64);
202 Ok(ReceivedMessage::ServerRestarting {
203 reconnect_in,
204 try_for,
205 })
206 }
207 _ => bail!("unexpected packet: {:?}", frame.typ()),
208 }
209}
210
211#[derive(Debug)]
213enum ConnWriterMessage {
214 Packet((PublicKey, Bytes)),
216 Pong([u8; 8]),
218 Ping([u8; 8]),
220 NotePreferred(bool),
222 Shutdown,
224}
225
226struct ConnWriterTasks {
232 recv_msgs: mpsc::Receiver<ConnWriterMessage>,
233 writer: ConnWriter,
234 rate_limiter: Option<RateLimiter>,
235}
236
237impl ConnWriterTasks {
238 async fn run(mut self) -> Result<()> {
239 while let Some(msg) = self.recv_msgs.recv().await {
240 match msg {
241 ConnWriterMessage::Packet((key, bytes)) => {
242 send_packet(&mut self.writer, &self.rate_limiter, key, bytes).await?;
243 }
244 ConnWriterMessage::Pong(data) => {
245 write_frame(&mut self.writer, Frame::Pong { data }, None).await?;
246 self.writer.flush().await?;
247 }
248 ConnWriterMessage::Ping(data) => {
249 write_frame(&mut self.writer, Frame::Ping { data }, None).await?;
250 self.writer.flush().await?;
251 }
252 ConnWriterMessage::NotePreferred(preferred) => {
253 write_frame(&mut self.writer, Frame::NotePreferred { preferred }, None).await?;
254 self.writer.flush().await?;
255 }
256 ConnWriterMessage::Shutdown => {
257 return Ok(());
258 }
259 }
260 }
261
262 bail!("channel unexpectedly closed");
263 }
264}
265
266pub struct ConnBuilder {
269 secret_key: SecretKey,
270 reader: ConnReader,
271 writer: ConnWriter,
272 local_addr: Option<SocketAddr>,
273}
274
275pub(crate) enum ConnReader {
276 Derp(FramedRead<MaybeTlsStreamReader, DerpCodec>),
277 Ws(SplitStream<WebSocketStream>),
278}
279
280pub(crate) enum ConnWriter {
281 Derp(FramedWrite<MaybeTlsStreamWriter, DerpCodec>),
282 Ws(SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>),
283}
284
285fn tung_wasm_to_io_err(e: tokio_tungstenite_wasm::Error) -> std::io::Error {
286 match e {
287 tokio_tungstenite_wasm::Error::Io(io_err) => io_err,
288 _ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()),
289 }
290}
291
292impl Stream for ConnReader {
293 type Item = Result<Frame>;
294
295 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296 match *self {
297 Self::Derp(ref mut ws) => Pin::new(ws).poll_next(cx),
298 Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) {
299 Poll::Ready(Some(Ok(tokio_tungstenite_wasm::Message::Binary(vec)))) => {
300 Poll::Ready(Some(Frame::decode_from_ws_msg(vec)))
301 }
302 Poll::Ready(Some(Ok(msg))) => {
303 tracing::warn!(?msg, "Got websocket message of unsupported type, skipping.");
304 Poll::Pending
305 }
306 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
307 Poll::Ready(None) => Poll::Ready(None),
308 Poll::Pending => Poll::Pending,
309 },
310 }
311 }
312}
313
314impl Sink<Frame> for ConnWriter {
315 type Error = std::io::Error;
316
317 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
318 match *self {
319 Self::Derp(ref mut ws) => Pin::new(ws).poll_ready(cx),
320 Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_wasm_to_io_err),
321 }
322 }
323
324 fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
325 match *self {
326 Self::Derp(ref mut ws) => Pin::new(ws).start_send(item),
327 Self::Ws(ref mut ws) => Pin::new(ws)
328 .start_send(tokio_tungstenite_wasm::Message::binary(
329 item.encode_for_ws_msg(),
330 ))
331 .map_err(tung_wasm_to_io_err),
332 }
333 }
334
335 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
336 match *self {
337 Self::Derp(ref mut ws) => Pin::new(ws).poll_flush(cx),
338 Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_wasm_to_io_err),
339 }
340 }
341
342 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
343 match *self {
344 Self::Derp(ref mut ws) => Pin::new(ws).poll_close(cx),
345 Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_wasm_to_io_err),
346 }
347 }
348}
349
350impl ConnBuilder {
351 pub fn new(
352 secret_key: SecretKey,
353 local_addr: Option<SocketAddr>,
354 reader: ConnReader,
355 writer: ConnWriter,
356 ) -> Self {
357 Self {
358 secret_key,
359 reader,
360 writer,
361 local_addr,
362 }
363 }
364
365 async fn server_handshake(&mut self) -> Result<Option<RateLimiter>> {
366 debug!("server_handshake: started");
367 let client_info = ClientInfo {
368 version: PROTOCOL_VERSION,
369 };
370 debug!("server_handshake: sending client_key: {:?}", &client_info);
371 crate::relay::codec::send_client_key(&mut self.writer, &self.secret_key, &client_info)
372 .await?;
373
374 let rate_limiter = RateLimiter::new(0, 0)?;
376
377 debug!("server_handshake: done");
378 Ok(rate_limiter)
379 }
380
381 pub async fn build(mut self) -> Result<(Conn, ConnReceiver)> {
382 let rate_limiter = self.server_handshake().await?;
384
385 let (writer_sender, writer_recv) = mpsc::channel(PER_CLIENT_SEND_QUEUE_DEPTH);
387 let writer_task = tokio::task::spawn(
388 ConnWriterTasks {
389 rate_limiter,
390 writer: self.writer,
391 recv_msgs: writer_recv,
392 }
393 .run()
394 .instrument(info_span!("conn.writer")),
395 );
396
397 let (reader_sender, reader_recv) = mpsc::channel(PER_CLIENT_READ_QUEUE_DEPTH);
398 let reader_task = tokio::task::spawn({
399 let writer_sender = writer_sender.clone();
400 async move {
401 loop {
402 let frame = tokio::time::timeout(CLIENT_RECV_TIMEOUT, self.reader.next()).await;
403 let res = match frame {
404 Ok(Some(Ok(frame))) => process_incoming_frame(frame),
405 Ok(Some(Err(err))) => {
406 Err(err)
408 }
409 Ok(None) => {
410 Err(anyhow::anyhow!("EOF: reader stream ended"))
412 }
413 Err(err) => {
414 Err(err.into())
416 }
417 };
418 if res.is_err() {
419 writer_sender.send(ConnWriterMessage::Shutdown).await.ok();
421 break;
422 }
423 if reader_sender.send(res).await.is_err() {
424 writer_sender.send(ConnWriterMessage::Shutdown).await.ok();
426 break;
427 }
428 }
429 }
430 .instrument(info_span!("conn.reader"))
431 });
432
433 let conn = Conn {
434 inner: Arc::new(ConnTasks {
435 local_addr: self.local_addr,
436 writer_channel: writer_sender,
437 writer_task: AbortOnDropHandle::new(writer_task),
438 reader_task: AbortOnDropHandle::new(reader_task),
439 }),
440 };
441
442 let conn_receiver = ConnReceiver {
443 reader_channel: reader_recv,
444 };
445
446 Ok((conn, conn_receiver))
447 }
448}
449
450#[derive(derive_more::Debug, Clone)]
451pub enum ReceivedMessage {
453 ReceivedPacket {
455 source: PublicKey,
457 #[debug(skip)]
459 data: Bytes, },
461 PeerGone(PublicKey),
464 Ping([u8; 8]),
467 Pong([u8; 8]),
470 KeepAlive,
474 Health {
476 problem: Option<String>,
483 },
484 ServerRestarting {
486 reconnect_in: Duration,
489 try_for: Duration,
494 },
495}
496
497pub(crate) async fn send_packet<S: Sink<Frame, Error = std::io::Error> + Unpin>(
498 mut writer: S,
499 rate_limiter: &Option<RateLimiter>,
500 dst_key: PublicKey,
501 packet: Bytes,
502) -> Result<()> {
503 ensure!(
504 packet.len() <= MAX_PACKET_SIZE,
505 "packet too big: {}",
506 packet.len()
507 );
508
509 let frame = Frame::SendPacket { dst_key, packet };
510 if let Some(rate_limiter) = rate_limiter {
511 if rate_limiter.check_n(frame.len()).is_err() {
512 tracing::warn!("dropping send: rate limit reached");
513 return Ok(());
514 }
515 }
516 writer.send(frame).await?;
517 writer.flush().await?;
518
519 Ok(())
520}
521
522pub(crate) struct RateLimiter {
523 inner: governor::RateLimiter<
524 governor::state::direct::NotKeyed,
525 governor::state::InMemoryState,
526 governor::clock::DefaultClock,
527 governor::middleware::NoOpMiddleware,
528 >,
529}
530
531impl RateLimiter {
532 pub(crate) fn new(bytes_per_second: usize, bytes_burst: usize) -> Result<Option<Self>> {
533 if bytes_per_second == 0 || bytes_burst == 0 {
534 return Ok(None);
535 }
536 let bytes_per_second = NonZeroU32::new(u32::try_from(bytes_per_second)?)
537 .context("bytes_per_second not non-zero")?;
538 let bytes_burst =
539 NonZeroU32::new(u32::try_from(bytes_burst)?).context("bytes_burst not non-zero")?;
540 Ok(Some(Self {
541 inner: governor::RateLimiter::direct(
542 governor::Quota::per_second(bytes_per_second).allow_burst(bytes_burst),
543 ),
544 }))
545 }
546
547 pub(crate) fn check_n(&self, n: usize) -> Result<()> {
548 let n = NonZeroU32::new(u32::try_from(n)?).context("n not non-zero")?;
549 match self.inner.check_n(n) {
550 Ok(_) => Ok(()),
551 Err(_) => bail!("batch cannot go through"),
552 }
553 }
554}