1use std::future::poll_fn;
4use std::marker::PhantomData;
5#[cfg(not(feature = "feat-rate-limit"))]
6use std::marker::PhantomPinned;
7use std::os::fd::AsFd;
8use std::pin::{pin, Pin};
9use std::task::{ready, Context, Poll};
10use std::{io, ops};
11
12use crossbeam_utils::CachePadded;
13use tokio::fs::File;
14use tokio::io::{AsyncRead, AsyncWrite, Interest};
15use tokio::net::{TcpStream, UnixStream};
16#[cfg(feature = "feat-rate-limit")]
17use tokio::time::Sleep;
18
19use crate::context::SpliceIoCtx;
20use crate::rate::RATE_LIMITER_DISABLED;
21#[cfg(feature = "feat-rate-limit")]
22use crate::rate::{RateLimit, RateLimitResult, RateLimiter, RATE_LIMITER_ENABLED};
23use crate::traffic::TrafficResult;
24use crate::utils::Drained;
25
26#[pin_project::pin_project]
27#[derive(Debug)]
28pub struct SpliceIo<R, W, const RATE_LIMITER_IS_ENABLED: bool = RATE_LIMITER_DISABLED> {
34 ctx: CachePadded<SpliceIoCtx<R, W>>,
38
39 r: PhantomData<R>,
40 w: PhantomData<W>,
41
42 #[cfg(feature = "feat-rate-limit")]
43 rate_limiter: RateLimiter<RATE_LIMITER_IS_ENABLED>,
45
46 #[pin]
47 state: TransferState,
48}
49
50impl<R, W, const RATE_LIMITER_IS_ENABLED: bool> ops::Deref
51 for SpliceIo<R, W, RATE_LIMITER_IS_ENABLED>
52{
53 type Target = SpliceIoCtx<R, W>;
54
55 fn deref(&self) -> &Self::Target {
56 &self.ctx
57 }
58}
59
60impl<R, W> SpliceIo<R, W, RATE_LIMITER_DISABLED> {
61 #[must_use]
62 pub fn new(ctx: SpliceIoCtx<R, W>) -> Self {
64 SpliceIo {
65 ctx: CachePadded::new(ctx),
66 r: PhantomData,
67 w: PhantomData,
68 #[cfg(feature = "feat-rate-limit")]
69 rate_limiter: RateLimiter::empty(),
70 state: TransferState::Draining,
71 }
72 }
73
74 #[cfg(feature = "feat-rate-limit")]
75 pub fn with_rate_limit(self, limit: RateLimit) -> SpliceIo<R, W, RATE_LIMITER_ENABLED> {
79 SpliceIo {
80 ctx: self.ctx,
81 r: self.r,
82 w: self.w,
83 rate_limiter: RateLimiter::new(limit),
84 state: self.state,
85 }
86 }
87}
88
89#[derive(Debug)]
90#[pin_project::pin_project(project = TransferStateProj)]
91enum TransferState {
92 Draining,
94
95 #[cfg_attr(not(feature = "feat-rate-limit"), allow(dead_code))]
96 Throttled {
98 #[cfg(feature = "feat-rate-limit")]
99 #[pin]
100 sleep: Sleep,
101
102 #[cfg(not(feature = "feat-rate-limit"))]
103 #[doc(hidden)]
104 #[pin]
105 _pinned: PhantomPinned,
107 },
108
109 Pumping,
111
112 Flushing,
114
115 Terminating,
117
118 Faulted { error: Option<io::Error> },
120
121 Finished,
123}
124
125impl<R, W, const RATE_LIMITER_IS_ENABLED: bool> SpliceIo<R, W, RATE_LIMITER_IS_ENABLED>
126where
127 R: AsyncReadFd,
128 W: AsyncWriteFd,
129{
130 pub async fn execute(self, r: &mut R, w: &mut W) -> TrafficResult
136 where
137 R: Unpin,
138 W: Unpin,
139 {
140 let mut this = pin!(self);
141 let mut r = Pin::new(r);
142 let mut w = Pin::new(w);
143
144 let error = poll_fn(|cx| this.as_mut().poll_execute(cx, r.as_mut(), w.as_mut()))
145 .await
146 .err();
147
148 this.ctx.traffic_client_tx(error)
149 }
150
151 #[cfg_attr(
152 any(
153 feature = "feat-tracing-trace",
154 all(debug_assertions, feature = "feat-tracing")
155 ),
156 tracing::instrument(level = "TRACE", skip(self, cx, r, w), ret)
157 )]
158 #[allow(clippy::too_many_lines)]
159 pub fn poll_execute(
172 mut self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 mut r: Pin<&mut R>,
175 mut w: Pin<&mut W>,
176 ) -> Poll<io::Result<()>> {
177 macro_rules! ready_or_cleanup {
178 ($e:expr, $state:expr) => {
179 match $e {
180 Poll::Ready(Ok(t)) => t,
181 Poll::Ready(Err(e)) => {
182 $state.set(TransferState::Faulted { error: Some(e) });
183 continue;
184 }
185 Poll::Pending => {
186 break Poll::Pending;
187 }
188 }
189 };
190 }
191
192 loop {
193 crate::enter_tracing_span!(
194 "loop",
195 ctx = ?self.ctx,
196 state = ?self.state,
197 );
198
199 let mut this = self.as_mut().project();
200
201 match this.state.as_mut().project() {
202 TransferStateProj::Draining => {
203 #[cfg(feature = "feat-rate-limit")]
204 let ideal_len = this.rate_limiter.ideal_len(this.ctx.pipe_size());
205
206 #[cfg(not(feature = "feat-rate-limit"))]
207 let ideal_len = None;
208
209 match ready_or_cleanup!(
210 this.ctx.poll_splice_drain(cx, r.as_mut(), ideal_len),
211 this.state.as_mut()
212 ) {
213 Drained::Some(_drained) => {
214 #[cfg(feature = "feat-rate-limit")]
215 {
216 #[allow(clippy::used_underscore_binding)]
217 match this.rate_limiter.check(_drained) {
218 RateLimitResult::Accepted => {}
219 RateLimitResult::Throttled { now, dur } => {
220 this.state.as_mut().set(TransferState::Throttled {
221 sleep: tokio::time::sleep_until(now + dur),
222 });
223 continue;
224 }
225 }
226 }
227 }
228 Drained::Done => {}
229 }
230
231 this.state.set(TransferState::Pumping);
232 }
233 #[cfg(feature = "feat-rate-limit")]
234 TransferStateProj::Throttled { sleep } => {
235 use std::future::Future;
236
237 ready!(sleep.poll(cx));
238
239 this.state.set(TransferState::Pumping);
241 }
242 #[cfg(not(feature = "feat-rate-limit"))]
243 TransferStateProj::Throttled { _pinned } => {
244 this.state.set(TransferState::Pumping);
246 }
247 TransferStateProj::Pumping => {
248 ready_or_cleanup!(
249 this.ctx.poll_splice_pump(cx, w.as_mut()),
250 this.state.as_mut()
251 );
252
253 if this.ctx.finished() {
254 this.state.set(TransferState::Terminating);
256 } else {
257 this.state.set(TransferState::Flushing);
259 }
260 }
261 TransferStateProj::Flushing => {
262 ready_or_cleanup!(w.as_mut().poll_flush(cx), this.state.as_mut());
263
264 this.state.set(TransferState::Draining);
265 }
266 TransferStateProj::Terminating => {
267 ready_or_cleanup!(w.as_mut().poll_shutdown(cx), this.state.as_mut());
268
269 this.state.set(TransferState::Finished);
270 }
271 TransferStateProj::Faulted { error } => {
272 if error.is_some() {
273 ready!(w.as_mut().poll_shutdown(cx))?;
275 } else {
276 #[cfg(feature = "feat-nightly")]
277 std::hint::cold_path();
278 }
279
280 let Some(error) = error.take() else {
281 #[cfg(feature = "feat-nightly")]
282 std::hint::cold_path();
283
284 break Poll::Ready(Err(io::Error::new(
285 io::ErrorKind::Other,
286 "`poll_execute()` called after error returned",
287 )));
288 };
289
290 break Poll::Ready(Err(error));
291 }
292 TransferStateProj::Finished => {
293 break Poll::Ready(Ok(()));
294 }
295 }
296 }
297 }
298}
299
300#[pin_project::pin_project]
301#[derive(Debug)]
302pub struct SpliceBidiIo<
304 SL,
305 SR,
306 const SL_RATE_LIMITER_IS_ENABLED: bool,
307 const SR_RATE_LIMITER_IS_ENABLED: bool,
308> {
309 #[pin]
310 pub io_sl2sr: SpliceIo<SL, SR, SL_RATE_LIMITER_IS_ENABLED>,
312
313 #[pin]
314 pub io_sr2sl: SpliceIo<SR, SL, SR_RATE_LIMITER_IS_ENABLED>,
316}
317
318impl<SL, SR, const SL_RATE_LIMITER_IS_ENABLED: bool, const SR_RATE_LIMITER_IS_ENABLED: bool>
319 SpliceBidiIo<SL, SR, SL_RATE_LIMITER_IS_ENABLED, SR_RATE_LIMITER_IS_ENABLED>
320where
321 SL: AsyncReadFd + AsyncWriteFd + IsNotFile,
322 SR: AsyncReadFd + AsyncWriteFd + IsNotFile,
323{
324 pub async fn execute(self, sl: &mut SL, sr: &mut SR) -> TrafficResult
330 where
331 SL: Unpin,
332 SR: Unpin,
333 {
334 let mut this = pin!(self);
335 let mut sl = Pin::new(sl);
336 let mut sr = Pin::new(sr);
337
338 let error = poll_fn(|cx| this.as_mut().poll_execute(cx, sl.as_mut(), sr.as_mut()))
339 .await
340 .err();
341
342 this.io_sl2sr
344 .ctx
345 .traffic_client_tx(error)
346 .merge(this.io_sr2sl.ctx.traffic_client_rx(None))
347 }
348
349 #[cfg_attr(
350 any(
351 feature = "feat-tracing-trace",
352 all(debug_assertions, feature = "feat-tracing")
353 ),
354 tracing::instrument(
355 level = "TRACE",
356 name = "SpliceBidiIo::poll_execute",
357 skip(self, cx, sl, sr),
358 ret
359 )
360 )]
361 pub fn poll_execute(
374 self: Pin<&mut Self>,
375 cx: &mut Context<'_>,
376 mut sl: Pin<&mut SL>,
377 mut sr: Pin<&mut SR>,
378 ) -> Poll<io::Result<()>> {
379 let mut this = self.project();
380
381 let io_sl2sr_ret = this
382 .io_sl2sr
383 .as_mut()
384 .poll_execute(cx, sl.as_mut(), sr.as_mut());
385 let io_sr2sl_ret = this
386 .io_sr2sl
387 .as_mut()
388 .poll_execute(cx, sr.as_mut(), sl.as_mut());
389
390 #[cfg(not(feature = "feat-brutal-shutdown"))]
391 {
392 match (io_sl2sr_ret, io_sr2sl_ret) {
393 (Poll::Pending, _) | (_, Poll::Pending) => Poll::Pending,
394 (Poll::Ready(Ok(())), Poll::Ready(Ok(()))) => Poll::Ready(Ok(())),
395 (Poll::Ready(Err(e)), _) | (_, Poll::Ready(Err(e))) => Poll::Ready(Err(e)),
396 }
397 }
398
399 #[cfg(feature = "feat-brutal-shutdown")]
400 {
401 match (io_sl2sr_ret, io_sr2sl_ret) {
402 (Poll::Pending, Poll::Pending) => Poll::Pending,
403 (Poll::Ready(Err(e)), _) | (_, Poll::Ready(Err(e))) => Poll::Ready(Err(e)),
404 (Poll::Ready(Ok(())), _) | (_, Poll::Ready(Ok(()))) => Poll::Ready(Ok(())),
406 }
407 }
408 }
409}
410
411pub trait IsFile {}
419
420impl<T> IsFile for &mut T where T: IsFile {}
421impl<T> IsFile for Pin<&mut T> where T: IsFile {}
422
423pub trait IsNotFile {}
428
429impl<T> IsNotFile for &mut T where T: IsNotFile {}
430impl<T> IsNotFile for Pin<&mut T> where T: IsNotFile {}
431
432pub trait AsyncReadFd: AsyncRead + AsFd {
437 #[doc(hidden)]
438 fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>>;
439
440 #[doc(hidden)]
441 fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R>;
442}
443
444impl<T: AsyncReadFd + Unpin> AsyncReadFd for &mut T {
445 #[inline]
446 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
447 (**self).poll_read_ready(cx)
448 }
449
450 #[inline]
451 fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
452 (**self).try_io_read(f)
453 }
454}
455
456pub trait AsyncWriteFd: AsyncWrite + AsFd {
461 #[doc(hidden)]
462 fn poll_write_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>>;
463
464 #[doc(hidden)]
465 fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R>;
466}
467
468impl<T: AsyncWriteFd + Unpin> AsyncWriteFd for &mut T {
469 #[inline]
470 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
471 (**self).poll_write_ready(cx)
472 }
473
474 #[inline]
475 fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
476 (**self).try_io_write(f)
477 }
478}
479
480macro_rules! impl_async_fd {
481 ($($ty:ty),+) => {
482 $(
483 impl AsyncReadFd for $ty {
484 #[inline]
485 fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
486 self.poll_read_ready(cx)
487 }
488
489 #[inline]
490 fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
491 self.try_io(Interest::READABLE, f)
492 }
493 }
494
495 impl AsyncWriteFd for $ty {
496 #[inline]
497 fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
498 self.poll_write_ready(cx)
499 }
500
501 #[inline]
502 fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
503 self.try_io(Interest::WRITABLE, f)
504 }
505 }
506
507 impl IsNotFile for $ty {}
508 )+
509 };
510 (FILE: $($ty:ty),+) => {
511 $(
512 impl AsyncReadFd for $ty {
513 #[inline]
514 fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
515 Poll::Ready(Ok(()))
516 }
517
518 #[inline]
519 fn try_io_read<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
520 f()
521 }
522 }
523
524 impl AsyncWriteFd for $ty {
525 #[inline]
526 fn poll_write_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
527 Poll::Ready(Ok(()))
528 }
529
530 #[inline]
531 fn try_io_write<R>(&self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> {
532 f()
533 }
534 }
535
536 impl IsFile for $ty {}
537 )+
538 };
539}
540
541impl_async_fd!(TcpStream, UnixStream);
542impl_async_fd!(FILE: File);