1use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::*;
16use futures::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20 cell::UnsafeCell,
21 fmt,
22 future::Future,
23 mem::transmute,
24 pin::Pin,
25 sync::{
26 Arc,
27 atomic::{AtomicBool, AtomicU64, Ordering},
28 },
29 task::{Context, Poll},
30};
31use sync_utils::time::DelayedTime;
32
33pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
42 close_tx: Option<MTx<()>>,
43 inner: Arc<ClientStreamInner<F, P>>,
44}
45
46impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
47 #[inline]
49 pub fn connect(
50 facts: Arc<F>, addr: &str, conn_id: &str, last_resp_ts: Option<Arc<AtomicU64>>,
51 ) -> impl Future<Output = Result<Self, RpcIntErr>> + Send {
52 async move {
53 let client_id = facts.get_client_id();
54 let conn = P::connect(addr, conn_id, facts.get_config()).await?;
55 Ok(Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts))
56 }
57 }
58
59 #[inline]
60 fn new(
61 facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
62 last_resp_ts: Option<Arc<AtomicU64>>,
63 ) -> Self {
64 let (_close_tx, _close_rx) = mpmc::unbounded_async::<()>();
65 let inner = Arc::new(ClientStreamInner::new(
66 facts,
67 conn,
68 client_id,
69 conn_id,
70 _close_rx,
71 last_resp_ts,
72 ));
73 logger_debug!(inner.logger, "{:?} connected", inner);
74 let _inner = inner.clone();
75 inner.facts.spawn_detach(async move {
76 _inner.receive_loop().await;
77 });
78 Self { close_tx: Some(_close_tx), inner }
79 }
80
81 #[inline]
82 pub fn get_codec(&self) -> &F::Codec {
83 &self.inner.codec
84 }
85
86 #[inline(always)]
90 pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
91 self.inner.send_ping_req().await
92 }
93
94 #[inline(always)]
95 pub fn get_last_resp_ts(&self) -> u64 {
96 if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
97 }
98
99 #[inline(always)]
101 pub fn is_closed(&self) -> bool {
102 self.inner.closed.load(Ordering::SeqCst)
103 }
104
105 pub async fn set_error_and_exit(&mut self) {
109 self.inner.has_err.store(true, Ordering::SeqCst);
110 self.inner.conn.close_conn::<F>(&self.inner.logger).await;
111 if let Some(close_tx) = self.close_tx.as_ref() {
112 let _ = close_tx.send(()); }
114 }
115
116 #[inline(always)]
124 pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
125 self.inner.send_task(task, need_flush).await
126 }
127
128 #[inline(always)]
131 pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
132 self.inner.flush_req().await
133 }
134
135 #[inline]
137 pub fn will_block(&self) -> bool {
138 self.inner.throttler.nearly_full()
139 }
140
141 #[inline]
143 pub fn get_inflight_count(&self) -> usize {
144 self.inner.throttler.get_inflight_count()
146 }
147}
148
149impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
150 fn drop(&mut self) {
151 self.close_tx.take();
152 let timer = self.inner.get_timer_mut();
153 timer.stop_reg_task();
154 self.inner.closed.store(true, Ordering::SeqCst);
155 }
156}
157
158impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
159 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
160 self.inner.fmt(f)
161 }
162}
163
164struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
165 client_id: u64,
166 conn: P,
167 seq: AtomicU64,
168 close_rx: MAsyncRx<()>, closed: AtomicBool, timer: UnsafeCell<ClientTaskTimer<F>>,
171 has_err: AtomicBool,
173 throttler: Throttler,
174 last_resp_ts: Option<Arc<AtomicU64>>,
175 encode_buf: UnsafeCell<Vec<u8>>,
176 codec: F::Codec,
177 logger: Arc<LogFilter>,
178 facts: Arc<F>,
179}
180
181unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
182
183unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
184
185impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
186 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
187 self.conn.fmt(f)
188 }
189}
190
191impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
192 pub fn new(
193 facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: MAsyncRx<()>,
194 last_resp_ts: Option<Arc<AtomicU64>>,
195 ) -> Self {
196 let config = facts.get_config();
197 let mut thresholds = config.thresholds;
198 if thresholds == 0 {
199 thresholds = 128;
200 }
201 let client_inner = Self {
202 client_id,
203 conn,
204 close_rx,
205 closed: AtomicBool::new(false),
206 seq: AtomicU64::new(1),
207 encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
208 throttler: Throttler::new(thresholds),
209 last_resp_ts,
210 has_err: AtomicBool::new(false),
211 codec: F::Codec::default(),
212 logger: facts.new_logger(),
213 timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
214 facts,
215 };
216 logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
217 client_inner
218 }
219
220 #[inline(always)]
221 fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
222 unsafe { transmute(self.timer.get()) }
223 }
224
225 #[inline(always)]
226 fn get_encoded_buf(&self) -> &mut Vec<u8> {
227 unsafe { transmute(self.encode_buf.get()) }
228 }
229
230 async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
232 if self.throttler.nearly_full() {
233 need_flush = true;
234 }
235 let timer = self.get_timer_mut();
236 timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
237 if self.closed.load(Ordering::Acquire) {
240 logger_warn!(
241 self.logger,
242 "{:?} sending task {:?} failed: {}",
243 self,
244 task,
245 RpcIntErr::IO,
246 );
247 task.set_rpc_error(RpcIntErr::IO);
248 self.facts.error_handle(task);
249 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); return Err(RpcIntErr::IO);
251 }
252 match self.send_request(task, need_flush).await {
253 Err(_) => {
254 self.closed.store(true, Ordering::SeqCst);
255 self.has_err.store(true, Ordering::SeqCst);
256 return Err(RpcIntErr::IO);
257 }
258 Ok(_) => {
259 self.throttler.throttle().await;
261 return Ok(());
262 }
263 }
264 }
265
266 #[inline(always)]
267 async fn flush_req(&self) -> Result<(), RpcIntErr> {
268 if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
269 logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
270 self.closed.store(true, Ordering::SeqCst);
271 self.has_err.store(true, Ordering::SeqCst);
272 let timer = self.get_timer_mut();
273 timer.stop_reg_task();
274 return Err(RpcIntErr::IO);
275 }
276 Ok(())
277 }
278
279 #[inline(always)]
280 async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
281 let seq = self.seq_update();
282 task.set_seq(seq);
283 let buf = self.get_encoded_buf();
284 match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
285 Err(_) => {
286 logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
287 return Err(RpcIntErr::Encode);
288 }
289 Ok(blob_buf) => {
290 if let Err(e) =
291 self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
292 {
293 logger_warn!(
294 self.logger,
295 "{:?} send_req write req {:?} err: {:?}",
296 self,
297 task,
298 e
299 );
300 self.closed.store(true, Ordering::SeqCst);
301 self.has_err.store(true, Ordering::SeqCst);
302 let timer = self.get_timer_mut();
303 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
306 timer.stop_reg_task();
307 logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
308 task.set_rpc_error(RpcIntErr::IO);
309 self.facts.error_handle(task);
310 return Err(RpcIntErr::IO);
311 } else {
312 let wg = self.throttler.add_task();
313 let timer = self.get_timer_mut();
314 logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
315 timer.reg_task(task, wg).await;
316 }
317 return Ok(());
318 }
319 }
320 }
321
322 #[inline(always)]
323 async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
324 if self.closed.load(Ordering::Acquire) {
325 logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
326 return Err(RpcIntErr::IO);
327 }
328 let buf = self.get_encoded_buf();
330 proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
331 if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
334 logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
335 self.closed.store(true, Ordering::SeqCst);
336 return Err(RpcIntErr::IO);
337 }
338 Ok(())
339 }
340
341 async fn recv_some(&self) -> Result<(), RpcIntErr> {
343 for _ in 0i32..20 {
344 match self.recv_one_resp().await {
347 Err(e) => {
348 return Err(e);
349 }
350 Ok(_) => {
351 if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
352 last_resp_ts.store(DelayedTime::get(), Ordering::Relaxed);
353 }
354 }
355 }
356 }
357 Ok(())
358 }
359
360 async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
361 let timer = self.get_timer_mut();
362 loop {
363 if self.closed.load(Ordering::Acquire) {
364 logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
365 if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
367 return Err(RpcIntErr::IO);
368 }
369 if let Err(_e) = self
370 .conn
371 .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
372 .await
373 {
374 self.closed.store(true, Ordering::SeqCst);
375 return Err(RpcIntErr::IO);
376 }
377 } else {
378 match self
380 .conn
381 .read_resp(
382 self.facts.as_ref(),
383 &self.logger,
384 &self.codec,
385 Some(&self.close_rx),
386 timer,
387 )
388 .await
389 {
390 Err(_e) => {
391 return Err(RpcIntErr::IO);
392 }
393 Ok(r) => {
394 if !r {
396 self.closed.store(true, Ordering::SeqCst);
397 continue;
398 }
399 }
400 }
401 }
402 }
403 }
404
405 async fn receive_loop(&self) {
406 let mut tick = <F as AsyncTime>::tick(Duration::from_secs(1));
407 loop {
408 let f = self.recv_some();
409 pin_mut!(f);
410 let selector = ReceiverTimerFuture::new(self, &mut tick, &mut f);
411 match selector.await {
412 Ok(_) => {}
413 Err(e) => {
414 logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
415 self.closed.store(true, Ordering::SeqCst);
416 let timer = self.get_timer_mut();
417 timer.clean_pending_tasks(self.facts.as_ref());
418 while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
420 timer.clean_pending_tasks(self.facts.as_ref());
424 <F as AsyncTime>::sleep(Duration::from_secs(1)).await;
425 }
426 return;
427 }
428 }
429 }
430 }
431
432 fn time_reach(&self) {
434 logger_trace!(
435 self.logger,
436 "{:?} has {} pending_tasks",
437 self,
438 self.throttler.get_inflight_count()
439 );
440 let timer = self.get_timer_mut();
441 timer.adjust_task_queue(self.facts.as_ref());
442 return;
443 }
444
445 #[inline(always)]
446 fn seq_update(&self) -> u64 {
447 self.seq.fetch_add(1, Ordering::SeqCst)
448 }
449}
450
451impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
452 fn drop(&mut self) {
453 let timer = self.get_timer_mut();
454 timer.clean_pending_tasks(self.facts.as_ref());
455 }
456}
457
458struct ReceiverTimerFuture<'a, F, P, I, FR>
459where
460 F: ClientFacts,
461 P: ClientTransport,
462 I: TimeInterval,
463 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
464{
465 client: &'a ClientStreamInner<F, P>,
466 inv: Pin<&'a mut I>,
467 recv_future: Pin<&'a mut FR>,
468}
469
470impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
471where
472 F: ClientFacts,
473 P: ClientTransport,
474 I: TimeInterval,
475 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
476{
477 fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
478 Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
479 }
480}
481
482impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
486where
487 F: ClientFacts,
488 P: ClientTransport,
489 I: TimeInterval,
490 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
491{
492 type Output = Result<(), RpcIntErr>;
493
494 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
495 let mut _self = self.get_mut();
496 while let Poll::Ready(_) = _self.inv.as_mut().poll_tick(ctx) {
498 _self.client.time_reach();
499 }
500 if _self.client.has_err.load(Ordering::Relaxed) {
501 return Poll::Ready(Err(RpcIntErr::IO));
504 }
505 _self.client.get_timer_mut().poll_sent_task(ctx);
506 if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
508 return Poll::Ready(r);
509 }
510 return Poll::Pending;
511 }
512}