memberlist_quic/stream_layer/
quinn.rs

1use std::{cmp, io, marker::PhantomData, net::SocketAddr, sync::Arc, time::Duration};
2
3use agnostic::Runtime;
4use futures::AsyncWriteExt;
5use memberlist_core::proto::MediumVec;
6use quinn::{ClientConfig, Connection, Endpoint, RecvStream, SendStream, VarInt};
7use smol_str::SmolStr;
8
9mod options;
10pub use options::*;
11
12use super::{QuicAcceptor, QuicConnection, QuicConnector, QuicStream, StreamLayer};
13
14/// [`Quinn`] is an implementation of [`StreamLayer`] based on [`quinn`].
15pub struct Quinn<R> {
16  opts: QuinnOptions,
17  _m: PhantomData<R>,
18}
19
20impl<R> Quinn<R> {
21  /// Creates a new [`Quinn`] stream layer with the given options.
22  fn new_in(opts: Options) -> Self {
23    Self {
24      opts: opts.into(),
25      _m: PhantomData,
26    }
27  }
28}
29
30impl<R: Runtime> StreamLayer for Quinn<R> {
31  type Runtime = R;
32  type Acceptor = QuinnAcceptor;
33  type Connector = QuinnConnector<R>;
34  type Connection = QuinnConnection;
35  type Stream = QuinnStream;
36  type Options = Options;
37
38  fn max_stream_data(&self) -> usize {
39    self.opts.max_stream_data.min(self.opts.max_connection_data)
40  }
41
42  async fn new(opts: Self::Options) -> io::Result<Self> {
43    Ok(Self::new_in(opts))
44  }
45
46  async fn bind(
47    &self,
48    addr: SocketAddr,
49  ) -> io::Result<(SocketAddr, Self::Acceptor, Self::Connector)> {
50    let server_name = self.opts.server_name.clone();
51
52    let client_config = self.opts.client_config.clone();
53    let sock = std::net::UdpSocket::bind(addr)?;
54    let auto_port = addr.port() == 0;
55
56    let endpoint = Arc::new(Endpoint::new(
57      self.opts.endpoint_config.clone(),
58      Some(self.opts.server_config.clone()),
59      sock,
60      Arc::new(R::quinn()),
61    )?);
62
63    let local_addr = endpoint.local_addr()?;
64    if auto_port {
65      tracing::info!(
66        "memberlist_quic.endpoint: binding to dynamic addr {}",
67        local_addr
68      );
69    }
70
71    let acceptor = Self::Acceptor {
72      endpoint: endpoint.clone(),
73      local_addr,
74    };
75
76    let connector = Self::Connector {
77      server_name,
78      endpoint,
79      local_addr,
80      client_config,
81      connect_timeout: self.opts.connect_timeout,
82      _marker: PhantomData,
83    };
84    Ok((local_addr, acceptor, connector))
85  }
86}
87
88/// [`QuinnAcceptor`] is an implementation of [`QuicAcceptor`] based on [`quinn`].
89pub struct QuinnAcceptor {
90  endpoint: Arc<Endpoint>,
91  local_addr: SocketAddr,
92}
93
94impl Clone for QuinnAcceptor {
95  fn clone(&self) -> Self {
96    Self {
97      endpoint: self.endpoint.clone(),
98      local_addr: self.local_addr,
99    }
100  }
101}
102
103impl QuicAcceptor for QuinnAcceptor {
104  type Connection = QuinnConnection;
105
106  async fn accept(&mut self) -> io::Result<(Self::Connection, SocketAddr)> {
107    let conn = self
108      .endpoint
109      .accept()
110      .await
111      .ok_or(io::Error::other("endpoint closed"))?
112      .await?;
113    let remote_addr = conn.remote_address();
114
115    Ok((
116      QuinnConnection::new(conn, self.local_addr, remote_addr),
117      remote_addr,
118    ))
119  }
120
121  async fn close(&mut self) -> io::Result<()> {
122    Endpoint::close(&self.endpoint, VarInt::from(0u32), b"close acceptor");
123    Ok(())
124  }
125
126  fn local_addr(&self) -> SocketAddr {
127    self.local_addr
128  }
129}
130
131impl Drop for QuinnAcceptor {
132  fn drop(&mut self) {
133    Endpoint::close(&self.endpoint, VarInt::from(0u32), b"close acceptor");
134  }
135}
136
137/// [`QuinnConnector`] is an implementation of [`QuicConnector`] based on [`quinn`].
138pub struct QuinnConnector<R> {
139  server_name: SmolStr,
140  endpoint: Arc<Endpoint>,
141  client_config: ClientConfig,
142  connect_timeout: Duration,
143  local_addr: SocketAddr,
144  _marker: PhantomData<R>,
145}
146
147impl<R> QuicConnector for QuinnConnector<R>
148where
149  R: Runtime,
150{
151  type Connection = QuinnConnection;
152
153  async fn connect(&self, addr: SocketAddr) -> io::Result<Self::Connection> {
154    let connecting = self
155      .endpoint
156      .connect_with(self.client_config.clone(), addr, &self.server_name)
157      .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
158    let conn = R::timeout(self.connect_timeout, connecting)
159      .await
160      .map_err(io::Error::from)??;
161    Ok(QuinnConnection::new(conn, self.local_addr, addr))
162  }
163
164  async fn close(&self) -> io::Result<()> {
165    Endpoint::close(&self.endpoint, VarInt::from(0u32), b"close connector");
166    Ok(())
167  }
168
169  async fn wait_idle(&self) -> io::Result<()> {
170    self.endpoint.wait_idle().await;
171    Ok(())
172  }
173
174  fn local_addr(&self) -> SocketAddr {
175    self.local_addr
176  }
177}
178
179impl<R> Drop for QuinnConnector<R> {
180  fn drop(&mut self) {
181    Endpoint::close(&self.endpoint, VarInt::from(0u32), b"close connector");
182  }
183}
184
185/// A [`ProtoReader`](memberlist_core::proto::ProtoReader) implementation for Quinn stream layer
186pub struct QuinnProtoReader {
187  stream: RecvStream,
188  peek_buf: MediumVec<u8>,
189}
190
191impl From<RecvStream> for QuinnProtoReader {
192  fn from(stream: RecvStream) -> Self {
193    Self {
194      stream,
195      peek_buf: MediumVec::new(),
196    }
197  }
198}
199
200impl memberlist_core::proto::ProtoReader for QuinnProtoReader {
201  async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
202    let dst_len = buf.len();
203    let peek_len = self.peek_buf.len();
204
205    match dst_len.cmp(&peek_len) {
206      cmp::Ordering::Less => {
207        buf.copy_from_slice(&self.peek_buf[..dst_len]);
208        Ok(dst_len)
209      }
210      cmp::Ordering::Equal => {
211        buf.copy_from_slice(&self.peek_buf);
212        Ok(peek_len)
213      }
214      cmp::Ordering::Greater => {
215        let want = dst_len - peek_len;
216        self.peek_buf.resize(dst_len, 0);
217        match self
218          .stream
219          .read(&mut self.peek_buf[peek_len..peek_len + want])
220          .await
221        {
222          Ok(Some(n)) => {
223            let has = peek_len + n;
224            if n < want {
225              self.peek_buf.truncate(has);
226            }
227            buf[..has].copy_from_slice(&self.peek_buf);
228            Ok(peek_len + n)
229          }
230          Ok(None) | Err(_) => {
231            self.peek_buf.truncate(peek_len);
232            buf[..peek_len].copy_from_slice(&self.peek_buf);
233            Ok(peek_len)
234          }
235        }
236      }
237    }
238  }
239
240  async fn peek_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
241    let dst_len = buf.len();
242    let peek_len = self.peek_buf.len();
243
244    match dst_len.cmp(&peek_len) {
245      cmp::Ordering::Less => {
246        buf.copy_from_slice(&self.peek_buf[..dst_len]);
247        Ok(())
248      }
249      cmp::Ordering::Equal => {
250        buf.copy_from_slice(&self.peek_buf);
251        Ok(())
252      }
253      cmp::Ordering::Greater => {
254        let want = dst_len - peek_len;
255        self.peek_buf.resize(dst_len, 0);
256        let readed = self
257          .stream
258          .read(&mut self.peek_buf[peek_len..peek_len + want])
259          .await?;
260
261        if let Some(n) = readed {
262          let has = peek_len + n;
263          if n < want {
264            return Err(std::io::Error::new(
265              std::io::ErrorKind::UnexpectedEof,
266              "unexpected eof",
267            ));
268          }
269          buf[..has].copy_from_slice(&self.peek_buf);
270          Ok(())
271        } else {
272          Err(std::io::Error::new(
273            std::io::ErrorKind::UnexpectedEof,
274            "unexpected eof",
275          ))
276        }
277      }
278    }
279  }
280
281  async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
282    let dst_len = buf.len();
283    let peek_len = self.peek_buf.len();
284
285    if dst_len <= peek_len {
286      buf.copy_from_slice(&self.peek_buf[..dst_len]);
287      self.peek_buf.drain(..dst_len);
288      Ok(dst_len)
289    } else {
290      buf[..peek_len].copy_from_slice(&self.peek_buf);
291      self.peek_buf.clear();
292      self
293        .stream
294        .read(&mut buf[peek_len..])
295        .await
296        .map(|read| read.unwrap_or(0))
297        .map_err(Into::into)
298    }
299  }
300
301  async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
302    let dst_len = buf.len();
303    let peek_len = self.peek_buf.len();
304
305    if dst_len <= peek_len {
306      buf.copy_from_slice(&self.peek_buf[..dst_len]);
307      self.peek_buf.drain(..dst_len);
308      Ok(())
309    } else {
310      buf[..peek_len].copy_from_slice(&self.peek_buf);
311      self.peek_buf.clear();
312      self
313        .stream
314        .read(&mut buf[peek_len..])
315        .await
316        .map(|_| ())
317        .map_err(Into::into)
318    }
319  }
320}
321
322/// [`QuinnStream`] is an implementation of [`QuicStream`] based on [`quinn`].
323pub struct QuinnStream {
324  send: SendStream,
325  recv: QuinnProtoReader,
326}
327
328impl QuinnStream {
329  #[inline]
330  fn new(send: SendStream, recv: RecvStream) -> Self {
331    Self {
332      send,
333      recv: recv.into(),
334    }
335  }
336}
337
338impl memberlist_core::transport::Connection for QuinnStream {
339  type Reader = QuinnProtoReader;
340
341  type Writer = SendStream;
342
343  fn split(self) -> (Self::Reader, Self::Writer) {
344    (self.recv, self.send)
345  }
346
347  async fn close(&mut self) -> std::io::Result<()> {
348    self.send.close().await
349  }
350
351  async fn write_all(&mut self, payload: &[u8]) -> std::io::Result<()> {
352    self.send.write_all(payload).await.map_err(Into::into)
353  }
354
355  async fn flush(&mut self) -> std::io::Result<()> {
356    self.send.flush().await
357  }
358
359  async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
360    memberlist_core::proto::ProtoReader::peek(&mut self.recv, buf).await
361  }
362
363  async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
364    memberlist_core::proto::ProtoReader::read_exact(&mut self.recv, buf).await
365  }
366
367  async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
368    memberlist_core::proto::ProtoReader::read(&mut self.recv, buf).await
369  }
370
371  async fn peek_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
372    memberlist_core::proto::ProtoReader::peek_exact(&mut self.recv, buf).await
373  }
374
375  fn consume_peek(&mut self) {
376    self.recv.peek_buf.clear();
377  }
378}
379
380impl QuicStream for QuinnStream {
381  type SendStream = SendStream;
382
383  async fn read_packet(&mut self) -> std::io::Result<bytes::Bytes> {
384    // TODO(al8n): make size limit configurable?
385    self
386      .recv
387      .stream
388      .read_to_end(u32::MAX as usize)
389      .await
390      .map(|data| {
391        if !self.recv.peek_buf.is_empty() {
392          let mut buf = bytes::BytesMut::with_capacity(self.recv.peek_buf.len() + data.len());
393          buf.extend_from_slice(&self.recv.peek_buf);
394          buf.extend_from_slice(&data);
395          buf.freeze()
396        } else {
397          data.into()
398        }
399      })
400      .map_err(|e| match e {
401        quinn::ReadToEndError::Read(e) => std::io::Error::from(e),
402        quinn::ReadToEndError::TooLong => {
403          std::io::Error::new(std::io::ErrorKind::InvalidData, "packet too large")
404        }
405      })
406  }
407}
408
409/// A connection based on [`quinn`].
410pub struct QuinnConnection {
411  conn: Connection,
412  local_addr: SocketAddr,
413  remote_addr: SocketAddr,
414}
415
416impl QuinnConnection {
417  #[inline]
418  fn new(conn: Connection, local_addr: SocketAddr, remote_addr: SocketAddr) -> Self {
419    Self {
420      conn,
421      local_addr,
422      remote_addr,
423    }
424  }
425}
426
427impl QuicConnection for QuinnConnection {
428  type Stream = QuinnStream;
429
430  async fn accept_bi(&self) -> io::Result<(Self::Stream, SocketAddr)> {
431    let (send, recv) = self.conn.accept_bi().await?;
432    Ok((QuinnStream::new(send, recv), self.remote_addr))
433  }
434
435  async fn open_bi(&self) -> io::Result<(Self::Stream, SocketAddr)> {
436    let (send, recv) = self.conn.open_bi().await?;
437    Ok((QuinnStream::new(send, recv), self.remote_addr))
438  }
439
440  async fn close(&self) -> io::Result<()> {
441    self.conn.close(0u32.into(), b"close connection");
442    Ok(())
443  }
444
445  async fn is_closed(&self) -> bool {
446    self.conn.close_reason().is_some()
447  }
448
449  fn local_addr(&self) -> SocketAddr {
450    self.local_addr
451  }
452}