memberlist_quic/stream_layer/
quinn.rs1use std::{cmp, io, marker::PhantomData, net::SocketAddr, sync::Arc, time::Duration};
2
3use agnostic::Runtime;
4use futures::AsyncWriteExt;
5use memberlist_core::proto::MediumVec;
6use quinn::{ClientConfig, Connection, Endpoint, RecvStream, SendStream, VarInt};
7use smol_str::SmolStr;
8
9mod options;
10pub use options::*;
11
12use super::{QuicAcceptor, QuicConnection, QuicConnector, QuicStream, StreamLayer};
13
14pub struct Quinn<R> {
16 opts: QuinnOptions,
17 _m: PhantomData<R>,
18}
19
20impl<R> Quinn<R> {
21 fn new_in(opts: Options) -> Self {
23 Self {
24 opts: opts.into(),
25 _m: PhantomData,
26 }
27 }
28}
29
30impl<R: Runtime> StreamLayer for Quinn<R> {
31 type Runtime = R;
32 type Acceptor = QuinnAcceptor;
33 type Connector = QuinnConnector<R>;
34 type Connection = QuinnConnection;
35 type Stream = QuinnStream;
36 type Options = Options;
37
38 fn max_stream_data(&self) -> usize {
39 self.opts.max_stream_data.min(self.opts.max_connection_data)
40 }
41
42 async fn new(opts: Self::Options) -> io::Result<Self> {
43 Ok(Self::new_in(opts))
44 }
45
46 async fn bind(
47 &self,
48 addr: SocketAddr,
49 ) -> io::Result<(SocketAddr, Self::Acceptor, Self::Connector)> {
50 let server_name = self.opts.server_name.clone();
51
52 let client_config = self.opts.client_config.clone();
53 let sock = std::net::UdpSocket::bind(addr)?;
54 let auto_port = addr.port() == 0;
55
56 let endpoint = Arc::new(Endpoint::new(
57 self.opts.endpoint_config.clone(),
58 Some(self.opts.server_config.clone()),
59 sock,
60 Arc::new(R::quinn()),
61 )?);
62
63 let local_addr = endpoint.local_addr()?;
64 if auto_port {
65 tracing::info!(
66 "memberlist_quic.endpoint: binding to dynamic addr {}",
67 local_addr
68 );
69 }
70
71 let acceptor = Self::Acceptor {
72 endpoint: endpoint.clone(),
73 local_addr,
74 };
75
76 let connector = Self::Connector {
77 server_name,
78 endpoint,
79 local_addr,
80 client_config,
81 connect_timeout: self.opts.connect_timeout,
82 _marker: PhantomData,
83 };
84 Ok((local_addr, acceptor, connector))
85 }
86}
87
88pub struct QuinnAcceptor {
90 endpoint: Arc<Endpoint>,
91 local_addr: SocketAddr,
92}
93
94impl Clone for QuinnAcceptor {
95 fn clone(&self) -> Self {
96 Self {
97 endpoint: self.endpoint.clone(),
98 local_addr: self.local_addr,
99 }
100 }
101}
102
103impl QuicAcceptor for QuinnAcceptor {
104 type Connection = QuinnConnection;
105
106 async fn accept(&mut self) -> io::Result<(Self::Connection, SocketAddr)> {
107 let conn = self
108 .endpoint
109 .accept()
110 .await
111 .ok_or(io::Error::other("endpoint closed"))?
112 .await?;
113 let remote_addr = conn.remote_address();
114
115 Ok((
116 QuinnConnection::new(conn, self.local_addr, remote_addr),
117 remote_addr,
118 ))
119 }
120
121 async fn close(&mut self) -> io::Result<()> {
122 Endpoint::close(&self.endpoint, VarInt::from(0u32), b"close acceptor");
123 Ok(())
124 }
125
126 fn local_addr(&self) -> SocketAddr {
127 self.local_addr
128 }
129}
130
131impl Drop for QuinnAcceptor {
132 fn drop(&mut self) {
133 Endpoint::close(&self.endpoint, VarInt::from(0u32), b"close acceptor");
134 }
135}
136
137pub struct QuinnConnector<R> {
139 server_name: SmolStr,
140 endpoint: Arc<Endpoint>,
141 client_config: ClientConfig,
142 connect_timeout: Duration,
143 local_addr: SocketAddr,
144 _marker: PhantomData<R>,
145}
146
147impl<R> QuicConnector for QuinnConnector<R>
148where
149 R: Runtime,
150{
151 type Connection = QuinnConnection;
152
153 async fn connect(&self, addr: SocketAddr) -> io::Result<Self::Connection> {
154 let connecting = self
155 .endpoint
156 .connect_with(self.client_config.clone(), addr, &self.server_name)
157 .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
158 let conn = R::timeout(self.connect_timeout, connecting)
159 .await
160 .map_err(io::Error::from)??;
161 Ok(QuinnConnection::new(conn, self.local_addr, addr))
162 }
163
164 async fn close(&self) -> io::Result<()> {
165 Endpoint::close(&self.endpoint, VarInt::from(0u32), b"close connector");
166 Ok(())
167 }
168
169 async fn wait_idle(&self) -> io::Result<()> {
170 self.endpoint.wait_idle().await;
171 Ok(())
172 }
173
174 fn local_addr(&self) -> SocketAddr {
175 self.local_addr
176 }
177}
178
179impl<R> Drop for QuinnConnector<R> {
180 fn drop(&mut self) {
181 Endpoint::close(&self.endpoint, VarInt::from(0u32), b"close connector");
182 }
183}
184
185pub struct QuinnProtoReader {
187 stream: RecvStream,
188 peek_buf: MediumVec<u8>,
189}
190
191impl From<RecvStream> for QuinnProtoReader {
192 fn from(stream: RecvStream) -> Self {
193 Self {
194 stream,
195 peek_buf: MediumVec::new(),
196 }
197 }
198}
199
200impl memberlist_core::proto::ProtoReader for QuinnProtoReader {
201 async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
202 let dst_len = buf.len();
203 let peek_len = self.peek_buf.len();
204
205 match dst_len.cmp(&peek_len) {
206 cmp::Ordering::Less => {
207 buf.copy_from_slice(&self.peek_buf[..dst_len]);
208 Ok(dst_len)
209 }
210 cmp::Ordering::Equal => {
211 buf.copy_from_slice(&self.peek_buf);
212 Ok(peek_len)
213 }
214 cmp::Ordering::Greater => {
215 let want = dst_len - peek_len;
216 self.peek_buf.resize(dst_len, 0);
217 match self
218 .stream
219 .read(&mut self.peek_buf[peek_len..peek_len + want])
220 .await
221 {
222 Ok(Some(n)) => {
223 let has = peek_len + n;
224 if n < want {
225 self.peek_buf.truncate(has);
226 }
227 buf[..has].copy_from_slice(&self.peek_buf);
228 Ok(peek_len + n)
229 }
230 Ok(None) | Err(_) => {
231 self.peek_buf.truncate(peek_len);
232 buf[..peek_len].copy_from_slice(&self.peek_buf);
233 Ok(peek_len)
234 }
235 }
236 }
237 }
238 }
239
240 async fn peek_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
241 let dst_len = buf.len();
242 let peek_len = self.peek_buf.len();
243
244 match dst_len.cmp(&peek_len) {
245 cmp::Ordering::Less => {
246 buf.copy_from_slice(&self.peek_buf[..dst_len]);
247 Ok(())
248 }
249 cmp::Ordering::Equal => {
250 buf.copy_from_slice(&self.peek_buf);
251 Ok(())
252 }
253 cmp::Ordering::Greater => {
254 let want = dst_len - peek_len;
255 self.peek_buf.resize(dst_len, 0);
256 let readed = self
257 .stream
258 .read(&mut self.peek_buf[peek_len..peek_len + want])
259 .await?;
260
261 if let Some(n) = readed {
262 let has = peek_len + n;
263 if n < want {
264 return Err(std::io::Error::new(
265 std::io::ErrorKind::UnexpectedEof,
266 "unexpected eof",
267 ));
268 }
269 buf[..has].copy_from_slice(&self.peek_buf);
270 Ok(())
271 } else {
272 Err(std::io::Error::new(
273 std::io::ErrorKind::UnexpectedEof,
274 "unexpected eof",
275 ))
276 }
277 }
278 }
279 }
280
281 async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
282 let dst_len = buf.len();
283 let peek_len = self.peek_buf.len();
284
285 if dst_len <= peek_len {
286 buf.copy_from_slice(&self.peek_buf[..dst_len]);
287 self.peek_buf.drain(..dst_len);
288 Ok(dst_len)
289 } else {
290 buf[..peek_len].copy_from_slice(&self.peek_buf);
291 self.peek_buf.clear();
292 self
293 .stream
294 .read(&mut buf[peek_len..])
295 .await
296 .map(|read| read.unwrap_or(0))
297 .map_err(Into::into)
298 }
299 }
300
301 async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
302 let dst_len = buf.len();
303 let peek_len = self.peek_buf.len();
304
305 if dst_len <= peek_len {
306 buf.copy_from_slice(&self.peek_buf[..dst_len]);
307 self.peek_buf.drain(..dst_len);
308 Ok(())
309 } else {
310 buf[..peek_len].copy_from_slice(&self.peek_buf);
311 self.peek_buf.clear();
312 self
313 .stream
314 .read(&mut buf[peek_len..])
315 .await
316 .map(|_| ())
317 .map_err(Into::into)
318 }
319 }
320}
321
322pub struct QuinnStream {
324 send: SendStream,
325 recv: QuinnProtoReader,
326}
327
328impl QuinnStream {
329 #[inline]
330 fn new(send: SendStream, recv: RecvStream) -> Self {
331 Self {
332 send,
333 recv: recv.into(),
334 }
335 }
336}
337
338impl memberlist_core::transport::Connection for QuinnStream {
339 type Reader = QuinnProtoReader;
340
341 type Writer = SendStream;
342
343 fn split(self) -> (Self::Reader, Self::Writer) {
344 (self.recv, self.send)
345 }
346
347 async fn close(&mut self) -> std::io::Result<()> {
348 self.send.close().await
349 }
350
351 async fn write_all(&mut self, payload: &[u8]) -> std::io::Result<()> {
352 self.send.write_all(payload).await.map_err(Into::into)
353 }
354
355 async fn flush(&mut self) -> std::io::Result<()> {
356 self.send.flush().await
357 }
358
359 async fn peek(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
360 memberlist_core::proto::ProtoReader::peek(&mut self.recv, buf).await
361 }
362
363 async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
364 memberlist_core::proto::ProtoReader::read_exact(&mut self.recv, buf).await
365 }
366
367 async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
368 memberlist_core::proto::ProtoReader::read(&mut self.recv, buf).await
369 }
370
371 async fn peek_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
372 memberlist_core::proto::ProtoReader::peek_exact(&mut self.recv, buf).await
373 }
374
375 fn consume_peek(&mut self) {
376 self.recv.peek_buf.clear();
377 }
378}
379
380impl QuicStream for QuinnStream {
381 type SendStream = SendStream;
382
383 async fn read_packet(&mut self) -> std::io::Result<bytes::Bytes> {
384 self
386 .recv
387 .stream
388 .read_to_end(u32::MAX as usize)
389 .await
390 .map(|data| {
391 if !self.recv.peek_buf.is_empty() {
392 let mut buf = bytes::BytesMut::with_capacity(self.recv.peek_buf.len() + data.len());
393 buf.extend_from_slice(&self.recv.peek_buf);
394 buf.extend_from_slice(&data);
395 buf.freeze()
396 } else {
397 data.into()
398 }
399 })
400 .map_err(|e| match e {
401 quinn::ReadToEndError::Read(e) => std::io::Error::from(e),
402 quinn::ReadToEndError::TooLong => {
403 std::io::Error::new(std::io::ErrorKind::InvalidData, "packet too large")
404 }
405 })
406 }
407}
408
409pub struct QuinnConnection {
411 conn: Connection,
412 local_addr: SocketAddr,
413 remote_addr: SocketAddr,
414}
415
416impl QuinnConnection {
417 #[inline]
418 fn new(conn: Connection, local_addr: SocketAddr, remote_addr: SocketAddr) -> Self {
419 Self {
420 conn,
421 local_addr,
422 remote_addr,
423 }
424 }
425}
426
427impl QuicConnection for QuinnConnection {
428 type Stream = QuinnStream;
429
430 async fn accept_bi(&self) -> io::Result<(Self::Stream, SocketAddr)> {
431 let (send, recv) = self.conn.accept_bi().await?;
432 Ok((QuinnStream::new(send, recv), self.remote_addr))
433 }
434
435 async fn open_bi(&self) -> io::Result<(Self::Stream, SocketAddr)> {
436 let (send, recv) = self.conn.open_bi().await?;
437 Ok((QuinnStream::new(send, recv), self.remote_addr))
438 }
439
440 async fn close(&self) -> io::Result<()> {
441 self.conn.close(0u32.into(), b"close connection");
442 Ok(())
443 }
444
445 async fn is_closed(&self) -> bool {
446 self.conn.close_reason().is_some()
447 }
448
449 fn local_addr(&self) -> SocketAddr {
450 self.local_addr
451 }
452}