1use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::null::CloseHandle;
16use futures_util::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20 cell::UnsafeCell,
21 fmt,
22 future::Future,
23 mem::transmute,
24 pin::Pin,
25 sync::{
26 Arc,
27 atomic::{AtomicBool, AtomicU64, Ordering},
28 },
29 task::{Context, Poll},
30};
31
32pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
41 close_tx: Option<CloseHandle<mpsc::Null>>,
42 inner: Arc<ClientStreamInner<F, P>>,
43}
44
45impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
46 #[inline]
48 pub async fn connect<RT: orb::AsyncRuntime>(
49 facts: Arc<F>, rt: &RT, addr: &str, conn_id: &str, last_resp_ts: Option<Arc<AtomicU64>>,
50 ) -> Result<Self, RpcIntErr> {
51 let client_id = facts.get_client_id();
52 let conn = P::connect(addr, conn_id, facts.get_config()).await?;
53 let this = Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts);
54 let inner = this.inner.clone();
55 rt.spawn_detach(async move {
56 inner.receive_loop::<RT>().await;
57 });
58 Ok(this)
59 }
60
61 #[inline]
62 fn new(
63 facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
64 last_resp_ts: Option<Arc<AtomicU64>>,
65 ) -> Self {
66 let (_close_tx, _close_rx) = mpsc::new::<mpsc::Null, _, _>();
67 let inner = Arc::new(ClientStreamInner::new(
68 facts,
69 conn,
70 client_id,
71 conn_id,
72 _close_rx,
73 last_resp_ts,
74 ));
75 logger_debug!(inner.logger, "{:?} connected", inner);
76 Self { close_tx: Some(_close_tx), inner }
77 }
78
79 #[inline]
80 pub fn get_codec(&self) -> &F::Codec {
81 &self.inner.codec
82 }
83
84 #[inline(always)]
88 pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
89 self.inner.send_ping_req().await
90 }
91
92 #[inline(always)]
93 pub fn get_last_resp_ts(&self) -> u64 {
94 if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
95 }
96
97 #[inline(always)]
99 pub fn is_closed(&self) -> bool {
100 self.inner.closed.load(Ordering::SeqCst)
101 }
102
103 pub async fn set_error_and_exit(&mut self) {
108 self.inner.has_err.store(true, Ordering::SeqCst);
110 self.inner.conn.close_conn::<F>(&self.inner.logger).await;
111 }
112
113 #[inline(always)]
121 pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
122 self.inner.send_task(task, need_flush).await
123 }
124
125 #[inline(always)]
128 pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
129 self.inner.flush_req().await
130 }
131
132 #[inline]
134 pub fn will_block(&self) -> bool {
135 self.inner.throttler.nearly_full()
136 }
137
138 #[inline]
140 pub fn get_inflight_count(&self) -> usize {
141 self.inner.throttler.get_inflight_count()
143 }
144}
145
146impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
147 fn drop(&mut self) {
148 self.close_tx.take();
149 let timer = self.inner.get_timer_mut();
150 timer.stop_reg_task();
151 self.inner.closed.store(true, Ordering::SeqCst);
152 }
153}
154
155impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
156 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157 self.inner.fmt(f)
158 }
159}
160
161struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
162 client_id: u64,
163 conn: P,
164 seq: AtomicU64,
165 close_rx: UnsafeCell<AsyncRx<mpsc::Null>>,
168 closed: AtomicBool, timer: UnsafeCell<ClientTaskTimer<F>>,
170 has_err: AtomicBool,
172 throttler: Throttler,
173 last_resp_ts: Option<Arc<AtomicU64>>,
174 encode_buf: UnsafeCell<Vec<u8>>,
175 codec: F::Codec,
176 logger: Arc<LogFilter>,
177 facts: Arc<F>,
178}
179
180unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
181
182unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
183
184impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
185 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
186 self.conn.fmt(f)
187 }
188}
189
190impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
191 pub fn new(
192 facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: AsyncRx<mpsc::Null>,
193 last_resp_ts: Option<Arc<AtomicU64>>,
194 ) -> Self {
195 let config = facts.get_config();
196 let mut thresholds = config.thresholds;
197 if thresholds == 0 {
198 thresholds = 128;
199 }
200 let client_inner = Self {
201 client_id,
202 conn,
203 close_rx: UnsafeCell::new(close_rx),
204 closed: AtomicBool::new(false),
205 seq: AtomicU64::new(1),
206 encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
207 throttler: Throttler::new(thresholds),
208 last_resp_ts,
209 has_err: AtomicBool::new(false),
210 codec: F::Codec::default(),
211 logger: facts.new_logger(),
212 timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
213 facts,
214 };
215 logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
216 client_inner
217 }
218
219 #[inline(always)]
220 fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
221 unsafe { transmute(self.timer.get()) }
222 }
223
224 #[inline(always)]
225 fn get_close_rx(&self) -> &mut AsyncRx<mpsc::Null> {
226 unsafe { transmute(self.close_rx.get()) }
227 }
228
229 #[inline(always)]
230 fn get_encoded_buf(&self) -> &mut Vec<u8> {
231 unsafe { transmute(self.encode_buf.get()) }
232 }
233
234 async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
236 if self.throttler.nearly_full() {
237 need_flush = true;
238 }
239 let timer = self.get_timer_mut();
240 timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
241 if self.closed.load(Ordering::Acquire) {
244 logger_warn!(
245 self.logger,
246 "{:?} sending task {:?} failed: {}",
247 self,
248 task,
249 RpcIntErr::IO,
250 );
251 task.set_rpc_error(RpcIntErr::IO);
252 self.facts.error_handle(task);
253 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); return Err(RpcIntErr::IO);
255 }
256 match self.send_request(task, need_flush).await {
257 Err(_) => {
258 self.closed.store(true, Ordering::SeqCst);
259 self.has_err.store(true, Ordering::SeqCst);
260 return Err(RpcIntErr::IO);
261 }
262 Ok(_) => {
263 self.throttler.throttle().await;
265 return Ok(());
266 }
267 }
268 }
269
270 #[inline(always)]
271 async fn flush_req(&self) -> Result<(), RpcIntErr> {
272 if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
273 logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
274 self.closed.store(true, Ordering::SeqCst);
275 self.has_err.store(true, Ordering::SeqCst);
276 let timer = self.get_timer_mut();
277 timer.stop_reg_task();
278 return Err(RpcIntErr::IO);
279 }
280 Ok(())
281 }
282
283 #[inline(always)]
284 async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
285 let seq = self.seq_update();
286 task.set_seq(seq);
287 let buf = self.get_encoded_buf();
288 match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
289 Err(_) => {
290 logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
291 return Err(RpcIntErr::Encode);
292 }
293 Ok(blob_buf) => {
294 if let Err(e) =
295 self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
296 {
297 logger_warn!(
298 self.logger,
299 "{:?} send_req write req {:?} err: {:?}",
300 self,
301 task,
302 e
303 );
304 self.closed.store(true, Ordering::SeqCst);
305 self.has_err.store(true, Ordering::SeqCst);
306 let timer = self.get_timer_mut();
307 timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
310 timer.stop_reg_task();
311 logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
312 task.set_rpc_error(RpcIntErr::IO);
313 self.facts.error_handle(task);
314 return Err(RpcIntErr::IO);
315 } else {
316 let wg = self.throttler.add_task();
317 let timer = self.get_timer_mut();
318 logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
319 timer.reg_task(task, wg).await;
320 }
321 return Ok(());
322 }
323 }
324 }
325
326 #[inline(always)]
327 async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
328 if self.closed.load(Ordering::Acquire) {
329 logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
330 return Err(RpcIntErr::IO);
331 }
332 let buf = self.get_encoded_buf();
334 proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
335 if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
338 logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
339 self.closed.store(true, Ordering::SeqCst);
340 return Err(RpcIntErr::IO);
341 }
342 Ok(())
343 }
344
345 async fn recv_some(&self) -> Result<(), RpcIntErr> {
347 for _ in 0i32..20 {
348 match self.recv_one_resp().await {
351 Err(e) => {
352 return Err(e);
353 }
354 Ok(_) => {
355 if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
356 last_resp_ts.store(self.facts.get_timestamp(), Ordering::Release);
357 }
358 }
359 }
360 }
361 Ok(())
362 }
363
364 async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
365 let timer = self.get_timer_mut();
366 loop {
367 if self.closed.load(Ordering::Acquire) {
368 logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
369 if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
371 return Err(RpcIntErr::IO);
372 }
373 if let Err(_e) = self
375 .conn
376 .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
377 .await
378 {
379 self.closed.store(true, Ordering::SeqCst);
380 return Err(RpcIntErr::IO);
381 }
382 } else {
383 match self
387 .conn
388 .read_resp(
389 self.facts.as_ref(),
390 &self.logger,
391 &self.codec,
392 Some(self.get_close_rx()),
393 timer,
394 )
395 .await
396 {
397 Err(_e) => {
398 return Err(RpcIntErr::IO);
399 }
400 Ok(r) => {
401 if !r {
403 self.closed.store(true, Ordering::SeqCst);
404 continue;
405 }
406 }
407 }
408 }
409 }
410 }
411
412 async fn receive_loop<RT: AsyncRuntime>(&self) {
413 let mut tick = <RT as AsyncTime>::interval(Duration::from_secs(1));
414 loop {
415 let f = self.recv_some();
416 pin_mut!(f);
417 let selector = ReceiverTimerFuture::new(self, &mut tick, &mut f);
418 match selector.await {
419 Ok(_) => {}
420 Err(e) => {
421 logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
422 self.closed.store(true, Ordering::SeqCst);
423 let timer = self.get_timer_mut();
424 timer.clean_pending_tasks(self.facts.as_ref());
425 while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
427 timer.clean_pending_tasks(self.facts.as_ref());
431 <RT as AsyncTime>::sleep(Duration::from_secs(1)).await;
432 }
433 return;
434 }
435 }
436 }
437 }
438
439 fn time_reach(&self) {
441 logger_trace!(
442 self.logger,
443 "{:?} has {} pending_tasks",
444 self,
445 self.throttler.get_inflight_count()
446 );
447 let timer = self.get_timer_mut();
448 timer.adjust_task_queue(self.facts.as_ref());
449 return;
450 }
451
452 #[inline(always)]
453 fn seq_update(&self) -> u64 {
454 self.seq.fetch_add(1, Ordering::SeqCst)
455 }
456}
457
458impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
459 fn drop(&mut self) {
460 let timer = self.get_timer_mut();
461 timer.clean_pending_tasks(self.facts.as_ref());
462 }
463}
464
465struct ReceiverTimerFuture<'a, F, P, I, FR>
466where
467 F: ClientFacts,
468 P: ClientTransport,
469 I: TimeInterval,
470 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
471{
472 client: &'a ClientStreamInner<F, P>,
473 inv: Pin<&'a mut I>,
474 recv_future: Pin<&'a mut FR>,
475}
476
477impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
478where
479 F: ClientFacts,
480 P: ClientTransport,
481 I: TimeInterval,
482 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
483{
484 fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
485 Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
486 }
487}
488
489impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
493where
494 F: ClientFacts,
495 P: ClientTransport,
496 I: TimeInterval,
497 FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
498{
499 type Output = Result<(), RpcIntErr>;
500
501 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
502 let mut _self = self.get_mut();
503 while _self.inv.as_mut().poll_tick(ctx).is_ready() {
505 _self.client.time_reach();
506 }
507 if _self.client.has_err.load(Ordering::Relaxed) {
508 return Poll::Ready(Err(RpcIntErr::IO));
511 }
512 _self.client.get_timer_mut().poll_sent_task(ctx);
513 if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
515 return Poll::Ready(r);
516 }
517 return Poll::Pending;
518 }
519}