1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25use ant_libp2p_core as libp2p_core;
26
27use std::{
28 collections::VecDeque,
29 io,
30 io::{IoSlice, IoSliceMut},
31 iter,
32 pin::Pin,
33 task::{Context, Poll, Waker},
34};
35
36use either::Either;
37use futures::{prelude::*, ready};
38use libp2p_core::{
39 muxing::{StreamMuxer, StreamMuxerEvent},
40 upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeInfo},
41};
42use thiserror::Error;
43
44#[derive(Debug)]
46pub struct Muxer<C> {
47 connection: Either<yamux012::Connection<C>, yamux013::Connection<C>>,
48 inbound_stream_buffer: VecDeque<Stream>,
60 inbound_stream_waker: Option<Waker>,
62}
63
64const MAX_BUFFERED_INBOUND_STREAMS: usize = 256;
70
71impl<C> Muxer<C>
72where
73 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
74{
75 fn new(connection: Either<yamux012::Connection<C>, yamux013::Connection<C>>) -> Self {
77 Muxer {
78 connection,
79 inbound_stream_buffer: VecDeque::default(),
80 inbound_stream_waker: None,
81 }
82 }
83}
84
85impl<C> StreamMuxer for Muxer<C>
86where
87 C: AsyncRead + AsyncWrite + Unpin + 'static,
88{
89 type Substream = Stream;
90 type Error = Error;
91
92 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_inbound", skip(self, cx))]
93 fn poll_inbound(
94 mut self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 ) -> Poll<Result<Self::Substream, Self::Error>> {
97 if let Some(stream) = self.inbound_stream_buffer.pop_front() {
98 return Poll::Ready(Ok(stream));
99 }
100
101 if let Poll::Ready(res) = self.poll_inner(cx) {
102 return Poll::Ready(res);
103 }
104
105 self.inbound_stream_waker = Some(cx.waker().clone());
106 Poll::Pending
107 }
108
109 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_outbound", skip(self, cx))]
110 fn poll_outbound(
111 mut self: Pin<&mut Self>,
112 cx: &mut Context<'_>,
113 ) -> Poll<Result<Self::Substream, Self::Error>> {
114 let stream = match self.connection.as_mut() {
115 Either::Left(c) => ready!(c.poll_new_outbound(cx))
116 .map_err(|e| Error(Either::Left(e)))
117 .map(|s| Stream(Either::Left(s))),
118 Either::Right(c) => ready!(c.poll_new_outbound(cx))
119 .map_err(|e| Error(Either::Right(e)))
120 .map(|s| Stream(Either::Right(s))),
121 }?;
122 Poll::Ready(Ok(stream))
123 }
124
125 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_close", skip(self, cx))]
126 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127 match self.connection.as_mut() {
128 Either::Left(c) => c.poll_close(cx).map_err(|e| Error(Either::Left(e))),
129 Either::Right(c) => c.poll_close(cx).map_err(|e| Error(Either::Right(e))),
130 }
131 }
132
133 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll", skip(self, cx))]
134 fn poll(
135 self: Pin<&mut Self>,
136 cx: &mut Context<'_>,
137 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
138 let this = self.get_mut();
139
140 let inbound_stream = ready!(this.poll_inner(cx))?;
141
142 if this.inbound_stream_buffer.len() >= MAX_BUFFERED_INBOUND_STREAMS {
143 tracing::warn!(
144 stream=%inbound_stream.0,
145 "dropping stream because buffer is full"
146 );
147 drop(inbound_stream);
148 } else {
149 this.inbound_stream_buffer.push_back(inbound_stream);
150
151 if let Some(waker) = this.inbound_stream_waker.take() {
152 waker.wake()
153 }
154 }
155
156 cx.waker().wake_by_ref();
158 Poll::Pending
159 }
160}
161
162#[derive(Debug)]
164pub struct Stream(Either<yamux012::Stream, yamux013::Stream>);
165
166impl AsyncRead for Stream {
167 fn poll_read(
168 mut self: Pin<&mut Self>,
169 cx: &mut Context<'_>,
170 buf: &mut [u8],
171 ) -> Poll<io::Result<usize>> {
172 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_read(cx, buf))
173 }
174
175 fn poll_read_vectored(
176 mut self: Pin<&mut Self>,
177 cx: &mut Context<'_>,
178 bufs: &mut [IoSliceMut<'_>],
179 ) -> Poll<io::Result<usize>> {
180 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_read_vectored(cx, bufs))
181 }
182}
183
184impl AsyncWrite for Stream {
185 fn poll_write(
186 mut self: Pin<&mut Self>,
187 cx: &mut Context<'_>,
188 buf: &[u8],
189 ) -> Poll<io::Result<usize>> {
190 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_write(cx, buf))
191 }
192
193 fn poll_write_vectored(
194 mut self: Pin<&mut Self>,
195 cx: &mut Context<'_>,
196 bufs: &[IoSlice<'_>],
197 ) -> Poll<io::Result<usize>> {
198 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_write_vectored(cx, bufs))
199 }
200
201 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
202 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_flush(cx))
203 }
204
205 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
206 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_close(cx))
207 }
208}
209
210impl<C> Muxer<C>
211where
212 C: AsyncRead + AsyncWrite + Unpin + 'static,
213{
214 fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream, Error>> {
215 let stream = match self.connection.as_mut() {
216 Either::Left(c) => ready!(c.poll_next_inbound(cx))
217 .ok_or(Error(Either::Left(yamux012::ConnectionError::Closed)))?
218 .map_err(|e| Error(Either::Left(e)))
219 .map(|s| Stream(Either::Left(s)))?,
220 Either::Right(c) => ready!(c.poll_next_inbound(cx))
221 .ok_or(Error(Either::Right(yamux013::ConnectionError::Closed)))?
222 .map_err(|e| Error(Either::Right(e)))
223 .map(|s| Stream(Either::Right(s)))?,
224 };
225
226 Poll::Ready(Ok(stream))
227 }
228}
229
230#[derive(Debug, Clone)]
232pub struct Config(Either<Config012, Config013>);
233
234impl Default for Config {
235 fn default() -> Self {
236 Self(Either::Right(Config013::default()))
237 }
238}
239
240#[derive(Debug, Clone)]
241struct Config012 {
242 inner: yamux012::Config,
243 mode: Option<yamux012::Mode>,
244}
245
246impl Default for Config012 {
247 fn default() -> Self {
248 let mut inner = yamux012::Config::default();
249 inner.set_read_after_close(false);
252 Self { inner, mode: None }
253 }
254}
255
256pub struct WindowUpdateMode(yamux012::WindowUpdateMode);
259
260impl WindowUpdateMode {
261 #[deprecated(note = "Use `WindowUpdateMode::on_read` instead.")]
274 pub fn on_receive() -> Self {
275 #[allow(deprecated)]
276 WindowUpdateMode(yamux012::WindowUpdateMode::OnReceive)
277 }
278
279 pub fn on_read() -> Self {
294 WindowUpdateMode(yamux012::WindowUpdateMode::OnRead)
295 }
296}
297
298impl Config {
299 #[deprecated(note = "Will be removed with the next breaking release.")]
302 pub fn client() -> Self {
303 Self(Either::Left(Config012 {
304 mode: Some(yamux012::Mode::Client),
305 ..Default::default()
306 }))
307 }
308
309 #[deprecated(note = "Will be removed with the next breaking release.")]
312 pub fn server() -> Self {
313 Self(Either::Left(Config012 {
314 mode: Some(yamux012::Mode::Server),
315 ..Default::default()
316 }))
317 }
318
319 #[deprecated(
321 note = "Will be replaced in the next breaking release with a connection receive window size limit."
322 )]
323 pub fn set_receive_window_size(&mut self, num_bytes: u32) -> &mut Self {
324 self.set(|cfg| cfg.set_receive_window(num_bytes))
325 }
326
327 #[deprecated(note = "Will be removed with the next breaking release.")]
329 pub fn set_max_buffer_size(&mut self, num_bytes: usize) -> &mut Self {
330 self.set(|cfg| cfg.set_max_buffer_size(num_bytes))
331 }
332
333 pub fn set_max_num_streams(&mut self, num_streams: usize) -> &mut Self {
335 self.set(|cfg| cfg.set_max_num_streams(num_streams))
336 }
337
338 #[deprecated(
341 note = "`WindowUpdate::OnRead` is the default. `WindowUpdate::OnReceive` breaks backpressure, is thus not recommended, and will be removed in the next breaking release. Thus this method becomes obsolete and will be removed with the next breaking release."
342 )]
343 pub fn set_window_update_mode(&mut self, mode: WindowUpdateMode) -> &mut Self {
344 self.set(|cfg| cfg.set_window_update_mode(mode.0))
345 }
346
347 fn set(&mut self, f: impl FnOnce(&mut yamux012::Config) -> &mut yamux012::Config) -> &mut Self {
348 let cfg012 = match self.0.as_mut() {
349 Either::Left(c) => &mut c.inner,
350 Either::Right(_) => {
351 self.0 = Either::Left(Config012::default());
352 &mut self.0.as_mut().unwrap_left().inner
353 }
354 };
355
356 f(cfg012);
357
358 self
359 }
360}
361
362impl UpgradeInfo for Config {
363 type Info = &'static str;
364 type InfoIter = iter::Once<Self::Info>;
365
366 fn protocol_info(&self) -> Self::InfoIter {
367 iter::once("/yamux/1.0.0")
368 }
369}
370
371impl<C> InboundConnectionUpgrade<C> for Config
372where
373 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
374{
375 type Output = Muxer<C>;
376 type Error = io::Error;
377 type Future = future::Ready<Result<Self::Output, Self::Error>>;
378
379 fn upgrade_inbound(self, io: C, _: Self::Info) -> Self::Future {
380 let connection = match self.0 {
381 Either::Left(Config012 { inner, mode }) => Either::Left(yamux012::Connection::new(
382 io,
383 inner,
384 mode.unwrap_or(yamux012::Mode::Server),
385 )),
386 Either::Right(Config013(cfg)) => {
387 Either::Right(yamux013::Connection::new(io, cfg, yamux013::Mode::Server))
388 }
389 };
390
391 future::ready(Ok(Muxer::new(connection)))
392 }
393}
394
395impl<C> OutboundConnectionUpgrade<C> for Config
396where
397 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
398{
399 type Output = Muxer<C>;
400 type Error = io::Error;
401 type Future = future::Ready<Result<Self::Output, Self::Error>>;
402
403 fn upgrade_outbound(self, io: C, _: Self::Info) -> Self::Future {
404 let connection = match self.0 {
405 Either::Left(Config012 { inner, mode }) => Either::Left(yamux012::Connection::new(
406 io,
407 inner,
408 mode.unwrap_or(yamux012::Mode::Client),
409 )),
410 Either::Right(Config013(cfg)) => {
411 Either::Right(yamux013::Connection::new(io, cfg, yamux013::Mode::Client))
412 }
413 };
414
415 future::ready(Ok(Muxer::new(connection)))
416 }
417}
418
419#[derive(Debug, Clone)]
420struct Config013(yamux013::Config);
421
422impl Default for Config013 {
423 fn default() -> Self {
424 let mut cfg = yamux013::Config::default();
425 cfg.set_read_after_close(false);
428 Self(cfg)
429 }
430}
431
432#[derive(Debug, Error)]
434#[error(transparent)]
435pub struct Error(Either<yamux012::ConnectionError, yamux013::ConnectionError>);
436
437impl From<Error> for io::Error {
438 fn from(err: Error) -> Self {
439 match err.0 {
440 Either::Left(err) => match err {
441 yamux012::ConnectionError::Io(e) => e,
442 e => io::Error::new(io::ErrorKind::Other, e),
443 },
444 Either::Right(err) => match err {
445 yamux013::ConnectionError::Io(e) => e,
446 e => io::Error::new(io::ErrorKind::Other, e),
447 },
448 }
449 }
450}
451
452#[cfg(test)]
453mod test {
454 use super::*;
455 #[test]
456 fn config_set_switches_to_v012() {
457 let mut cfg = Config::default();
460 assert!(matches!(
461 cfg,
462 Config(Either::Right(Config013(yamux013::Config { .. })))
463 ));
464
465 cfg.set_max_num_streams(42);
467 assert!(matches!(cfg, Config(Either::Left(Config012 { .. }))));
468 }
469}