1use crate::client::stream::ClientStream;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientFacts, ClientTransport, task::ClientTaskDone,
4};
5use crate::error::RpcIntErr;
6use captains_log::filter::LogFilter;
7use crossfire::{MAsyncRx, MAsyncTx, MTx, RecvTimeoutError, mpmc};
8use orb::AsyncRuntime;
9use std::fmt;
10use std::marker::PhantomData;
11use std::sync::Arc;
12use std::sync::atomic::{
13 AtomicBool, AtomicUsize,
14 Ordering::{Acquire, Relaxed, Release, SeqCst},
15};
16use std::time::Duration;
17
18pub struct ClientPool<F: ClientFacts, P: ClientTransport> {
37 tx_async: MAsyncTx<mpmc::Array<F::Task>>,
38 tx: MTx<mpmc::Array<F::Task>>,
39 inner: Arc<ClientPoolInner<F, P>>,
40}
41
42impl<F: ClientFacts, P: ClientTransport> Clone for ClientPool<F, P> {
43 fn clone(&self) -> Self {
44 Self { tx_async: self.tx_async.clone(), tx: self.tx.clone(), inner: self.inner.clone() }
45 }
46}
47
48struct ClientPoolInner<F: ClientFacts, P: ClientTransport> {
49 facts: Arc<F>,
50 logger: Arc<LogFilter>,
51 rx: MAsyncRx<mpmc::Array<F::Task>>,
52 addr: String,
53 conn_id: String,
54 is_ok: AtomicBool,
56 worker_count: AtomicUsize,
58 connected_worker_count: AtomicUsize,
60 _phan: PhantomData<fn(&P)>,
63}
64
65const ONE_SEC: Duration = Duration::from_secs(1);
66
67impl<F: ClientFacts, P: ClientTransport> ClientPool<F, P> {
68 pub fn new<RT: AsyncRuntime + Clone>(
69 facts: Arc<F>, rt: &RT, addr: &str, mut channel_size: usize,
70 ) -> Self {
71 let config = facts.get_config();
72 if config.thresholds > 0 {
73 if channel_size < config.thresholds {
74 channel_size = config.thresholds;
75 }
76 } else if channel_size == 0 {
77 channel_size = 128;
78 }
79 let (tx_async, rx) = mpmc::bounded_async(channel_size);
80 let tx = tx_async.clone().into();
81 let conn_id = format!("to {}", addr);
82 let inner = Arc::new(ClientPoolInner {
83 logger: facts.new_logger(),
84 facts: facts.clone(),
85 rx,
86 addr: addr.to_string(),
87 conn_id,
88 is_ok: AtomicBool::new(true),
89 worker_count: AtomicUsize::new(0),
90 connected_worker_count: AtomicUsize::new(0),
91 _phan: Default::default(),
92 });
93 let s = Self { tx_async, tx, inner };
94 s.spawn::<RT>(rt);
95 s
96 }
97
98 #[inline(always)]
99 pub fn is_healthy(&self) -> bool {
100 self.inner.is_ok.load(Relaxed)
101 }
102
103 #[inline]
104 pub fn get_addr(&self) -> &str {
105 &self.inner.addr
106 }
107
108 #[inline]
109 pub async fn send_req(&self, task: F::Task) {
110 ClientCaller::send_req(self, task).await;
111 }
112
113 #[inline]
114 pub fn send_req_blocking(&self, task: F::Task) {
115 ClientCallerBlocking::send_req_blocking(self, task);
116 }
117
118 #[inline]
121 pub fn spawn<RT: AsyncRuntime + Clone>(&self, rt: &RT) {
122 let worker_id = self.inner.worker_count.fetch_add(1, Acquire);
123 self.inner.clone().spawn_worker(rt, worker_id);
124 }
125}
126
127impl<F: ClientFacts, P: ClientTransport> Drop for ClientPoolInner<F, P> {
128 fn drop(&mut self) {
129 self.cleanup();
130 logger_trace!(self.logger, "{} dropped", self);
131 }
132}
133
134impl<F: ClientFacts, P: ClientTransport> ClientCaller for ClientPool<F, P> {
135 type Facts = F;
136 #[inline]
137 async fn send_req(&self, task: F::Task) {
138 self.tx_async.send(task).await.expect("submit");
139 }
140}
141
142impl<F: ClientFacts, P: ClientTransport> ClientCallerBlocking for ClientPool<F, P> {
143 type Facts = F;
144 #[inline]
145 fn send_req_blocking(&self, task: F::Task) {
146 self.tx.send(task).expect("submit");
147 }
148}
149
150impl<F: ClientFacts, P: ClientTransport> fmt::Display for ClientPoolInner<F, P> {
151 #[inline]
152 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153 write!(f, "ConnPool {}", self.conn_id)
154 }
155}
156
157impl<F: ClientFacts, P: ClientTransport> ClientPoolInner<F, P> {
158 fn spawn_worker<RT: AsyncRuntime + Clone>(self: Arc<Self>, rt: &RT, worker_id: usize) {
159 let _rt = rt.clone();
160 rt.spawn_detach(async move {
161 logger_trace!(&self.logger, "{} worker_id={} running", self, worker_id);
162 self.run(_rt, worker_id).await;
163 self.worker_count.fetch_sub(1, SeqCst);
164 logger_trace!(&self.logger, "{} worker_id={} exit", self, worker_id);
165 });
166 }
167
168 #[inline(always)]
169 fn get_workers(&self) -> usize {
170 self.worker_count.load(SeqCst)
171 }
172
173 #[inline(always)]
179 fn set_err(&self) {
180 self.is_ok.store(false, SeqCst);
181 }
182
183 #[inline]
184 async fn connect<RT: AsyncRuntime>(&self, rt: &RT) -> Result<ClientStream<F, P>, RpcIntErr> {
185 ClientStream::connect(self.facts.clone(), rt, &self.addr, &self.conn_id, None).await
186 }
187
188 #[inline(always)]
189 async fn _run_worker(
190 &self, _worker_id: usize, stream: &mut ClientStream<F, P>,
191 ) -> Result<(), RpcIntErr> {
192 loop {
193 match self.rx.recv().await {
194 Ok(task) => {
195 stream.send_task(task, false).await?;
196 while let Ok(task) = self.rx.try_recv() {
197 stream.send_task(task, false).await?;
198 }
199 stream.flush_req().await?;
200 }
201 Err(_) => {
202 stream.flush_req().await?;
203 return Ok(());
204 }
205 }
206 }
207 }
208
209 async fn run_worker(
210 &self, worker_id: usize, stream: &mut ClientStream<F, P>,
211 ) -> Result<(), RpcIntErr> {
212 self.connected_worker_count.fetch_add(1, Acquire);
213 let r = self._run_worker(worker_id, stream).await;
214 logger_trace!(self.logger, "{} worker {} exit: {}", self, worker_id, r.is_ok());
215 self.connected_worker_count.fetch_add(1, Release);
216 r
217 }
218
219 async fn run<RT: AsyncRuntime + Clone>(self: &Arc<Self>, rt: RT, mut worker_id: usize) {
224 'CONN_LOOP: loop {
225 match self.connect::<RT>(&rt).await {
226 Ok(mut stream) => {
227 logger_trace!(self.logger, "{} worker={} connected", self, worker_id);
228 if worker_id == 0 {
229 'MONITOR: loop {
231 if self.get_workers() > 1 {
232 RT::sleep(ONE_SEC).await;
233 if stream.ping().await.is_err() {
234 self.set_err();
235 continue 'CONN_LOOP;
237 }
238 } else {
239 match self.rx.recv_with_timer(RT::sleep(ONE_SEC)).await {
240 Err(RecvTimeoutError::Disconnected) => {
241 return;
242 }
243 Err(RecvTimeoutError::Timeout) => {
244 if stream.ping().await.is_err() {
245 self.set_err();
246 self.cleanup();
247 continue 'CONN_LOOP;
248 }
249 }
250 Ok(task) => {
251 if stream.get_inflight_count() > 0
252 && self.get_workers() == 1
253 && self
254 .worker_count
255 .compare_exchange(1, 2, SeqCst, Relaxed)
256 .is_ok()
257 {
258 worker_id = 1;
261 self.clone().spawn_worker::<RT>(&rt, 0);
262 }
263 if stream.send_task(task, true).await.is_err() {
264 self.set_err();
265 if worker_id == 0 {
266 self.cleanup();
267 RT::sleep(ONE_SEC).await;
268 continue 'CONN_LOOP;
269 } else {
270 return;
271 }
272 } else if worker_id > 0 {
273 logger_trace!(
274 self.logger,
275 "{} worker={} break monitor",
276 self,
277 worker_id
278 );
279 break 'MONITOR;
281 }
282 }
283 }
284 }
285 }
286 }
287 if worker_id > 0 {
288 if self.run_worker(worker_id, &mut stream).await.is_err() {
289 self.set_err();
290 }
292 return;
294 }
295 }
296 Err(e) => {
297 self.set_err();
298 error!("connect failed to {}: {}", self.addr, e);
299 self.cleanup();
300 RT::sleep(ONE_SEC).await;
301 }
302 }
303 }
304 }
305
306 fn cleanup(&self) {
307 while let Ok(mut task) = self.rx.try_recv() {
308 task.set_rpc_error(RpcIntErr::Unreachable);
309 logger_trace!(self.logger, "{} set task err due not not healthy", self);
310 self.facts.error_handle(task);
311 }
312 }
313}