Skip to main content

razor_stream/client/
pool.rs

1use crate::client::stream::ClientStream;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientFacts, ClientTransport, task::ClientTaskDone,
4};
5use crate::error::RpcIntErr;
6use captains_log::filter::LogFilter;
7use crossfire::{MAsyncRx, MAsyncTx, MTx, RecvTimeoutError, mpmc};
8use orb::prelude::{AsyncExec, AsyncRuntime, AsyncTime};
9use std::fmt;
10use std::marker::PhantomData;
11use std::sync::Arc;
12use std::sync::atomic::{
13    AtomicBool, AtomicUsize,
14    Ordering::{Acquire, Relaxed, Release, SeqCst},
15};
16use std::time::Duration;
17
18/// Connection pool to the one server address (supports async and blocking context)
19///
20/// There's a worker accepting task post in bounded channel.
21///
22/// Even when the server address is not reachable, the worker coroutine will not exit,
23/// until ConnPool is dropped.
24///
25/// The background coroutine will:
26/// - monitor the address with ping task (action 0)
27/// - cleanup the task in channel with error_handle when the address is unhealthy
28///
29/// If the connection is healthy and there's incoming, the worker will spawn another coroutine for
30/// monitor purpose.
31///
32/// considering:
33/// - The task incoming might never stop until faulty pool remove from pools collection
34/// - If ping mixed with task with real business, might blocked due to throttler of in-flight
35///   message in the stream.
36pub struct ConnPool<F: ClientFacts, P: ClientTransport> {
37    tx_async: MAsyncTx<mpmc::Array<F::Task>>,
38    tx: MTx<mpmc::Array<F::Task>>,
39    inner: Arc<ConnPoolInner<F, P>>,
40}
41
42impl<F: ClientFacts, P: ClientTransport> Clone for ConnPool<F, P> {
43    fn clone(&self) -> Self {
44        Self { tx_async: self.tx_async.clone(), tx: self.tx.clone(), inner: self.inner.clone() }
45    }
46}
47
48struct ConnPoolInner<F: ClientFacts, P: ClientTransport> {
49    facts: Arc<F>,
50    logger: Arc<LogFilter>,
51    rx: MAsyncRx<mpmc::Array<F::Task>>,
52    addr: String,
53    conn_id: String,
54    /// whether connection is healthy?
55    is_ok: AtomicBool,
56    /// dynamic worker count (not the monitor)
57    worker_count: AtomicUsize,
58    /// dynamic worker count (not the monitor)
59    connected_worker_count: AtomicUsize,
60    ///// Set by user
61    //limit: AtomicUsize, // TODO
62    _phan: PhantomData<fn(&P)>,
63}
64
65const ONE_SEC: Duration = Duration::from_secs(1);
66
67impl<F: ClientFacts, P: ClientTransport> ConnPool<F, P> {
68    /// # Argument
69    ///
70    /// - `rt`: When we are in orb async context, just pass None, otherwise (in thread context),
71    ///   pass the AsyncRuntime::Exec.
72    pub fn new(
73        facts: Arc<F>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, addr: &str,
74        mut channel_size: usize,
75    ) -> Self {
76        let config = facts.get_config();
77        if config.thresholds > 0 {
78            if channel_size < config.thresholds {
79                channel_size = config.thresholds;
80            }
81        } else if channel_size == 0 {
82            channel_size = 128;
83        }
84        let (tx_async, rx) = mpmc::bounded_async(channel_size);
85        let tx = tx_async.clone().into();
86        let conn_id = format!("to {}", addr);
87        let inner = Arc::new(ConnPoolInner {
88            logger: facts.new_logger(),
89            facts: facts.clone(),
90            rx,
91            addr: addr.to_string(),
92            conn_id,
93            is_ok: AtomicBool::new(true),
94            worker_count: AtomicUsize::new(0),
95            connected_worker_count: AtomicUsize::new(0),
96            _phan: Default::default(),
97        });
98        let s = Self { tx_async, tx, inner };
99        s.spawn(rt);
100        s
101    }
102
103    #[inline(always)]
104    pub fn is_healthy(&self) -> bool {
105        self.inner.is_ok.load(Relaxed)
106    }
107
108    #[inline]
109    pub fn get_addr(&self) -> &str {
110        &self.inner.addr
111    }
112
113    #[inline]
114    pub async fn send_req(&self, task: F::Task) {
115        ClientCaller::send_req(self, task).await;
116    }
117
118    #[inline]
119    pub fn send_req_blocking(&self, task: F::Task) {
120        ClientCallerBlocking::send_req_blocking(self, task);
121    }
122
123    /// by default there's one worker thread after initiation, but you can pre-spawn more thread if
124    /// the connection is not enough to achieve desired throughput.
125    #[inline]
126    pub fn spawn(&self, rt: Option<&<P::RT as AsyncRuntime>::Exec>) {
127        let worker_id = self.inner.worker_count.fetch_add(1, Acquire);
128        self.inner.clone().spawn_worker(rt, worker_id);
129    }
130}
131
132impl<F: ClientFacts, P: ClientTransport> Drop for ConnPoolInner<F, P> {
133    fn drop(&mut self) {
134        self.cleanup();
135        logger_trace!(self.logger, "{} dropped", self);
136    }
137}
138
139impl<F: ClientFacts, P: ClientTransport> ClientCaller for ConnPool<F, P> {
140    type Facts = F;
141    #[inline]
142    async fn send_req(&self, task: F::Task) {
143        self.tx_async.send(task).await.expect("submit");
144    }
145}
146
147impl<F: ClientFacts, P: ClientTransport> ClientCallerBlocking for ConnPool<F, P> {
148    type Facts = F;
149    #[inline]
150    fn send_req_blocking(&self, task: F::Task) {
151        self.tx.send(task).expect("submit");
152    }
153}
154
155impl<F: ClientFacts, P: ClientTransport> fmt::Display for ConnPoolInner<F, P> {
156    #[inline]
157    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
158        write!(f, "ConnPool {}", self.conn_id)
159    }
160}
161
162impl<F: ClientFacts, P: ClientTransport> ConnPoolInner<F, P> {
163    #[inline]
164    fn spawn_worker(self: Arc<Self>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, worker_id: usize) {
165        let f = async move {
166            logger_trace!(&self.logger, "{} worker_id={} running", self, worker_id);
167            self.run(worker_id).await;
168            self.worker_count.fetch_sub(1, SeqCst);
169            logger_trace!(&self.logger, "{} worker_id={} exit", self, worker_id);
170        };
171        if let Some(_rt) = rt {
172            _rt.spawn_detach(f);
173        } else {
174            P::RT::spawn_detach(f);
175        }
176    }
177
178    #[inline(always)]
179    fn get_workers(&self) -> usize {
180        self.worker_count.load(SeqCst)
181    }
182
183    //    #[inline(always)]
184    //    fn get_healthy_workers(&self) -> usize {
185    //        self.connected_worker_count.load(SeqCst)
186    //    }
187
188    #[inline(always)]
189    fn set_err(&self) {
190        self.is_ok.store(false, SeqCst);
191    }
192
193    #[inline]
194    async fn connect(&self) -> Result<ClientStream<F, P>, RpcIntErr> {
195        ClientStream::connect(self.facts.clone(), None, &self.addr, &self.conn_id, None).await
196    }
197
198    #[inline(always)]
199    async fn _run_worker(
200        &self, _worker_id: usize, stream: &mut ClientStream<F, P>,
201    ) -> Result<(), RpcIntErr> {
202        loop {
203            match self.rx.recv().await {
204                Ok(task) => {
205                    stream.send_task(task, false).await?;
206                    while let Ok(task) = self.rx.try_recv() {
207                        stream.send_task(task, false).await?;
208                    }
209                    stream.flush_req().await?;
210                }
211                Err(_) => {
212                    stream.flush_req().await?;
213                    return Ok(());
214                }
215            }
216        }
217    }
218
219    async fn run_worker(
220        &self, worker_id: usize, stream: &mut ClientStream<F, P>,
221    ) -> Result<(), RpcIntErr> {
222        self.connected_worker_count.fetch_add(1, Acquire);
223        let r = self._run_worker(worker_id, stream).await;
224        logger_trace!(self.logger, "{} worker {} exit: {}", self, worker_id, r.is_ok());
225        self.connected_worker_count.fetch_add(1, Release);
226        r
227    }
228
229    /// The worker maintains connection state,
230    /// connection attempts happens after we spawn.
231    /// If the address is dead, the thread might exit after multiple attempts, and later re-spawn
232    /// when the needs arrives.
233    async fn run(self: &Arc<Self>, mut worker_id: usize) {
234        'CONN_LOOP: loop {
235            match self.connect().await {
236                Ok(mut stream) => {
237                    logger_trace!(self.logger, "{} worker={} connected", self, worker_id);
238                    if worker_id == 0 {
239                        // act as monitor
240                        'MONITOR: loop {
241                            if self.get_workers() > 1 {
242                                <P::RT as AsyncTime>::sleep(ONE_SEC).await;
243                                if stream.ping().await.is_err() {
244                                    self.set_err();
245                                    // don't cleanup the channel unless only one worker left
246                                    continue 'CONN_LOOP;
247                                }
248                            } else {
249                                match self
250                                    .rx
251                                    .recv_with_timer(<P::RT as AsyncTime>::sleep(ONE_SEC))
252                                    .await
253                                {
254                                    Err(RecvTimeoutError::Disconnected) => {
255                                        return;
256                                    }
257                                    Err(RecvTimeoutError::Timeout) => {
258                                        if stream.ping().await.is_err() {
259                                            self.set_err();
260                                            self.cleanup();
261                                            continue 'CONN_LOOP;
262                                        }
263                                    }
264                                    Ok(task) => {
265                                        if stream.get_inflight_count() > 0
266                                            && self.get_workers() == 1
267                                            && self
268                                                .worker_count
269                                                .compare_exchange(1, 2, SeqCst, Relaxed)
270                                                .is_ok()
271                                        {
272                                            // there's might be a lag to connect,
273                                            // so we are spawning identity with new worker,
274                                            worker_id = 1;
275                                            self.clone().spawn_worker(None, 0);
276                                        }
277                                        if stream.send_task(task, true).await.is_err() {
278                                            self.set_err();
279                                            if worker_id == 0 {
280                                                self.cleanup();
281                                                <P::RT as AsyncTime>::sleep(ONE_SEC).await;
282                                                continue 'CONN_LOOP;
283                                            } else {
284                                                return;
285                                            }
286                                        } else if worker_id > 0 {
287                                            logger_trace!(
288                                                self.logger,
289                                                "{} worker={} break monitor",
290                                                self,
291                                                worker_id
292                                            );
293                                            // taken over as run_worker.
294                                            break 'MONITOR;
295                                        }
296                                    }
297                                }
298                            }
299                        }
300                    }
301                    if worker_id > 0 {
302                        if self.run_worker(worker_id, &mut stream).await.is_err() {
303                            self.set_err();
304                            // don't cleanup the channel unless only one worker left
305                        }
306                        // TODO If worker will exit automiatically when idle_time passed
307                        return;
308                    }
309                }
310                Err(e) => {
311                    self.set_err();
312                    error!("connect failed to {}: {}", self.addr, e);
313                    self.cleanup();
314                    <P::RT as AsyncTime>::sleep(ONE_SEC).await;
315                }
316            }
317        }
318    }
319
320    fn cleanup(&self) {
321        while let Ok(mut task) = self.rx.try_recv() {
322            task.set_rpc_error(RpcIntErr::Unreachable);
323            logger_trace!(self.logger, "{} set task err due not not healthy", self);
324            self.facts.error_handle(task);
325        }
326    }
327}