Skip to main content

razor_stream/client/
pool.rs

1use crate::client::stream::ClientStream;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientFacts, ClientTransport, task::ClientTaskDone,
4};
5use crate::error::RpcIntErr;
6use captains_log::filter::LogFilter;
7use crossfire::{MAsyncRx, MAsyncTx, MTx, RecvTimeoutError, mpmc};
8use orb::AsyncRuntime;
9use std::fmt;
10use std::marker::PhantomData;
11use std::sync::Arc;
12use std::sync::atomic::{
13    AtomicBool, AtomicUsize,
14    Ordering::{Acquire, Relaxed, Release, SeqCst},
15};
16use std::time::Duration;
17
18/// Connection pool to the one server address (supports async and blocking context)
19///
20/// There's a worker accepting task post in bounded channel.
21///
22/// Even when the server address is not reachable, the worker coroutine will not exit,
23/// until ClientPool is dropped.
24///
25/// The background coroutine will:
26/// - monitor the address with ping task (action 0)
27/// - cleanup the task in channel with error_handle when the address is unhealthy
28///
29/// If the connection is healthy and there's incoming, the worker will spawn another coroutine for
30/// monitor purpose.
31///
32/// considering:
33/// - The task incoming might never stop until faulty pool remove from pools collection
34/// - If ping mixed with task with real business, might blocked due to throttler of in-flight
35///   message in the stream.
36pub struct ClientPool<F: ClientFacts, P: ClientTransport> {
37    tx_async: MAsyncTx<mpmc::Array<F::Task>>,
38    tx: MTx<mpmc::Array<F::Task>>,
39    inner: Arc<ClientPoolInner<F, P>>,
40}
41
42impl<F: ClientFacts, P: ClientTransport> Clone for ClientPool<F, P> {
43    fn clone(&self) -> Self {
44        Self { tx_async: self.tx_async.clone(), tx: self.tx.clone(), inner: self.inner.clone() }
45    }
46}
47
48struct ClientPoolInner<F: ClientFacts, P: ClientTransport> {
49    facts: Arc<F>,
50    logger: Arc<LogFilter>,
51    rx: MAsyncRx<mpmc::Array<F::Task>>,
52    addr: String,
53    conn_id: String,
54    /// whether connection is healthy?
55    is_ok: AtomicBool,
56    /// dynamic worker count (not the monitor)
57    worker_count: AtomicUsize,
58    /// dynamic worker count (not the monitor)
59    connected_worker_count: AtomicUsize,
60    ///// Set by user
61    //limit: AtomicUsize, // TODO
62    _phan: PhantomData<fn(&P)>,
63}
64
65const ONE_SEC: Duration = Duration::from_secs(1);
66
67impl<F: ClientFacts, P: ClientTransport> ClientPool<F, P> {
68    pub fn new<RT: AsyncRuntime + Clone>(
69        facts: Arc<F>, rt: &RT, addr: &str, mut channel_size: usize,
70    ) -> Self {
71        let config = facts.get_config();
72        if config.thresholds > 0 {
73            if channel_size < config.thresholds {
74                channel_size = config.thresholds;
75            }
76        } else if channel_size == 0 {
77            channel_size = 128;
78        }
79        let (tx_async, rx) = mpmc::bounded_async(channel_size);
80        let tx = tx_async.clone().into();
81        let conn_id = format!("to {}", addr);
82        let inner = Arc::new(ClientPoolInner {
83            logger: facts.new_logger(),
84            facts: facts.clone(),
85            rx,
86            addr: addr.to_string(),
87            conn_id,
88            is_ok: AtomicBool::new(true),
89            worker_count: AtomicUsize::new(0),
90            connected_worker_count: AtomicUsize::new(0),
91            _phan: Default::default(),
92        });
93        let s = Self { tx_async, tx, inner };
94        s.spawn::<RT>(rt);
95        s
96    }
97
98    #[inline(always)]
99    pub fn is_healthy(&self) -> bool {
100        self.inner.is_ok.load(Relaxed)
101    }
102
103    #[inline]
104    pub fn get_addr(&self) -> &str {
105        &self.inner.addr
106    }
107
108    #[inline]
109    pub async fn send_req(&self, task: F::Task) {
110        ClientCaller::send_req(self, task).await;
111    }
112
113    #[inline]
114    pub fn send_req_blocking(&self, task: F::Task) {
115        ClientCallerBlocking::send_req_blocking(self, task);
116    }
117
118    /// by default there's one worker thread after initiation, but you can pre-spawn more thread if
119    /// the connection is not enough to achieve desired throughput.
120    #[inline]
121    pub fn spawn<RT: AsyncRuntime + Clone>(&self, rt: &RT) {
122        let worker_id = self.inner.worker_count.fetch_add(1, Acquire);
123        self.inner.clone().spawn_worker(rt, worker_id);
124    }
125}
126
127impl<F: ClientFacts, P: ClientTransport> Drop for ClientPoolInner<F, P> {
128    fn drop(&mut self) {
129        self.cleanup();
130        logger_trace!(self.logger, "{} dropped", self);
131    }
132}
133
134impl<F: ClientFacts, P: ClientTransport> ClientCaller for ClientPool<F, P> {
135    type Facts = F;
136    #[inline]
137    async fn send_req(&self, task: F::Task) {
138        self.tx_async.send(task).await.expect("submit");
139    }
140}
141
142impl<F: ClientFacts, P: ClientTransport> ClientCallerBlocking for ClientPool<F, P> {
143    type Facts = F;
144    #[inline]
145    fn send_req_blocking(&self, task: F::Task) {
146        self.tx.send(task).expect("submit");
147    }
148}
149
150impl<F: ClientFacts, P: ClientTransport> fmt::Display for ClientPoolInner<F, P> {
151    #[inline]
152    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153        write!(f, "ConnPool {}", self.conn_id)
154    }
155}
156
157impl<F: ClientFacts, P: ClientTransport> ClientPoolInner<F, P> {
158    fn spawn_worker<RT: AsyncRuntime + Clone>(self: Arc<Self>, rt: &RT, worker_id: usize) {
159        let _rt = rt.clone();
160        rt.spawn_detach(async move {
161            logger_trace!(&self.logger, "{} worker_id={} running", self, worker_id);
162            self.run(_rt, worker_id).await;
163            self.worker_count.fetch_sub(1, SeqCst);
164            logger_trace!(&self.logger, "{} worker_id={} exit", self, worker_id);
165        });
166    }
167
168    #[inline(always)]
169    fn get_workers(&self) -> usize {
170        self.worker_count.load(SeqCst)
171    }
172
173    //    #[inline(always)]
174    //    fn get_healthy_workers(&self) -> usize {
175    //        self.connected_worker_count.load(SeqCst)
176    //    }
177
178    #[inline(always)]
179    fn set_err(&self) {
180        self.is_ok.store(false, SeqCst);
181    }
182
183    #[inline]
184    async fn connect<RT: AsyncRuntime>(&self, rt: &RT) -> Result<ClientStream<F, P>, RpcIntErr> {
185        ClientStream::connect(self.facts.clone(), rt, &self.addr, &self.conn_id, None).await
186    }
187
188    #[inline(always)]
189    async fn _run_worker(
190        &self, _worker_id: usize, stream: &mut ClientStream<F, P>,
191    ) -> Result<(), RpcIntErr> {
192        loop {
193            match self.rx.recv().await {
194                Ok(task) => {
195                    stream.send_task(task, false).await?;
196                    while let Ok(task) = self.rx.try_recv() {
197                        stream.send_task(task, false).await?;
198                    }
199                    stream.flush_req().await?;
200                }
201                Err(_) => {
202                    stream.flush_req().await?;
203                    return Ok(());
204                }
205            }
206        }
207    }
208
209    async fn run_worker(
210        &self, worker_id: usize, stream: &mut ClientStream<F, P>,
211    ) -> Result<(), RpcIntErr> {
212        self.connected_worker_count.fetch_add(1, Acquire);
213        let r = self._run_worker(worker_id, stream).await;
214        logger_trace!(self.logger, "{} worker {} exit: {}", self, worker_id, r.is_ok());
215        self.connected_worker_count.fetch_add(1, Release);
216        r
217    }
218
219    /// The worker maintains connection state,
220    /// connection attempts happens after we spawn.
221    /// If the address is dead, the thread might exit after multiple attempts, and later re-spawn
222    /// when the needs arrives.
223    async fn run<RT: AsyncRuntime + Clone>(self: &Arc<Self>, rt: RT, mut worker_id: usize) {
224        'CONN_LOOP: loop {
225            match self.connect::<RT>(&rt).await {
226                Ok(mut stream) => {
227                    logger_trace!(self.logger, "{} worker={} connected", self, worker_id);
228                    if worker_id == 0 {
229                        // act as monitor
230                        'MONITOR: loop {
231                            if self.get_workers() > 1 {
232                                RT::sleep(ONE_SEC).await;
233                                if stream.ping().await.is_err() {
234                                    self.set_err();
235                                    // don't cleanup the channel unless only one worker left
236                                    continue 'CONN_LOOP;
237                                }
238                            } else {
239                                match self.rx.recv_with_timer(RT::sleep(ONE_SEC)).await {
240                                    Err(RecvTimeoutError::Disconnected) => {
241                                        return;
242                                    }
243                                    Err(RecvTimeoutError::Timeout) => {
244                                        if stream.ping().await.is_err() {
245                                            self.set_err();
246                                            self.cleanup();
247                                            continue 'CONN_LOOP;
248                                        }
249                                    }
250                                    Ok(task) => {
251                                        if stream.get_inflight_count() > 0
252                                            && self.get_workers() == 1
253                                            && self
254                                                .worker_count
255                                                .compare_exchange(1, 2, SeqCst, Relaxed)
256                                                .is_ok()
257                                        {
258                                            // there's might be a lag to connect,
259                                            // so we are spawning identity with new worker,
260                                            worker_id = 1;
261                                            self.clone().spawn_worker::<RT>(&rt, 0);
262                                        }
263                                        if stream.send_task(task, true).await.is_err() {
264                                            self.set_err();
265                                            if worker_id == 0 {
266                                                self.cleanup();
267                                                RT::sleep(ONE_SEC).await;
268                                                continue 'CONN_LOOP;
269                                            } else {
270                                                return;
271                                            }
272                                        } else if worker_id > 0 {
273                                            logger_trace!(
274                                                self.logger,
275                                                "{} worker={} break monitor",
276                                                self,
277                                                worker_id
278                                            );
279                                            // taken over as run_worker.
280                                            break 'MONITOR;
281                                        }
282                                    }
283                                }
284                            }
285                        }
286                    }
287                    if worker_id > 0 {
288                        if self.run_worker(worker_id, &mut stream).await.is_err() {
289                            self.set_err();
290                            // don't cleanup the channel unless only one worker left
291                        }
292                        // TODO If worker will exit automiatically when idle_time passed
293                        return;
294                    }
295                }
296                Err(e) => {
297                    self.set_err();
298                    error!("connect failed to {}: {}", self.addr, e);
299                    self.cleanup();
300                    RT::sleep(ONE_SEC).await;
301                }
302            }
303        }
304    }
305
306    fn cleanup(&self) {
307        while let Ok(mut task) = self.rx.try_recv() {
308            task.set_rpc_error(RpcIntErr::Unreachable);
309            logger_trace!(self.logger, "{} set task err due not not healthy", self);
310            self.facts.error_handle(task);
311        }
312    }
313}