1use crate::client::stream::ClientStream;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientFacts, ClientTransport, task::ClientTaskDone,
4};
5use crate::error::RpcIntErr;
6use captains_log::filter::LogFilter;
7use crossfire::{MAsyncRx, MAsyncTx, MTx, RecvTimeoutError, mpmc};
8use orb::prelude::{AsyncExec, AsyncTime};
9use std::fmt;
10use std::marker::PhantomData;
11use std::sync::Arc;
12use std::sync::atomic::{
13 AtomicBool, AtomicUsize,
14 Ordering::{Acquire, Relaxed, Release, SeqCst},
15};
16use std::time::Duration;
17
18pub struct ConnPool<F: ClientFacts, P: ClientTransport> {
37 tx_async: MAsyncTx<mpmc::Array<F::Task>>,
38 tx: MTx<mpmc::Array<F::Task>>,
39 inner: Arc<ConnPoolInner<F, P>>,
40}
41
42impl<F: ClientFacts, P: ClientTransport> Clone for ConnPool<F, P> {
43 fn clone(&self) -> Self {
44 Self { tx_async: self.tx_async.clone(), tx: self.tx.clone(), inner: self.inner.clone() }
45 }
46}
47
48struct ConnPoolInner<F: ClientFacts, P: ClientTransport> {
49 facts: Arc<F>,
50 logger: Arc<LogFilter>,
51 rx: MAsyncRx<mpmc::Array<F::Task>>,
52 addr: String,
53 conn_id: String,
54 is_ok: AtomicBool,
56 worker_count: AtomicUsize,
58 connected_worker_count: AtomicUsize,
60 _phan: PhantomData<fn(&P)>,
63}
64
65const ONE_SEC: Duration = Duration::from_secs(1);
66
67impl<F: ClientFacts, P: ClientTransport> ConnPool<F, P> {
68 pub fn new(facts: Arc<F>, rt: &P::RT, addr: &str, mut channel_size: usize) -> Self {
69 let config = facts.get_config();
70 if config.thresholds > 0 {
71 if channel_size < config.thresholds {
72 channel_size = config.thresholds;
73 }
74 } else if channel_size == 0 {
75 channel_size = 128;
76 }
77 let (tx_async, rx) = mpmc::bounded_async(channel_size);
78 let tx = tx_async.clone().into();
79 let conn_id = format!("to {}", addr);
80 let inner = Arc::new(ConnPoolInner {
81 logger: facts.new_logger(),
82 facts: facts.clone(),
83 rx,
84 addr: addr.to_string(),
85 conn_id,
86 is_ok: AtomicBool::new(true),
87 worker_count: AtomicUsize::new(0),
88 connected_worker_count: AtomicUsize::new(0),
89 _phan: Default::default(),
90 });
91 let s = Self { tx_async, tx, inner };
92 s.spawn(rt);
93 s
94 }
95
96 #[inline(always)]
97 pub fn is_healthy(&self) -> bool {
98 self.inner.is_ok.load(Relaxed)
99 }
100
101 #[inline]
102 pub fn get_addr(&self) -> &str {
103 &self.inner.addr
104 }
105
106 #[inline]
107 pub async fn send_req(&self, task: F::Task) {
108 ClientCaller::send_req(self, task).await;
109 }
110
111 #[inline]
112 pub fn send_req_blocking(&self, task: F::Task) {
113 ClientCallerBlocking::send_req_blocking(self, task);
114 }
115
116 #[inline]
119 pub fn spawn(&self, rt: &P::RT) {
120 let worker_id = self.inner.worker_count.fetch_add(1, Acquire);
121 self.inner.clone().spawn_worker(rt, worker_id);
122 }
123}
124
125impl<F: ClientFacts, P: ClientTransport> Drop for ConnPoolInner<F, P> {
126 fn drop(&mut self) {
127 self.cleanup();
128 logger_trace!(self.logger, "{} dropped", self);
129 }
130}
131
132impl<F: ClientFacts, P: ClientTransport> ClientCaller for ConnPool<F, P> {
133 type Facts = F;
134 #[inline]
135 async fn send_req(&self, task: F::Task) {
136 self.tx_async.send(task).await.expect("submit");
137 }
138}
139
140impl<F: ClientFacts, P: ClientTransport> ClientCallerBlocking for ConnPool<F, P> {
141 type Facts = F;
142 #[inline]
143 fn send_req_blocking(&self, task: F::Task) {
144 self.tx.send(task).expect("submit");
145 }
146}
147
148impl<F: ClientFacts, P: ClientTransport> fmt::Display for ConnPoolInner<F, P> {
149 #[inline]
150 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
151 write!(f, "ConnPool {}", self.conn_id)
152 }
153}
154
155impl<F: ClientFacts, P: ClientTransport> ConnPoolInner<F, P> {
156 fn spawn_worker(self: Arc<Self>, rt: &P::RT, worker_id: usize) {
157 let _rt = rt.clone();
158 rt.spawn_detach(async move {
159 logger_trace!(&self.logger, "{} worker_id={} running", self, worker_id);
160 self.run(_rt, worker_id).await;
161 self.worker_count.fetch_sub(1, SeqCst);
162 logger_trace!(&self.logger, "{} worker_id={} exit", self, worker_id);
163 });
164 }
165
166 #[inline(always)]
167 fn get_workers(&self) -> usize {
168 self.worker_count.load(SeqCst)
169 }
170
171 #[inline(always)]
177 fn set_err(&self) {
178 self.is_ok.store(false, SeqCst);
179 }
180
181 #[inline]
182 async fn connect(&self, rt: &P::RT) -> Result<ClientStream<F, P>, RpcIntErr> {
183 ClientStream::connect(self.facts.clone(), rt, &self.addr, &self.conn_id, None).await
184 }
185
186 #[inline(always)]
187 async fn _run_worker(
188 &self, _worker_id: usize, stream: &mut ClientStream<F, P>,
189 ) -> Result<(), RpcIntErr> {
190 loop {
191 match self.rx.recv().await {
192 Ok(task) => {
193 stream.send_task(task, false).await?;
194 while let Ok(task) = self.rx.try_recv() {
195 stream.send_task(task, false).await?;
196 }
197 stream.flush_req().await?;
198 }
199 Err(_) => {
200 stream.flush_req().await?;
201 return Ok(());
202 }
203 }
204 }
205 }
206
207 async fn run_worker(
208 &self, worker_id: usize, stream: &mut ClientStream<F, P>,
209 ) -> Result<(), RpcIntErr> {
210 self.connected_worker_count.fetch_add(1, Acquire);
211 let r = self._run_worker(worker_id, stream).await;
212 logger_trace!(self.logger, "{} worker {} exit: {}", self, worker_id, r.is_ok());
213 self.connected_worker_count.fetch_add(1, Release);
214 r
215 }
216
217 async fn run(self: &Arc<Self>, rt: P::RT, mut worker_id: usize) {
222 'CONN_LOOP: loop {
223 match self.connect(&rt).await {
224 Ok(mut stream) => {
225 logger_trace!(self.logger, "{} worker={} connected", self, worker_id);
226 if worker_id == 0 {
227 'MONITOR: loop {
229 if self.get_workers() > 1 {
230 <P::RT as AsyncTime>::sleep(ONE_SEC).await;
231 if stream.ping().await.is_err() {
232 self.set_err();
233 continue 'CONN_LOOP;
235 }
236 } else {
237 match self
238 .rx
239 .recv_with_timer(<P::RT as AsyncTime>::sleep(ONE_SEC))
240 .await
241 {
242 Err(RecvTimeoutError::Disconnected) => {
243 return;
244 }
245 Err(RecvTimeoutError::Timeout) => {
246 if stream.ping().await.is_err() {
247 self.set_err();
248 self.cleanup();
249 continue 'CONN_LOOP;
250 }
251 }
252 Ok(task) => {
253 if stream.get_inflight_count() > 0
254 && self.get_workers() == 1
255 && self
256 .worker_count
257 .compare_exchange(1, 2, SeqCst, Relaxed)
258 .is_ok()
259 {
260 worker_id = 1;
263 self.clone().spawn_worker(&rt, 0);
264 }
265 if stream.send_task(task, true).await.is_err() {
266 self.set_err();
267 if worker_id == 0 {
268 self.cleanup();
269 <P::RT as AsyncTime>::sleep(ONE_SEC).await;
270 continue 'CONN_LOOP;
271 } else {
272 return;
273 }
274 } else if worker_id > 0 {
275 logger_trace!(
276 self.logger,
277 "{} worker={} break monitor",
278 self,
279 worker_id
280 );
281 break 'MONITOR;
283 }
284 }
285 }
286 }
287 }
288 }
289 if worker_id > 0 {
290 if self.run_worker(worker_id, &mut stream).await.is_err() {
291 self.set_err();
292 }
294 return;
296 }
297 }
298 Err(e) => {
299 self.set_err();
300 error!("connect failed to {}: {}", self.addr, e);
301 self.cleanup();
302 <P::RT as AsyncTime>::sleep(ONE_SEC).await;
303 }
304 }
305 }
306 }
307
308 fn cleanup(&self) {
309 while let Ok(mut task) = self.rx.try_recv() {
310 task.set_rpc_error(RpcIntErr::Unreachable);
311 logger_trace!(self.logger, "{} set task err due not not healthy", self);
312 self.facts.error_handle(task);
313 }
314 }
315}