Skip to main content

razor_stream/client/
pool.rs

1use crate::client::stream::ClientStream;
2use crate::client::{
3    ClientCaller, ClientCallerBlocking, ClientFacts, ClientTransport, task::ClientTaskDone,
4};
5use crate::error::RpcIntErr;
6use captains_log::filter::LogFilter;
7use crossfire::{MAsyncRx, MAsyncTx, MTx, RecvTimeoutError, mpmc};
8use orb::prelude::*;
9use std::fmt;
10use std::marker::PhantomData;
11use std::sync::Arc;
12use std::sync::atomic::{
13    AtomicBool, AtomicUsize,
14    Ordering::{Acquire, Relaxed, Release, SeqCst},
15};
16use std::time::Duration;
17
18/// Connection pool to the one server address (supports async and blocking context)
19///
20/// There's a worker accepting task post in bounded channel.
21///
22/// Even when the server address is not reachable, the worker coroutine will not exit,
23/// until ClientPool is dropped.
24///
25/// The background coroutine will:
26/// - monitor the address with ping task (action 0)
27/// - cleanup the task in channel with error_handle when the address is unhealthy
28///
29/// If the connection is healthy and there's incoming, the worker will spawn another coroutine for
30/// monitor purpose.
31///
32/// considering:
33/// - The task incoming might never stop until faulty pool remove from pools collection
34/// - If ping mixed with task with real business, might blocked due to throttler of in-flight
35/// message in the stream.
36pub struct ClientPool<F: ClientFacts, P: ClientTransport> {
37    tx_async: MAsyncTx<mpmc::Array<F::Task>>,
38    tx: MTx<mpmc::Array<F::Task>>,
39    inner: Arc<ClientPoolInner<F, P>>,
40}
41
42impl<F: ClientFacts, P: ClientTransport> Clone for ClientPool<F, P> {
43    fn clone(&self) -> Self {
44        Self { tx_async: self.tx_async.clone(), tx: self.tx.clone(), inner: self.inner.clone() }
45    }
46}
47
48struct ClientPoolInner<F: ClientFacts, P: ClientTransport> {
49    facts: Arc<F>,
50    logger: Arc<LogFilter>,
51    rx: MAsyncRx<mpmc::Array<F::Task>>,
52    addr: String,
53    conn_id: String,
54    /// whether connection is healthy?
55    is_ok: AtomicBool,
56    /// dynamic worker count (not the monitor)
57    worker_count: AtomicUsize,
58    /// dynamic worker count (not the monitor)
59    connected_worker_count: AtomicUsize,
60    ///// Set by user
61    //limit: AtomicUsize, // TODO
62    _phan: PhantomData<fn(&P)>,
63}
64
65const ONE_SEC: Duration = Duration::from_secs(1);
66
67impl<F: ClientFacts, P: ClientTransport> ClientPool<F, P> {
68    pub fn new(facts: Arc<F>, addr: &str, mut channel_size: usize) -> Self {
69        let config = facts.get_config();
70        if config.thresholds > 0 {
71            if channel_size < config.thresholds {
72                channel_size = config.thresholds;
73            }
74        } else if channel_size == 0 {
75            channel_size = 128;
76        }
77        let (tx_async, rx) = mpmc::bounded_async(channel_size);
78        let tx = tx_async.clone().into();
79        let conn_id = format!("to {}", addr);
80        let inner = Arc::new(ClientPoolInner {
81            logger: facts.new_logger(),
82            facts: facts.clone(),
83            rx,
84            addr: addr.to_string(),
85            conn_id,
86            is_ok: AtomicBool::new(true),
87            worker_count: AtomicUsize::new(0),
88            connected_worker_count: AtomicUsize::new(0),
89            _phan: Default::default(),
90        });
91        let s = Self { tx_async, tx, inner };
92        s.spawn();
93        s
94    }
95
96    #[inline(always)]
97    pub fn is_healthy(&self) -> bool {
98        self.inner.is_ok.load(Relaxed)
99    }
100
101    #[inline]
102    pub fn get_addr(&self) -> &str {
103        &self.inner.addr
104    }
105
106    #[inline]
107    pub async fn send_req(&self, task: F::Task) {
108        ClientCaller::send_req(self, task).await;
109    }
110
111    #[inline]
112    pub fn send_req_blocking(&self, task: F::Task) {
113        ClientCallerBlocking::send_req_blocking(self, task);
114    }
115
116    #[inline]
117    pub fn spawn(&self) {
118        let worker_id = self.inner.worker_count.fetch_add(1, Acquire);
119        self.inner.clone().spawn_worker(worker_id);
120    }
121}
122
123impl<F: ClientFacts, P: ClientTransport> Drop for ClientPoolInner<F, P> {
124    fn drop(&mut self) {
125        self.cleanup();
126        logger_trace!(self.logger, "{} dropped", self);
127    }
128}
129
130impl<F: ClientFacts, P: ClientTransport> ClientCaller for ClientPool<F, P> {
131    type Facts = F;
132    #[inline]
133    async fn send_req(&self, task: F::Task) {
134        self.tx_async.send(task).await.expect("submit");
135    }
136}
137
138impl<F: ClientFacts, P: ClientTransport> ClientCallerBlocking for ClientPool<F, P> {
139    type Facts = F;
140    #[inline]
141    fn send_req_blocking(&self, task: F::Task) {
142        self.tx.send(task).expect("submit");
143    }
144}
145
146impl<F: ClientFacts, P: ClientTransport> fmt::Display for ClientPoolInner<F, P> {
147    #[inline]
148    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149        write!(f, "ConnPool {}", self.conn_id)
150    }
151}
152
153impl<F: ClientFacts, P: ClientTransport> ClientPoolInner<F, P> {
154    fn spawn_worker(self: Arc<Self>, worker_id: usize) {
155        let facts = self.facts.clone();
156        facts.spawn_detach(async move {
157            logger_trace!(&self.logger, "{} worker_id={} running", self, worker_id);
158            self.run(worker_id).await;
159            self.worker_count.fetch_sub(1, SeqCst);
160            logger_trace!(&self.logger, "{} worker_id={} exit", self, worker_id);
161        });
162    }
163
164    #[inline(always)]
165    fn get_workers(&self) -> usize {
166        self.worker_count.load(SeqCst)
167    }
168
169    #[inline(always)]
170    fn get_healthy_workers(&self) -> usize {
171        self.connected_worker_count.load(SeqCst)
172    }
173
174    #[inline(always)]
175    fn set_err(&self) {
176        self.is_ok.store(false, SeqCst);
177    }
178
179    #[inline]
180    async fn connect(&self) -> Result<ClientStream<F, P>, RpcIntErr> {
181        ClientStream::connect(self.facts.clone(), &self.addr, &self.conn_id, None).await
182    }
183
184    #[inline(always)]
185    async fn _run_worker(
186        &self, _worker_id: usize, stream: &mut ClientStream<F, P>,
187    ) -> Result<(), RpcIntErr> {
188        loop {
189            match self.rx.recv().await {
190                Ok(task) => {
191                    stream.send_task(task, false).await?;
192                    while let Ok(task) = self.rx.try_recv() {
193                        stream.send_task(task, false).await?;
194                    }
195                    stream.flush_req().await?;
196                }
197                Err(_) => {
198                    stream.flush_req().await?;
199                    return Ok(());
200                }
201            }
202        }
203    }
204
205    async fn run_worker(
206        &self, worker_id: usize, stream: &mut ClientStream<F, P>,
207    ) -> Result<(), RpcIntErr> {
208        self.connected_worker_count.fetch_add(1, Acquire);
209        let r = self._run_worker(worker_id, stream).await;
210        logger_trace!(self.logger, "{} worker {} exit: {}", self, worker_id, r.is_ok());
211        self.connected_worker_count.fetch_add(1, Release);
212        r
213    }
214
215    async fn run(self: &Arc<Self>, mut worker_id: usize) {
216        'CONN_LOOP: loop {
217            match self.connect().await {
218                Ok(mut stream) => {
219                    logger_trace!(self.logger, "{} worker={} connected", self, worker_id);
220                    if worker_id == 0 {
221                        // act as monitor
222                        'MONITOR: loop {
223                            if self.get_workers() > 1 {
224                                F::sleep(ONE_SEC).await;
225                                if stream.ping().await.is_err() {
226                                    self.set_err();
227                                    // don't cleanup the channel unless only one worker left
228                                    continue 'CONN_LOOP;
229                                }
230                            } else {
231                                match self.rx.recv_with_timer(F::sleep(ONE_SEC)).await {
232                                    Err(RecvTimeoutError::Disconnected) => {
233                                        return;
234                                    }
235                                    Err(RecvTimeoutError::Timeout) => {
236                                        if stream.ping().await.is_err() {
237                                            self.set_err();
238                                            self.cleanup();
239                                            continue 'CONN_LOOP;
240                                        }
241                                    }
242                                    Ok(task) => {
243                                        if stream.get_inflight_count() > 0
244                                            && self.get_workers() == 1
245                                        {
246                                            if self
247                                                .worker_count
248                                                .compare_exchange(1, 2, SeqCst, Relaxed)
249                                                .is_ok()
250                                            {
251                                                // there's might be a lag to connect,
252                                                // so we are spawning identity with new worker,
253                                                worker_id = 1;
254                                                self.clone().spawn_worker(0);
255                                            }
256                                        }
257                                        if stream.send_task(task, true).await.is_err() {
258                                            self.set_err();
259                                            if worker_id == 0 {
260                                                self.cleanup();
261                                                F::sleep(ONE_SEC).await;
262                                                continue 'CONN_LOOP;
263                                            } else {
264                                                return;
265                                            }
266                                        } else if worker_id > 0 {
267                                            logger_trace!(
268                                                self.logger,
269                                                "{} worker={} break monitor",
270                                                self,
271                                                worker_id
272                                            );
273                                            // taken over as run_worker.
274                                            break 'MONITOR;
275                                        }
276                                    }
277                                }
278                            }
279                        }
280                    }
281                    if worker_id > 0 {
282                        if self.run_worker(worker_id, &mut stream).await.is_err() {
283                            self.set_err();
284                            // don't cleanup the channel unless only one worker left
285                        }
286                        // TODO If worker will exit automiatically when idle_time passed
287                        return;
288                    }
289                }
290                Err(e) => {
291                    self.set_err();
292                    error!("connect failed to {}: {}", self.addr, e);
293                    self.cleanup();
294                    F::sleep(ONE_SEC).await;
295                }
296            }
297        }
298    }
299
300    fn cleanup(&self) {
301        while let Ok(mut task) = self.rx.try_recv() {
302            task.set_rpc_error(RpcIntErr::Unreachable);
303            logger_trace!(self.logger, "{} set task err due not not healthy", self);
304            self.facts.error_handle(task);
305        }
306    }
307}