spider/utils/
connect.rs

1use pin_project_lite::pin_project;
2use std::{
3    future::Future,
4    pin::Pin,
5    sync::atomic::AtomicBool,
6    task::{Context, Poll},
7};
8use tokio::{
9    select,
10    sync::{mpsc::error::SendError, OnceCell},
11};
12use tower::{BoxError, Layer, Service};
13
14/// A threadpool dedicated for connecting to services.
15static CONNECT_THREAD_POOL: OnceCell<
16    tokio::sync::mpsc::UnboundedSender<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
17> = OnceCell::const_new();
18
19/// Is the background thread connect enabled.
20static BACKGROUND_THREAD_CONNECT_ENABLED: AtomicBool = AtomicBool::new(true);
21
22/// Is the background thread inited.
23pub(crate) fn background_connect_threading() -> bool {
24    BACKGROUND_THREAD_CONNECT_ENABLED.load(std::sync::atomic::Ordering::Relaxed)
25}
26
27/// Init a background thread for request connect handling.
28#[cfg(all(target_os = "linux", feature = "io_uring"))]
29pub fn init_background_runtime() {
30    let _ = CONNECT_THREAD_POOL.set({
31        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
32        let builder = std::thread::Builder::new();
33
34        if let Err(_) = builder.spawn(move || {
35            tokio_uring::builder().start(async {
36                while let Some(work) = rx.recv().await {
37                    tokio_uring::spawn(work);
38                }
39            })
40        }) {
41            let _ = tx.downgrade();
42            BACKGROUND_THREAD_CONNECT_ENABLED.store(false, std::sync::atomic::Ordering::Relaxed);
43        };
44
45        tx
46    });
47}
48
49/// Init a background thread for request connect handling.
50#[cfg(any(not(target_os = "linux"), not(feature = "io_uring")))]
51pub fn init_background_runtime() {
52    let _ = CONNECT_THREAD_POOL.set({
53        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
54        let builder = std::thread::Builder::new();
55
56        if let Err(_) = builder.spawn(move || {
57            match tokio::runtime::Builder::new_multi_thread()
58                .thread_name("connect-background-pool-thread")
59                .worker_threads(num_cpus::get() as usize)
60                .on_thread_start(move || {
61                    #[cfg(target_os = "linux")]
62                    unsafe {
63                        if libc::nice(10) == -1 && *libc::__errno_location() != 0 {
64                            let error = std::io::Error::last_os_error();
65                            log::error!("failed to set threadpool niceness: {}", error);
66                        }
67                    }
68                })
69                .enable_all()
70                .build()
71            {
72                Ok(rt) => {
73                    rt.block_on(async move {
74                        while let Some(work) = rx.recv().await {
75                            tokio::task::spawn(work);
76                        }
77                    });
78                }
79                _ => {
80                    BACKGROUND_THREAD_CONNECT_ENABLED
81                        .store(false, std::sync::atomic::Ordering::Relaxed);
82                }
83            }
84        }) {
85            let _ = tx.downgrade();
86            BACKGROUND_THREAD_CONNECT_ENABLED.store(false, std::sync::atomic::Ordering::Relaxed);
87        };
88
89        tx
90    });
91}
92
93/// This tower layer injects futures with a oneshot channel, and then sends them to the background runtime for processing.
94#[derive(Copy, Clone)]
95pub struct BackgroundProcessorLayer;
96
97impl BackgroundProcessorLayer {
98    /// A new background proccess layer shortcut.
99    pub fn new() -> Self {
100        Self
101    }
102}
103impl<S> Layer<S> for BackgroundProcessorLayer {
104    type Service = BackgroundProcessor<S>;
105    fn layer(&self, service: S) -> Self::Service {
106        BackgroundProcessor::new(service)
107    }
108}
109
110impl std::fmt::Debug for BackgroundProcessorLayer {
111    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
112        f.debug_struct("BackgroundProcessorLayer").finish()
113    }
114}
115
116/// Send to the background runtime.
117fn send_to_background_runtime(future: impl Future<Output = ()> + Send + 'static) {
118    let tx = CONNECT_THREAD_POOL.get().expect(
119        "background runtime should be initialized by calling init_background_runtime before use",
120    );
121
122    if let Err(SendError(_)) = tx.send(Box::pin(future)) {
123        log::error!("Failed to send future - background connect runtime channel is closed. Abandoning task.");
124    }
125}
126
127/// This tower service injects futures with a oneshot channel, and then sends them to the background runtime for processing.
128#[derive(Debug, Clone)]
129pub struct BackgroundProcessor<S> {
130    inner: S,
131}
132
133impl<S> BackgroundProcessor<S> {
134    /// Setup a new connect background processor.
135    pub fn new(inner: S) -> Self {
136        BackgroundProcessor { inner }
137    }
138}
139
140impl<S, Request> Service<Request> for BackgroundProcessor<S>
141where
142    S: Service<Request>,
143    S::Response: Send + 'static,
144    S::Error: Into<BoxError> + Send,
145    S::Future: Send + 'static,
146{
147    type Response = S::Response;
148    type Error = BoxError;
149    type Future = BackgroundResponseFuture<S::Response>;
150
151    fn poll_ready(
152        &mut self,
153        cx: &mut std::task::Context<'_>,
154    ) -> std::task::Poll<Result<(), Self::Error>> {
155        match self.inner.poll_ready(cx) {
156            Poll::Pending => Poll::Pending,
157            Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)),
158        }
159    }
160
161    fn call(&mut self, req: Request) -> Self::Future {
162        let response = self.inner.call(req);
163        let (mut tx, rx) = tokio::sync::oneshot::channel();
164
165        let future = async move {
166            select! {
167                _ = tx.closed() => (),
168                result = response => {
169                    let _ = tx.send(result.map_err(Into::into));
170                }
171            }
172        };
173
174        send_to_background_runtime(future);
175        BackgroundResponseFuture::new(rx)
176    }
177}
178
179pin_project! {
180    #[derive(Debug)]
181    /// A new background response future.
182    pub struct BackgroundResponseFuture<S> {
183        #[pin]
184        rx: tokio::sync::oneshot::Receiver<Result<S, BoxError>>,
185    }
186}
187
188impl<S> BackgroundResponseFuture<S> {
189    pub(crate) fn new(rx: tokio::sync::oneshot::Receiver<Result<S, BoxError>>) -> Self {
190        BackgroundResponseFuture { rx }
191    }
192}
193
194impl<S> Future for BackgroundResponseFuture<S>
195where
196    S: Send + 'static,
197{
198    type Output = Result<S, BoxError>;
199
200    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201        let this = self.project();
202        match this.rx.poll(cx) {
203            Poll::Ready(v) => match v {
204                Ok(v) => Poll::Ready(v.map_err(Into::into)),
205                Err(err) => Poll::Ready(Err(Box::new(err) as BoxError)),
206            },
207            Poll::Pending => Poll::Pending,
208        }
209    }
210}