1use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::null::CloseHandle;
16use futures::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20 cell::UnsafeCell,
21 fmt,
22 future::Future,
23 mem::transmute,
24 pin::Pin,
25 sync::{
26 Arc,
27 atomic::{AtomicBool, AtomicU64, Ordering},
28 },
29 task::{Context, Poll},
30};
31use sync_utils::time::DelayedTime;
32
33pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
42 close_tx: Option<CloseHandle<mpsc::Null>>,
43 inner: Arc<ClientStreamInner<F, P>>,
44}
45
46impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
47 #[inline]
49 pub fn connect(
50 facts: Arc<F>, addr: &str, conn_id: &str, last_resp_ts: Option<Arc<AtomicU64>>,
51 ) -> impl Future<Output = Result<Self, RpcIntErr>> + Send {
52 async move {
53 let client_id = facts.get_client_id();
54 let conn = P::connect(addr, conn_id, facts.get_config()).await?;
55 Ok(Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts))
56 }
57 }
58
59 #[inline]
60 fn new(
61 facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
62 last_resp_ts: Option<Arc<AtomicU64>>,
63 ) -> Self {
64 let (_close_tx, _close_rx) = mpsc::new::<mpsc::Null, _, _>();
65 let inner = Arc::new(ClientStreamInner::new(
66 facts,
67 conn,
68 client_id,
69 conn_id,
70 _close_rx,
71 last_resp_ts,
72 ));
73 logger_debug!(inner.logger, "{:?} connected", inner);
74 let _inner = inner.clone();
75 inner.facts.spawn_detach(async move {
76 _inner.receive_loop().await;
77 });
78 Self { close_tx: Some(_close_tx), inner }
79 }
80
81 #[inline]
82 pub fn get_codec(&self) -> &F::Codec {
83 &self.inner.codec
84 }
85
86 #[inline(always)]
90 pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
91 self.inner.send_ping_req().await
92 }
93
94 #[inline(always)]
95 pub fn get_last_resp_ts(&self) -> u64 {
96 if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
97 }
98
99 #[inline(always)]
101 pub fn is_closed(&self) -> bool {
102 self.inner.closed.load(Ordering::SeqCst)
103 }
104
105 pub async fn set_error_and_exit(&mut self) {
110 self.inner.has_err.store(true, Ordering::SeqCst);
112 self.inner.conn.close_conn::<F>(&self.inner.logger).await;
113 }
114
115 #[inline(always)]
123 pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
124 self.inner.send_task(task, need_flush).await
125 }
126
127 #[inline(always)]
130 pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
131 self.inner.flush_req().await
132 }
133
134 #[inline]
136 pub fn will_block(&self) -> bool {
137 self.inner.throttler.nearly_full()
138 }
139
140 #[inline]
142 pub fn get_inflight_count(&self) -> usize {
143 self.inner.throttler.get_inflight_count()
145 }
146}
147
148impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
149 fn drop(&mut self) {
150 self.close_tx.take();
151 let timer = self.inner.get_timer_mut();
152 timer.stop_reg_task();
153 self.inner.closed.store(true, Ordering::SeqCst);
154 }
155}
156
157impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
158 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
159 self.inner.fmt(f)
160 }
161}
162
163struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
164 client_id: u64,
165 conn: P,
166 seq: AtomicU64,
167 close_rx: UnsafeCell<AsyncRx<mpsc::Null>>,
170 closed: AtomicBool, timer: UnsafeCell<ClientTaskTimer<F>>,
172 has_err: AtomicBool,
174 throttler: Throttler,
175 last_resp_ts: Option<Arc<AtomicU64>>,
176 encode_buf: UnsafeCell<Vec<u8>>,
177 codec: F::Codec,
178 logger: Arc<LogFilter>,
179 facts: Arc<F>,
180}
181
182unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
183
184unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
185
186impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
187 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
188 self.conn.fmt(f)
189 }
190}
191
192impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
193 pub fn new(
194 facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: AsyncRx<mpsc::Null>,
195 last_resp_ts: Option<Arc<AtomicU64>>,
196 ) -> Self {
197 let config = facts.get_config();
198 let mut thresholds = config.thresholds;
199 if thresholds == 0 {
200 thresholds = 128;
201 }
202 let client_inner = Self {
203 client_id,
204 conn,
205 close_rx: UnsafeCell::new(close_rx),
206 closed: AtomicBool::new(false),
207 seq: AtomicU64::new(1),
208 encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
209 throttler: Throttler::new(thresholds),
210 last_resp_ts,
211 has_err: AtomicBool::new(false),
212 codec: F::Codec::default(),
213 logger: facts.new_logger(),
214 timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
215 facts,
216 };
217 logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
218 client_inner
219 }
220
221 #[inline(always)]
222 fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
223 unsafe { transmute(self.timer.get()) }
224 }
225
226 #[inline(always)]
227 fn get_close_rx(&self) -> &mut AsyncRx<mpsc::Null> {
228 unsafe { transmute(self.close_rx.get()) }
229 }
230
231 #[inline(always)]
232 fn get_encoded_buf(&self) -> &mut Vec<u8> {
233 unsafe { transmute(self.encode_buf.get()) }
234 }
235
236 async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
238 if self.throttler.nearly_full() {
239 need_flush = true;
240 }
241 let timer = self.get_timer_mut();
242 timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
243 if self.closed.load(Ordering::Acquire) {
246 logger_warn!(
247 self.logger,
248 "{:?} sending task {:?} failed: {}",
249 self,
250 task,
251 RpcIntErr::IO,
252 );
253 task.set_rpc_error(RpcIntErr::IO);
254 self.facts.error_handle(task);
255 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); return Err(RpcIntErr::IO);
257 }
258 match self.send_request(task, need_flush).await {
259 Err(_) => {
260 self.closed.store(true, Ordering::SeqCst);
261 self.has_err.store(true, Ordering::SeqCst);
262 return Err(RpcIntErr::IO);
263 }
264 Ok(_) => {
265 self.throttler.throttle().await;
267 return Ok(());
268 }
269 }
270 }
271
272 #[inline(always)]
273 async fn flush_req(&self) -> Result<(), RpcIntErr> {
274 if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
275 logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
276 self.closed.store(true, Ordering::SeqCst);
277 self.has_err.store(true, Ordering::SeqCst);
278 let timer = self.get_timer_mut();
279 timer.stop_reg_task();
280 return Err(RpcIntErr::IO);
281 }
282 Ok(())
283 }
284
285 #[inline(always)]
286 async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
287 let seq = self.seq_update();
288 task.set_seq(seq);
289 let buf = self.get_encoded_buf();
290 match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
291 Err(_) => {
292 logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
293 return Err(RpcIntErr::Encode);
294 }
295 Ok(blob_buf) => {
296 if let Err(e) =
297 self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
298 {
299 logger_warn!(
300 self.logger,
301 "{:?} send_req write req {:?} err: {:?}",
302 self,
303 task,
304 e
305 );
306 self.closed.store(true, Ordering::SeqCst);
307 self.has_err.store(true, Ordering::SeqCst);
308 let timer = self.get_timer_mut();
309 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
312 timer.stop_reg_task();
313 logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
314 task.set_rpc_error(RpcIntErr::IO);
315 self.facts.error_handle(task);
316 return Err(RpcIntErr::IO);
317 } else {
318 let wg = self.throttler.add_task();
319 let timer = self.get_timer_mut();
320 logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
321 timer.reg_task(task, wg).await;
322 }
323 return Ok(());
324 }
325 }
326 }
327
328 #[inline(always)]
329 async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
330 if self.closed.load(Ordering::Acquire) {
331 logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
332 return Err(RpcIntErr::IO);
333 }
334 let buf = self.get_encoded_buf();
336 proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
337 if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
340 logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
341 self.closed.store(true, Ordering::SeqCst);
342 return Err(RpcIntErr::IO);
343 }
344 Ok(())
345 }
346
347 async fn recv_some(&self) -> Result<(), RpcIntErr> {
349 for _ in 0i32..20 {
350 match self.recv_one_resp().await {
353 Err(e) => {
354 return Err(e);
355 }
356 Ok(_) => {
357 if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
358 last_resp_ts.store(DelayedTime::get(), Ordering::Relaxed);
359 }
360 }
361 }
362 }
363 Ok(())
364 }
365
366 async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
367 let timer = self.get_timer_mut();
368 loop {
369 if self.closed.load(Ordering::Acquire) {
370 logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
371 if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
373 return Err(RpcIntErr::IO);
374 }
375 if let Err(_e) = self
377 .conn
378 .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
379 .await
380 {
381 self.closed.store(true, Ordering::SeqCst);
382 return Err(RpcIntErr::IO);
383 }
384 } else {
385 match self
389 .conn
390 .read_resp(
391 self.facts.as_ref(),
392 &self.logger,
393 &self.codec,
394 Some(self.get_close_rx()),
395 timer,
396 )
397 .await
398 {
399 Err(_e) => {
400 return Err(RpcIntErr::IO);
401 }
402 Ok(r) => {
403 if !r {
405 self.closed.store(true, Ordering::SeqCst);
406 continue;
407 }
408 }
409 }
410 }
411 }
412 }
413
414 async fn receive_loop(&self) {
415 let mut tick = <F as AsyncTime>::interval(Duration::from_secs(1));
416 loop {
417 let f = self.recv_some();
418 pin_mut!(f);
419 let selector = ReceiverTimerFuture::new(self, &mut tick, &mut f);
420 match selector.await {
421 Ok(_) => {}
422 Err(e) => {
423 logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
424 self.closed.store(true, Ordering::SeqCst);
425 let timer = self.get_timer_mut();
426 timer.clean_pending_tasks(self.facts.as_ref());
427 while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
429 timer.clean_pending_tasks(self.facts.as_ref());
433 <F as AsyncTime>::sleep(Duration::from_secs(1)).await;
434 }
435 return;
436 }
437 }
438 }
439 }
440
441 fn time_reach(&self) {
443 logger_trace!(
444 self.logger,
445 "{:?} has {} pending_tasks",
446 self,
447 self.throttler.get_inflight_count()
448 );
449 let timer = self.get_timer_mut();
450 timer.adjust_task_queue(self.facts.as_ref());
451 return;
452 }
453
454 #[inline(always)]
455 fn seq_update(&self) -> u64 {
456 self.seq.fetch_add(1, Ordering::SeqCst)
457 }
458}
459
460impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
461 fn drop(&mut self) {
462 let timer = self.get_timer_mut();
463 timer.clean_pending_tasks(self.facts.as_ref());
464 }
465}
466
467struct ReceiverTimerFuture<'a, F, P, I, FR>
468where
469 F: ClientFacts,
470 P: ClientTransport,
471 I: TimeInterval,
472 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
473{
474 client: &'a ClientStreamInner<F, P>,
475 inv: Pin<&'a mut I>,
476 recv_future: Pin<&'a mut FR>,
477}
478
479impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
480where
481 F: ClientFacts,
482 P: ClientTransport,
483 I: TimeInterval,
484 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
485{
486 fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
487 Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
488 }
489}
490
491impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
495where
496 F: ClientFacts,
497 P: ClientTransport,
498 I: TimeInterval,
499 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
500{
501 type Output = Result<(), RpcIntErr>;
502
503 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
504 let mut _self = self.get_mut();
505 while let Poll::Ready(_) = _self.inv.as_mut().poll_tick(ctx) {
507 _self.client.time_reach();
508 }
509 if _self.client.has_err.load(Ordering::Relaxed) {
510 return Poll::Ready(Err(RpcIntErr::IO));
513 }
514 _self.client.get_timer_mut().poll_sent_task(ctx);
515 if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
517 return Poll::Ready(r);
518 }
519 return Poll::Pending;
520 }
521}