1use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::null::CloseHandle;
16use futures_util::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20 cell::UnsafeCell,
21 fmt,
22 future::Future,
23 mem::transmute,
24 pin::Pin,
25 sync::{
26 Arc,
27 atomic::{AtomicBool, AtomicU64, Ordering},
28 },
29 task::{Context, Poll},
30};
31
32pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
41 close_tx: Option<CloseHandle<mpsc::Null>>,
42 inner: Arc<ClientStreamInner<F, P>>,
43}
44
45impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
46 #[inline]
48 pub async fn connect(
49 facts: Arc<F>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, addr: &str, conn_id: &str,
50 last_resp_ts: Option<Arc<AtomicU64>>,
51 ) -> Result<Self, RpcIntErr> {
52 let client_id = facts.get_client_id();
53 let conn = P::connect(addr, conn_id, facts.get_config()).await?;
54 let this = Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts);
55 let inner = this.inner.clone();
56 let f = inner.receive_loop();
57 if let Some(_rt) = rt {
58 _rt.spawn_detach(f);
59 } else {
60 P::RT::spawn_detach(f);
61 }
62 Ok(this)
63 }
64
65 #[inline]
66 fn new(
67 facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
68 last_resp_ts: Option<Arc<AtomicU64>>,
69 ) -> Self {
70 let (_close_tx, _close_rx) = mpsc::new::<mpsc::Null, _, _>();
71 let inner = Arc::new(ClientStreamInner::new(
72 facts,
73 conn,
74 client_id,
75 conn_id,
76 _close_rx,
77 last_resp_ts,
78 ));
79 logger_debug!(inner.logger, "{:?} connected", inner);
80 Self { close_tx: Some(_close_tx), inner }
81 }
82
83 #[inline]
84 pub fn get_codec(&self) -> &F::Codec {
85 &self.inner.codec
86 }
87
88 #[inline(always)]
92 pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
93 self.inner.send_ping_req().await
94 }
95
96 #[inline(always)]
97 pub fn get_last_resp_ts(&self) -> u64 {
98 if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
99 }
100
101 #[inline(always)]
103 pub fn is_closed(&self) -> bool {
104 self.inner.closed.load(Ordering::SeqCst)
105 }
106
107 pub async fn set_error_and_exit(&mut self) {
112 self.inner.has_err.store(true, Ordering::SeqCst);
114 self.inner.conn.close_conn::<F>(&self.inner.logger).await;
115 }
116
117 #[inline(always)]
125 pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
126 self.inner.send_task(task, need_flush).await
127 }
128
129 #[inline(always)]
132 pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
133 self.inner.flush_req().await
134 }
135
136 #[inline]
138 pub fn will_block(&self) -> bool {
139 self.inner.throttler.nearly_full()
140 }
141
142 #[inline]
144 pub fn get_inflight_count(&self) -> usize {
145 self.inner.throttler.get_inflight_count()
147 }
148}
149
150impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
151 fn drop(&mut self) {
152 self.close_tx.take();
153 let timer = self.inner.get_timer_mut();
154 timer.stop_reg_task();
155 self.inner.closed.store(true, Ordering::SeqCst);
156 }
157}
158
159impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
160 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
161 self.inner.fmt(f)
162 }
163}
164
165struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
166 client_id: u64,
167 conn: P,
168 seq: AtomicU64,
169 close_rx: UnsafeCell<AsyncRx<mpsc::Null>>,
172 closed: AtomicBool, timer: UnsafeCell<ClientTaskTimer<F>>,
174 has_err: AtomicBool,
176 throttler: Throttler,
177 last_resp_ts: Option<Arc<AtomicU64>>,
178 encode_buf: UnsafeCell<Vec<u8>>,
179 codec: F::Codec,
180 logger: Arc<LogFilter>,
181 facts: Arc<F>,
182}
183
184unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
185
186unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
187
188impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190 self.conn.fmt(f)
191 }
192}
193
194impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
195 pub fn new(
196 facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: AsyncRx<mpsc::Null>,
197 last_resp_ts: Option<Arc<AtomicU64>>,
198 ) -> Self {
199 let config = facts.get_config();
200 let mut thresholds = config.thresholds;
201 if thresholds == 0 {
202 thresholds = 128;
203 }
204 let client_inner = Self {
205 client_id,
206 conn,
207 close_rx: UnsafeCell::new(close_rx),
208 closed: AtomicBool::new(false),
209 seq: AtomicU64::new(1),
210 encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
211 throttler: Throttler::new(thresholds),
212 last_resp_ts,
213 has_err: AtomicBool::new(false),
214 codec: F::Codec::default(),
215 logger: facts.new_logger(),
216 timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
217 facts,
218 };
219 logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
220 client_inner
221 }
222
223 #[inline(always)]
224 fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
225 unsafe { transmute(self.timer.get()) }
226 }
227
228 #[inline(always)]
229 fn get_close_rx(&self) -> &mut AsyncRx<mpsc::Null> {
230 unsafe { transmute(self.close_rx.get()) }
231 }
232
233 #[inline(always)]
234 fn get_encoded_buf(&self) -> &mut Vec<u8> {
235 unsafe { transmute(self.encode_buf.get()) }
236 }
237
238 async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
240 if self.throttler.nearly_full() {
241 need_flush = true;
242 }
243 let timer = self.get_timer_mut();
244 timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
245 if self.closed.load(Ordering::Acquire) {
248 logger_warn!(
249 self.logger,
250 "{:?} sending task {:?} failed: {}",
251 self,
252 task,
253 RpcIntErr::IO,
254 );
255 task.set_rpc_error(RpcIntErr::IO);
256 self.facts.error_handle(task);
257 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); return Err(RpcIntErr::IO);
259 }
260 match self.send_request(task, need_flush).await {
261 Err(_) => {
262 self.closed.store(true, Ordering::SeqCst);
263 self.has_err.store(true, Ordering::SeqCst);
264 return Err(RpcIntErr::IO);
265 }
266 Ok(_) => {
267 self.throttler.throttle().await;
269 return Ok(());
270 }
271 }
272 }
273
274 #[inline(always)]
275 async fn flush_req(&self) -> Result<(), RpcIntErr> {
276 if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
277 logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
278 self.closed.store(true, Ordering::SeqCst);
279 self.has_err.store(true, Ordering::SeqCst);
280 let timer = self.get_timer_mut();
281 timer.stop_reg_task();
282 return Err(RpcIntErr::IO);
283 }
284 Ok(())
285 }
286
287 #[inline(always)]
288 async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
289 let seq = self.seq_update();
290 task.set_seq(seq);
291 let buf = self.get_encoded_buf();
292 match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
293 Err(_) => {
294 logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
295 return Err(RpcIntErr::Encode);
296 }
297 Ok(blob_buf) => {
298 if let Err(e) =
299 self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
300 {
301 logger_warn!(
302 self.logger,
303 "{:?} send_req write req {:?} err: {:?}",
304 self,
305 task,
306 e
307 );
308 self.closed.store(true, Ordering::SeqCst);
309 self.has_err.store(true, Ordering::SeqCst);
310 let timer = self.get_timer_mut();
311 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
314 timer.stop_reg_task();
315 logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
316 task.set_rpc_error(RpcIntErr::IO);
317 self.facts.error_handle(task);
318 return Err(RpcIntErr::IO);
319 } else {
320 let wg = self.throttler.add_task();
321 let timer = self.get_timer_mut();
322 logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
323 timer.reg_task(task, wg).await;
324 }
325 return Ok(());
326 }
327 }
328 }
329
330 #[inline(always)]
331 async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
332 if self.closed.load(Ordering::Acquire) {
333 logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
334 return Err(RpcIntErr::IO);
335 }
336 let buf = self.get_encoded_buf();
338 proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
339 if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
342 logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
343 self.closed.store(true, Ordering::SeqCst);
344 return Err(RpcIntErr::IO);
345 }
346 Ok(())
347 }
348
349 async fn recv_some(&self) -> Result<(), RpcIntErr> {
351 for _ in 0i32..20 {
352 match self.recv_one_resp().await {
355 Err(e) => {
356 return Err(e);
357 }
358 Ok(_) => {
359 if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
360 last_resp_ts.store(self.facts.get_timestamp(), Ordering::Release);
361 }
362 }
363 }
364 }
365 Ok(())
366 }
367
368 async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
369 let timer = self.get_timer_mut();
370 loop {
371 if self.closed.load(Ordering::Acquire) {
372 logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
373 if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
375 return Err(RpcIntErr::IO);
376 }
377 if let Err(_e) = self
379 .conn
380 .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
381 .await
382 {
383 self.closed.store(true, Ordering::SeqCst);
384 return Err(RpcIntErr::IO);
385 }
386 } else {
387 match self
391 .conn
392 .read_resp(
393 self.facts.as_ref(),
394 &self.logger,
395 &self.codec,
396 Some(self.get_close_rx()),
397 timer,
398 )
399 .await
400 {
401 Err(_e) => {
402 return Err(RpcIntErr::IO);
403 }
404 Ok(r) => {
405 if !r {
407 self.closed.store(true, Ordering::SeqCst);
408 continue;
409 }
410 }
411 }
412 }
413 }
414 }
415
416 async fn receive_loop(self: Arc<Self>) {
417 let mut tick = <P::RT as AsyncTime>::interval(Duration::from_secs(1));
418 loop {
419 let f = self.recv_some();
420 pin_mut!(f);
421 let selector = ReceiverTimerFuture::new(&self, &mut tick, &mut f);
422 match selector.await {
423 Ok(_) => {}
424 Err(e) => {
425 logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
426 self.closed.store(true, Ordering::SeqCst);
427 let timer = self.get_timer_mut();
428 timer.clean_pending_tasks(self.facts.as_ref());
429 while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
431 timer.clean_pending_tasks(self.facts.as_ref());
435 <P::RT as AsyncTime>::sleep(Duration::from_secs(1)).await;
436 }
437 return;
438 }
439 }
440 }
441 }
442
443 fn time_reach(&self) {
445 logger_trace!(
446 self.logger,
447 "{:?} has {} pending_tasks",
448 self,
449 self.throttler.get_inflight_count()
450 );
451 let timer = self.get_timer_mut();
452 timer.adjust_task_queue(self.facts.as_ref());
453 return;
454 }
455
456 #[inline(always)]
457 fn seq_update(&self) -> u64 {
458 self.seq.fetch_add(1, Ordering::SeqCst)
459 }
460}
461
462impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
463 fn drop(&mut self) {
464 let timer = self.get_timer_mut();
465 timer.clean_pending_tasks(self.facts.as_ref());
466 }
467}
468
469struct ReceiverTimerFuture<'a, F, P, I, FR>
470where
471 F: ClientFacts,
472 P: ClientTransport,
473 I: TimeInterval,
474 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
475{
476 client: &'a ClientStreamInner<F, P>,
477 inv: Pin<&'a mut I>,
478 recv_future: Pin<&'a mut FR>,
479}
480
481impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
482where
483 F: ClientFacts,
484 P: ClientTransport,
485 I: TimeInterval,
486 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
487{
488 fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
489 Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
490 }
491}
492
493impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
497where
498 F: ClientFacts,
499 P: ClientTransport,
500 I: TimeInterval,
501 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
502{
503 type Output = Result<(), RpcIntErr>;
504
505 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
506 let mut _self = self.get_mut();
507 while _self.inv.as_mut().poll_tick(ctx).is_ready() {
509 _self.client.time_reach();
510 }
511 if _self.client.has_err.load(Ordering::Relaxed) {
512 return Poll::Ready(Err(RpcIntErr::IO));
515 }
516 _self.client.get_timer_mut().poll_sent_task(ctx);
517 if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
519 return Poll::Ready(r);
520 }
521 return Poll::Pending;
522 }
523}