1use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::null::CloseHandle;
16use futures_util::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20 cell::UnsafeCell,
21 fmt,
22 future::Future,
23 mem::transmute,
24 pin::Pin,
25 sync::{
26 Arc,
27 atomic::{AtomicBool, AtomicU64, Ordering},
28 },
29 task::{Context, Poll},
30};
31
32pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
41 close_tx: Option<CloseHandle<mpsc::Null>>,
42 inner: Arc<ClientStreamInner<F, P>>,
43}
44
45impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
46 #[inline]
48 pub async fn connect(
49 facts: Arc<F>, addr: &str, conn_id: &str, last_resp_ts: Option<Arc<AtomicU64>>,
50 ) -> Result<Self, RpcIntErr> {
51 let client_id = facts.get_client_id();
52 let conn = P::connect(addr, conn_id, facts.get_config()).await?;
53 Ok(Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts))
54 }
55
56 #[inline]
57 fn new(
58 facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
59 last_resp_ts: Option<Arc<AtomicU64>>,
60 ) -> Self {
61 let (_close_tx, _close_rx) = mpsc::new::<mpsc::Null, _, _>();
62 let inner = Arc::new(ClientStreamInner::new(
63 facts,
64 conn,
65 client_id,
66 conn_id,
67 _close_rx,
68 last_resp_ts,
69 ));
70 logger_debug!(inner.logger, "{:?} connected", inner);
71 let _inner = inner.clone();
72 inner.facts.spawn_detach(async move {
73 _inner.receive_loop().await;
74 });
75 Self { close_tx: Some(_close_tx), inner }
76 }
77
78 #[inline]
79 pub fn get_codec(&self) -> &F::Codec {
80 &self.inner.codec
81 }
82
83 #[inline(always)]
87 pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
88 self.inner.send_ping_req().await
89 }
90
91 #[inline(always)]
92 pub fn get_last_resp_ts(&self) -> u64 {
93 if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
94 }
95
96 #[inline(always)]
98 pub fn is_closed(&self) -> bool {
99 self.inner.closed.load(Ordering::SeqCst)
100 }
101
102 pub async fn set_error_and_exit(&mut self) {
107 self.inner.has_err.store(true, Ordering::SeqCst);
109 self.inner.conn.close_conn::<F>(&self.inner.logger).await;
110 }
111
112 #[inline(always)]
120 pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
121 self.inner.send_task(task, need_flush).await
122 }
123
124 #[inline(always)]
127 pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
128 self.inner.flush_req().await
129 }
130
131 #[inline]
133 pub fn will_block(&self) -> bool {
134 self.inner.throttler.nearly_full()
135 }
136
137 #[inline]
139 pub fn get_inflight_count(&self) -> usize {
140 self.inner.throttler.get_inflight_count()
142 }
143}
144
145impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
146 fn drop(&mut self) {
147 self.close_tx.take();
148 let timer = self.inner.get_timer_mut();
149 timer.stop_reg_task();
150 self.inner.closed.store(true, Ordering::SeqCst);
151 }
152}
153
154impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
155 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156 self.inner.fmt(f)
157 }
158}
159
160struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
161 client_id: u64,
162 conn: P,
163 seq: AtomicU64,
164 close_rx: UnsafeCell<AsyncRx<mpsc::Null>>,
167 closed: AtomicBool, timer: UnsafeCell<ClientTaskTimer<F>>,
169 has_err: AtomicBool,
171 throttler: Throttler,
172 last_resp_ts: Option<Arc<AtomicU64>>,
173 encode_buf: UnsafeCell<Vec<u8>>,
174 codec: F::Codec,
175 logger: Arc<LogFilter>,
176 facts: Arc<F>,
177}
178
179unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
180
181unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
182
183impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
184 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
185 self.conn.fmt(f)
186 }
187}
188
189impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
190 pub fn new(
191 facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: AsyncRx<mpsc::Null>,
192 last_resp_ts: Option<Arc<AtomicU64>>,
193 ) -> Self {
194 let config = facts.get_config();
195 let mut thresholds = config.thresholds;
196 if thresholds == 0 {
197 thresholds = 128;
198 }
199 let client_inner = Self {
200 client_id,
201 conn,
202 close_rx: UnsafeCell::new(close_rx),
203 closed: AtomicBool::new(false),
204 seq: AtomicU64::new(1),
205 encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
206 throttler: Throttler::new(thresholds),
207 last_resp_ts,
208 has_err: AtomicBool::new(false),
209 codec: F::Codec::default(),
210 logger: facts.new_logger(),
211 timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
212 facts,
213 };
214 logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
215 client_inner
216 }
217
218 #[inline(always)]
219 fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
220 unsafe { transmute(self.timer.get()) }
221 }
222
223 #[inline(always)]
224 fn get_close_rx(&self) -> &mut AsyncRx<mpsc::Null> {
225 unsafe { transmute(self.close_rx.get()) }
226 }
227
228 #[inline(always)]
229 fn get_encoded_buf(&self) -> &mut Vec<u8> {
230 unsafe { transmute(self.encode_buf.get()) }
231 }
232
233 async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
235 if self.throttler.nearly_full() {
236 need_flush = true;
237 }
238 let timer = self.get_timer_mut();
239 timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
240 if self.closed.load(Ordering::Acquire) {
243 logger_warn!(
244 self.logger,
245 "{:?} sending task {:?} failed: {}",
246 self,
247 task,
248 RpcIntErr::IO,
249 );
250 task.set_rpc_error(RpcIntErr::IO);
251 self.facts.error_handle(task);
252 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); return Err(RpcIntErr::IO);
254 }
255 match self.send_request(task, need_flush).await {
256 Err(_) => {
257 self.closed.store(true, Ordering::SeqCst);
258 self.has_err.store(true, Ordering::SeqCst);
259 return Err(RpcIntErr::IO);
260 }
261 Ok(_) => {
262 self.throttler.throttle().await;
264 return Ok(());
265 }
266 }
267 }
268
269 #[inline(always)]
270 async fn flush_req(&self) -> Result<(), RpcIntErr> {
271 if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
272 logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
273 self.closed.store(true, Ordering::SeqCst);
274 self.has_err.store(true, Ordering::SeqCst);
275 let timer = self.get_timer_mut();
276 timer.stop_reg_task();
277 return Err(RpcIntErr::IO);
278 }
279 Ok(())
280 }
281
282 #[inline(always)]
283 async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
284 let seq = self.seq_update();
285 task.set_seq(seq);
286 let buf = self.get_encoded_buf();
287 match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
288 Err(_) => {
289 logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
290 return Err(RpcIntErr::Encode);
291 }
292 Ok(blob_buf) => {
293 if let Err(e) =
294 self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
295 {
296 logger_warn!(
297 self.logger,
298 "{:?} send_req write req {:?} err: {:?}",
299 self,
300 task,
301 e
302 );
303 self.closed.store(true, Ordering::SeqCst);
304 self.has_err.store(true, Ordering::SeqCst);
305 let timer = self.get_timer_mut();
306 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
309 timer.stop_reg_task();
310 logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
311 task.set_rpc_error(RpcIntErr::IO);
312 self.facts.error_handle(task);
313 return Err(RpcIntErr::IO);
314 } else {
315 let wg = self.throttler.add_task();
316 let timer = self.get_timer_mut();
317 logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
318 timer.reg_task(task, wg).await;
319 }
320 return Ok(());
321 }
322 }
323 }
324
325 #[inline(always)]
326 async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
327 if self.closed.load(Ordering::Acquire) {
328 logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
329 return Err(RpcIntErr::IO);
330 }
331 let buf = self.get_encoded_buf();
333 proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
334 if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
337 logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
338 self.closed.store(true, Ordering::SeqCst);
339 return Err(RpcIntErr::IO);
340 }
341 Ok(())
342 }
343
344 async fn recv_some(&self) -> Result<(), RpcIntErr> {
346 for _ in 0i32..20 {
347 match self.recv_one_resp().await {
350 Err(e) => {
351 return Err(e);
352 }
353 Ok(_) => {
354 if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
355 last_resp_ts.store(self.facts.get_timestamp(), Ordering::Release);
356 }
357 }
358 }
359 }
360 Ok(())
361 }
362
363 async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
364 let timer = self.get_timer_mut();
365 loop {
366 if self.closed.load(Ordering::Acquire) {
367 logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
368 if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
370 return Err(RpcIntErr::IO);
371 }
372 if let Err(_e) = self
374 .conn
375 .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
376 .await
377 {
378 self.closed.store(true, Ordering::SeqCst);
379 return Err(RpcIntErr::IO);
380 }
381 } else {
382 match self
386 .conn
387 .read_resp(
388 self.facts.as_ref(),
389 &self.logger,
390 &self.codec,
391 Some(self.get_close_rx()),
392 timer,
393 )
394 .await
395 {
396 Err(_e) => {
397 return Err(RpcIntErr::IO);
398 }
399 Ok(r) => {
400 if !r {
402 self.closed.store(true, Ordering::SeqCst);
403 continue;
404 }
405 }
406 }
407 }
408 }
409 }
410
411 async fn receive_loop(&self) {
412 let mut tick = <F as AsyncTime>::interval(Duration::from_secs(1));
413 loop {
414 let f = self.recv_some();
415 pin_mut!(f);
416 let selector = ReceiverTimerFuture::new(self, &mut tick, &mut f);
417 match selector.await {
418 Ok(_) => {}
419 Err(e) => {
420 logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
421 self.closed.store(true, Ordering::SeqCst);
422 let timer = self.get_timer_mut();
423 timer.clean_pending_tasks(self.facts.as_ref());
424 while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
426 timer.clean_pending_tasks(self.facts.as_ref());
430 <F as AsyncTime>::sleep(Duration::from_secs(1)).await;
431 }
432 return;
433 }
434 }
435 }
436 }
437
438 fn time_reach(&self) {
440 logger_trace!(
441 self.logger,
442 "{:?} has {} pending_tasks",
443 self,
444 self.throttler.get_inflight_count()
445 );
446 let timer = self.get_timer_mut();
447 timer.adjust_task_queue(self.facts.as_ref());
448 return;
449 }
450
451 #[inline(always)]
452 fn seq_update(&self) -> u64 {
453 self.seq.fetch_add(1, Ordering::SeqCst)
454 }
455}
456
457impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
458 fn drop(&mut self) {
459 let timer = self.get_timer_mut();
460 timer.clean_pending_tasks(self.facts.as_ref());
461 }
462}
463
464struct ReceiverTimerFuture<'a, F, P, I, FR>
465where
466 F: ClientFacts,
467 P: ClientTransport,
468 I: TimeInterval,
469 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
470{
471 client: &'a ClientStreamInner<F, P>,
472 inv: Pin<&'a mut I>,
473 recv_future: Pin<&'a mut FR>,
474}
475
476impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
477where
478 F: ClientFacts,
479 P: ClientTransport,
480 I: TimeInterval,
481 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
482{
483 fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
484 Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
485 }
486}
487
488impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
492where
493 F: ClientFacts,
494 P: ClientTransport,
495 I: TimeInterval,
496 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
497{
498 type Output = Result<(), RpcIntErr>;
499
500 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
501 let mut _self = self.get_mut();
502 while _self.inv.as_mut().poll_tick(ctx).is_ready() {
504 _self.client.time_reach();
505 }
506 if _self.client.has_err.load(Ordering::Relaxed) {
507 return Poll::Ready(Err(RpcIntErr::IO));
510 }
511 _self.client.get_timer_mut().poll_sent_task(ctx);
512 if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
514 return Poll::Ready(r);
515 }
516 return Poll::Pending;
517 }
518}