1use crate::client::stream::ClientStream;
2use crate::client::{
3 ClientCaller, ClientCallerBlocking, ClientFacts, ClientTransport, task::ClientTaskDone,
4};
5use crate::error::RpcIntErr;
6use captains_log::filter::LogFilter;
7use crossfire::{MAsyncRx, MAsyncTx, MTx, RecvTimeoutError, mpmc};
8use orb::prelude::{AsyncExec, AsyncRuntime, AsyncTime};
9use std::fmt;
10use std::marker::PhantomData;
11use std::sync::Arc;
12use std::sync::atomic::{
13 AtomicBool, AtomicUsize,
14 Ordering::{Acquire, Relaxed, Release, SeqCst},
15};
16use std::time::Duration;
17
18pub struct ConnPool<F: ClientFacts, P: ClientTransport> {
37 tx_async: MAsyncTx<mpmc::Array<F::Task>>,
38 tx: MTx<mpmc::Array<F::Task>>,
39 inner: Arc<ConnPoolInner<F, P>>,
40}
41
42impl<F: ClientFacts, P: ClientTransport> Clone for ConnPool<F, P> {
43 fn clone(&self) -> Self {
44 Self { tx_async: self.tx_async.clone(), tx: self.tx.clone(), inner: self.inner.clone() }
45 }
46}
47
48struct ConnPoolInner<F: ClientFacts, P: ClientTransport> {
49 facts: Arc<F>,
50 logger: Arc<LogFilter>,
51 rx: MAsyncRx<mpmc::Array<F::Task>>,
52 addr: String,
53 conn_id: String,
54 is_ok: AtomicBool,
56 worker_count: AtomicUsize,
58 connected_worker_count: AtomicUsize,
60 _phan: PhantomData<fn(&P)>,
63}
64
65const ONE_SEC: Duration = Duration::from_secs(1);
66
67impl<F: ClientFacts, P: ClientTransport> ConnPool<F, P> {
68 pub fn new(
73 facts: Arc<F>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, addr: &str,
74 mut channel_size: usize,
75 ) -> Self {
76 let config = facts.get_config();
77 if config.thresholds > 0 {
78 if channel_size < config.thresholds {
79 channel_size = config.thresholds;
80 }
81 } else if channel_size == 0 {
82 channel_size = 128;
83 }
84 let (tx_async, rx) = mpmc::bounded_async(channel_size);
85 let tx = tx_async.clone().into();
86 let conn_id = format!("to {}", addr);
87 let inner = Arc::new(ConnPoolInner {
88 logger: facts.new_logger(),
89 facts: facts.clone(),
90 rx,
91 addr: addr.to_string(),
92 conn_id,
93 is_ok: AtomicBool::new(true),
94 worker_count: AtomicUsize::new(0),
95 connected_worker_count: AtomicUsize::new(0),
96 _phan: Default::default(),
97 });
98 let s = Self { tx_async, tx, inner };
99 s.spawn(rt);
100 s
101 }
102
103 #[inline(always)]
104 pub fn is_healthy(&self) -> bool {
105 self.inner.is_ok.load(Relaxed)
106 }
107
108 #[inline]
109 pub fn get_addr(&self) -> &str {
110 &self.inner.addr
111 }
112
113 #[inline]
114 pub async fn send_req(&self, task: F::Task) {
115 ClientCaller::send_req(self, task).await;
116 }
117
118 #[inline]
119 pub fn send_req_blocking(&self, task: F::Task) {
120 ClientCallerBlocking::send_req_blocking(self, task);
121 }
122
123 #[inline]
126 pub fn spawn(&self, rt: Option<&<P::RT as AsyncRuntime>::Exec>) {
127 let worker_id = self.inner.worker_count.fetch_add(1, Acquire);
128 self.inner.clone().spawn_worker(rt, worker_id);
129 }
130}
131
132impl<F: ClientFacts, P: ClientTransport> Drop for ConnPoolInner<F, P> {
133 fn drop(&mut self) {
134 self.cleanup();
135 logger_trace!(self.logger, "{} dropped", self);
136 }
137}
138
139impl<F: ClientFacts, P: ClientTransport> ClientCaller for ConnPool<F, P> {
140 type Facts = F;
141 #[inline]
142 async fn send_req(&self, task: F::Task) {
143 self.tx_async.send(task).await.expect("submit");
144 }
145}
146
147impl<F: ClientFacts, P: ClientTransport> ClientCallerBlocking for ConnPool<F, P> {
148 type Facts = F;
149 #[inline]
150 fn send_req_blocking(&self, task: F::Task) {
151 self.tx.send(task).expect("submit");
152 }
153}
154
155impl<F: ClientFacts, P: ClientTransport> fmt::Display for ConnPoolInner<F, P> {
156 #[inline]
157 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
158 write!(f, "ConnPool {}", self.conn_id)
159 }
160}
161
162impl<F: ClientFacts, P: ClientTransport> ConnPoolInner<F, P> {
163 #[inline]
164 fn spawn_worker(self: Arc<Self>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, worker_id: usize) {
165 let f = async move {
166 logger_trace!(&self.logger, "{} worker_id={} running", self, worker_id);
167 self.run(worker_id).await;
168 self.worker_count.fetch_sub(1, SeqCst);
169 logger_trace!(&self.logger, "{} worker_id={} exit", self, worker_id);
170 };
171 if let Some(_rt) = rt {
172 _rt.spawn_detach(f);
173 } else {
174 P::RT::spawn_detach(f);
175 }
176 }
177
178 #[inline(always)]
179 fn get_workers(&self) -> usize {
180 self.worker_count.load(SeqCst)
181 }
182
183 #[inline(always)]
189 fn set_err(&self) {
190 self.is_ok.store(false, SeqCst);
191 }
192
193 #[inline]
194 async fn connect(&self) -> Result<ClientStream<F, P>, RpcIntErr> {
195 ClientStream::connect(self.facts.clone(), None, &self.addr, &self.conn_id, None).await
196 }
197
198 #[inline(always)]
199 async fn _run_worker(
200 &self, _worker_id: usize, stream: &mut ClientStream<F, P>,
201 ) -> Result<(), RpcIntErr> {
202 loop {
203 match self.rx.recv().await {
204 Ok(task) => {
205 stream.send_task(task, false).await?;
206 while let Ok(task) = self.rx.try_recv() {
207 stream.send_task(task, false).await?;
208 }
209 stream.flush_req().await?;
210 }
211 Err(_) => {
212 stream.flush_req().await?;
213 return Ok(());
214 }
215 }
216 }
217 }
218
219 async fn run_worker(
220 &self, worker_id: usize, stream: &mut ClientStream<F, P>,
221 ) -> Result<(), RpcIntErr> {
222 self.connected_worker_count.fetch_add(1, Acquire);
223 let r = self._run_worker(worker_id, stream).await;
224 logger_trace!(self.logger, "{} worker {} exit: {}", self, worker_id, r.is_ok());
225 self.connected_worker_count.fetch_add(1, Release);
226 r
227 }
228
229 async fn run(self: &Arc<Self>, mut worker_id: usize) {
234 'CONN_LOOP: loop {
235 match self.connect().await {
236 Ok(mut stream) => {
237 logger_trace!(self.logger, "{} worker={} connected", self, worker_id);
238 if worker_id == 0 {
239 'MONITOR: loop {
241 if self.get_workers() > 1 {
242 <P::RT as AsyncTime>::sleep(ONE_SEC).await;
243 if stream.ping().await.is_err() {
244 self.set_err();
245 continue 'CONN_LOOP;
247 }
248 } else {
249 match self
250 .rx
251 .recv_with_timer(<P::RT as AsyncTime>::sleep(ONE_SEC))
252 .await
253 {
254 Err(RecvTimeoutError::Disconnected) => {
255 return;
256 }
257 Err(RecvTimeoutError::Timeout) => {
258 if stream.ping().await.is_err() {
259 self.set_err();
260 self.cleanup();
261 continue 'CONN_LOOP;
262 }
263 }
264 Ok(task) => {
265 if stream.get_inflight_count() > 0
266 && self.get_workers() == 1
267 && self
268 .worker_count
269 .compare_exchange(1, 2, SeqCst, Relaxed)
270 .is_ok()
271 {
272 worker_id = 1;
275 self.clone().spawn_worker(None, 0);
276 }
277 if stream.send_task(task, true).await.is_err() {
278 self.set_err();
279 if worker_id == 0 {
280 self.cleanup();
281 <P::RT as AsyncTime>::sleep(ONE_SEC).await;
282 continue 'CONN_LOOP;
283 } else {
284 return;
285 }
286 } else if worker_id > 0 {
287 logger_trace!(
288 self.logger,
289 "{} worker={} break monitor",
290 self,
291 worker_id
292 );
293 break 'MONITOR;
295 }
296 }
297 }
298 }
299 }
300 }
301 if worker_id > 0 {
302 if self.run_worker(worker_id, &mut stream).await.is_err() {
303 self.set_err();
304 }
306 return;
308 }
309 }
310 Err(e) => {
311 self.set_err();
312 error!("connect failed to {}: {}", self.addr, e);
313 self.cleanup();
314 <P::RT as AsyncTime>::sleep(ONE_SEC).await;
315 }
316 }
317 }
318 }
319
320 fn cleanup(&self) {
321 while let Ok(mut task) = self.rx.try_recv() {
322 task.set_rpc_error(RpcIntErr::Unreachable);
323 logger_trace!(self.logger, "{} set task err due not not healthy", self);
324 self.facts.error_handle(task);
325 }
326 }
327}