razor_stream/client/
stream.rs

1//! [ClientStream] represents a client-side connection.
2//!
3//! On Drop, will close the connection on write-side, the response reader coroutine not exit
4//! until all the ClientTask have a response or after `task_timeout` is reached.
5//!
6//! The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
7//! An internal timer then registers the request through a channel, and when the response
8//! is received, it can optionally notify the user through a user-defined channel or another mechanism.
9
10use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::*;
16use futures::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20    cell::UnsafeCell,
21    fmt,
22    future::Future,
23    mem::transmute,
24    pin::Pin,
25    sync::{
26        Arc,
27        atomic::{AtomicBool, AtomicU64, Ordering},
28    },
29    task::{Context, Poll},
30};
31use sync_utils::time::DelayedTime;
32
33/// ClientStream represents a client-side connection.
34///
35/// On Drop, the connection will be closed on the write-side. The response reader coroutine will not exit
36/// until all the ClientTasks have a response or after `task_timeout` is reached.
37///
38/// The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
39/// An internal timer then registers the request through a channel, and when the response
40/// is received, it can optionally notify the user through a user-defined channel or another mechanism.
41pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
42    close_tx: Option<MTx<()>>,
43    inner: Arc<ClientStreamInner<F, P>>,
44}
45
46impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
47    /// Make a streaming connection to the server, returns [ClientStream] on success
48    #[inline]
49    pub fn connect(
50        facts: Arc<F>, addr: &str, conn_id: &str, last_resp_ts: Option<Arc<AtomicU64>>,
51    ) -> impl Future<Output = Result<Self, RpcIntErr>> + Send {
52        async move {
53            let client_id = facts.get_client_id();
54            let conn = P::connect(addr, conn_id, facts.get_config()).await?;
55            Ok(Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts))
56        }
57    }
58
59    #[inline]
60    fn new(
61        facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
62        last_resp_ts: Option<Arc<AtomicU64>>,
63    ) -> Self {
64        let (_close_tx, _close_rx) = mpmc::unbounded_async::<()>();
65        let inner = Arc::new(ClientStreamInner::new(
66            facts,
67            conn,
68            client_id,
69            conn_id,
70            _close_rx,
71            last_resp_ts,
72        ));
73        logger_debug!(inner.logger, "{:?} connected", inner);
74        let _inner = inner.clone();
75        inner.facts.spawn_detach(async move {
76            _inner.receive_loop().await;
77        });
78        Self { close_tx: Some(_close_tx), inner }
79    }
80
81    #[inline]
82    pub fn get_codec(&self) -> &F::Codec {
83        &self.inner.codec
84    }
85
86    /// Should be call in sender threads
87    ///
88    /// NOTE: will skip if throttler is full
89    #[inline(always)]
90    pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
91        self.inner.send_ping_req().await
92    }
93
94    #[inline(always)]
95    pub fn get_last_resp_ts(&self) -> u64 {
96        if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
97    }
98
99    /// Since sender and receiver are two threads, might be close on either side
100    #[inline(always)]
101    pub fn is_closed(&self) -> bool {
102        self.inner.closed.load(Ordering::SeqCst)
103    }
104
105    /// Force the receiver to exit.
106    ///
107    /// You can call it when connectivity probes detect that a server is unreachable.
108    pub async fn set_error_and_exit(&mut self) {
109        self.inner.has_err.store(true, Ordering::SeqCst);
110        self.inner.conn.close_conn::<F>(&self.inner.logger).await;
111        if let Some(close_tx) = self.close_tx.as_ref() {
112            let _ = close_tx.send(()); // This equals to ClientStream::drop
113        }
114    }
115
116    /// send_task() should only be called without parallelism.
117    ///
118    /// NOTE: After send, will wait for response if too many inflight task in throttler.
119    ///
120    /// Since the transport layer might have buffer, user should always call flush explicitly.
121    /// You can set `need_flush` = true for some urgent messages, or call flush_req() explicitly.
122    ///
123    #[inline(always)]
124    pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
125        self.inner.send_task(task, need_flush).await
126    }
127
128    /// Since the transport layer might have buffer, user should always call flush explicitly.
129    /// you can set `need_flush` = true for some urgent message, or call flush_req() explicitly.
130    #[inline(always)]
131    pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
132        self.inner.flush_req().await
133    }
134
135    /// Check the throttler and see if future send_task() might be blocked
136    #[inline]
137    pub fn will_block(&self) -> bool {
138        self.inner.throttler.nearly_full()
139    }
140
141    /// Get the task sent but not yet received response
142    #[inline]
143    pub fn get_inflight_count(&self) -> usize {
144        // TODO confirm ping task counted ?
145        self.inner.throttler.get_inflight_count()
146    }
147}
148
149impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
150    fn drop(&mut self) {
151        self.close_tx.take();
152        let timer = self.inner.get_timer_mut();
153        timer.stop_reg_task();
154        self.inner.closed.store(true, Ordering::SeqCst);
155    }
156}
157
158impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
159    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
160        self.inner.fmt(f)
161    }
162}
163
164struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
165    client_id: u64,
166    conn: P,
167    seq: AtomicU64,
168    close_rx: MAsyncRx<()>, // When ClientStream(sender) dropped, receiver will be timer
169    closed: AtomicBool,     // flag set by either sender or receive on there exit
170    timer: UnsafeCell<ClientTaskTimer<F>>,
171    // TODO can closed and has_err merge ?
172    has_err: AtomicBool,
173    throttler: Throttler,
174    last_resp_ts: Option<Arc<AtomicU64>>,
175    encode_buf: UnsafeCell<Vec<u8>>,
176    codec: F::Codec,
177    logger: Arc<LogFilter>,
178    facts: Arc<F>,
179}
180
181unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
182
183unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
184
185impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
186    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
187        self.conn.fmt(f)
188    }
189}
190
191impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
192    pub fn new(
193        facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: MAsyncRx<()>,
194        last_resp_ts: Option<Arc<AtomicU64>>,
195    ) -> Self {
196        let config = facts.get_config();
197        let mut thresholds = config.thresholds;
198        if thresholds == 0 {
199            thresholds = 128;
200        }
201        let client_inner = Self {
202            client_id,
203            conn,
204            close_rx,
205            closed: AtomicBool::new(false),
206            seq: AtomicU64::new(1),
207            encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
208            throttler: Throttler::new(thresholds),
209            last_resp_ts,
210            has_err: AtomicBool::new(false),
211            codec: F::Codec::default(),
212            logger: facts.new_logger(),
213            timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
214            facts,
215        };
216        logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
217        client_inner
218    }
219
220    #[inline(always)]
221    fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
222        unsafe { transmute(self.timer.get()) }
223    }
224
225    #[inline(always)]
226    fn get_encoded_buf(&self) -> &mut Vec<u8> {
227        unsafe { transmute(self.encode_buf.get()) }
228    }
229
230    /// Directly work on the socket steam, when failed
231    async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
232        if self.throttler.nearly_full() {
233            need_flush = true;
234        }
235        let timer = self.get_timer_mut();
236        timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
237        // It's possible receiver set close after pending_task_count increase, keep this until
238        // review
239        if self.closed.load(Ordering::Acquire) {
240            logger_warn!(
241                self.logger,
242                "{:?} sending task {:?} failed: {}",
243                self,
244                task,
245                RpcIntErr::IO,
246            );
247            task.set_rpc_error(RpcIntErr::IO);
248            self.facts.error_handle(task);
249            timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); // rollback
250            return Err(RpcIntErr::IO);
251        }
252        match self.send_request(task, need_flush).await {
253            Err(_) => {
254                self.closed.store(true, Ordering::SeqCst);
255                self.has_err.store(true, Ordering::SeqCst);
256                return Err(RpcIntErr::IO);
257            }
258            Ok(_) => {
259                // register task to norifier
260                self.throttler.throttle().await;
261                return Ok(());
262            }
263        }
264    }
265
266    #[inline(always)]
267    async fn flush_req(&self) -> Result<(), RpcIntErr> {
268        if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
269            logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
270            self.closed.store(true, Ordering::SeqCst);
271            self.has_err.store(true, Ordering::SeqCst);
272            let timer = self.get_timer_mut();
273            timer.stop_reg_task();
274            return Err(RpcIntErr::IO);
275        }
276        Ok(())
277    }
278
279    #[inline(always)]
280    async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
281        let seq = self.seq_update();
282        task.set_seq(seq);
283        let buf = self.get_encoded_buf();
284        match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
285            Err(_) => {
286                logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
287                return Err(RpcIntErr::Encode);
288            }
289            Ok(blob_buf) => {
290                if let Err(e) =
291                    self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
292                {
293                    logger_warn!(
294                        self.logger,
295                        "{:?} send_req write req {:?} err: {:?}",
296                        self,
297                        task,
298                        e
299                    );
300                    self.closed.store(true, Ordering::SeqCst);
301                    self.has_err.store(true, Ordering::SeqCst);
302                    let timer = self.get_timer_mut();
303                    // TODO check stop_reg_task
304                    // rollback counter
305                    timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
306                    timer.stop_reg_task();
307                    logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
308                    task.set_rpc_error(RpcIntErr::IO);
309                    self.facts.error_handle(task);
310                    return Err(RpcIntErr::IO);
311                } else {
312                    let wg = self.throttler.add_task();
313                    let timer = self.get_timer_mut();
314                    logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
315                    timer.reg_task(task, wg).await;
316                }
317                return Ok(());
318            }
319        }
320    }
321
322    #[inline(always)]
323    async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
324        if self.closed.load(Ordering::Acquire) {
325            logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
326            return Err(RpcIntErr::IO);
327        }
328        // PING does not counted in throttler
329        let buf = self.get_encoded_buf();
330        proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
331        // Ping does not need to reg_task, and have no error_handle, just to keep the connection
332        // alive. Connection Prober can monitor the liveness of ClientConn
333        if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
334            logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
335            self.closed.store(true, Ordering::SeqCst);
336            return Err(RpcIntErr::IO);
337        }
338        Ok(())
339    }
340
341    // return Ok(false) when close_rx has close and nothing more pending resp to receive
342    async fn recv_some(&self) -> Result<(), RpcIntErr> {
343        for _ in 0i32..20 {
344            // Underlayer rpc socket is buffered, might not yeal to runtime
345            // return if recv_one_resp runs too long, allow timer to be fire at each second
346            match self.recv_one_resp().await {
347                Err(e) => {
348                    return Err(e);
349                }
350                Ok(_) => {
351                    if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
352                        last_resp_ts.store(DelayedTime::get(), Ordering::Relaxed);
353                    }
354                }
355            }
356        }
357        Ok(())
358    }
359
360    async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
361        let timer = self.get_timer_mut();
362        loop {
363            if self.closed.load(Ordering::Acquire) {
364                logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
365                // ensure task receive on normal exit
366                if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
367                    return Err(RpcIntErr::IO);
368                }
369                if let Err(_e) = self
370                    .conn
371                    .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
372                    .await
373                {
374                    self.closed.store(true, Ordering::SeqCst);
375                    return Err(RpcIntErr::IO);
376                }
377            } else {
378                // Block here for new header without timeout
379                match self
380                    .conn
381                    .read_resp(
382                        self.facts.as_ref(),
383                        &self.logger,
384                        &self.codec,
385                        Some(&self.close_rx),
386                        timer,
387                    )
388                    .await
389                {
390                    Err(_e) => {
391                        return Err(RpcIntErr::IO);
392                    }
393                    Ok(r) => {
394                        // TODO FIXME
395                        if !r {
396                            self.closed.store(true, Ordering::SeqCst);
397                            continue;
398                        }
399                    }
400                }
401            }
402        }
403    }
404
405    async fn receive_loop(&self) {
406        let mut tick = <F as AsyncTime>::tick(Duration::from_secs(1));
407        loop {
408            let f = self.recv_some();
409            pin_mut!(f);
410            let selector = ReceiverTimerFuture::new(self, &mut tick, &mut f);
411            match selector.await {
412                Ok(_) => {}
413                Err(e) => {
414                    logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
415                    self.closed.store(true, Ordering::SeqCst);
416                    let timer = self.get_timer_mut();
417                    timer.clean_pending_tasks(self.facts.as_ref());
418                    // If pending_task_count > 0 means some tasks may still remain in the pending chan
419                    while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
420                        // After the 'closed' flag has taken effect,
421                        // pending_task_count will not keep growing,
422                        // so there is no need to sleep here.
423                        timer.clean_pending_tasks(self.facts.as_ref());
424                        <F as AsyncTime>::sleep(Duration::from_secs(1)).await;
425                    }
426                    return;
427                }
428            }
429        }
430    }
431
432    // Adjust the waiting queue
433    fn time_reach(&self) {
434        logger_trace!(
435            self.logger,
436            "{:?} has {} pending_tasks",
437            self,
438            self.throttler.get_inflight_count()
439        );
440        let timer = self.get_timer_mut();
441        timer.adjust_task_queue(self.facts.as_ref());
442        return;
443    }
444
445    #[inline(always)]
446    fn seq_update(&self) -> u64 {
447        self.seq.fetch_add(1, Ordering::SeqCst)
448    }
449}
450
451impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
452    fn drop(&mut self) {
453        let timer = self.get_timer_mut();
454        timer.clean_pending_tasks(self.facts.as_ref());
455    }
456}
457
458struct ReceiverTimerFuture<'a, F, P, I, FR>
459where
460    F: ClientFacts,
461    P: ClientTransport,
462    I: TimeInterval,
463    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
464{
465    client: &'a ClientStreamInner<F, P>,
466    inv: Pin<&'a mut I>,
467    recv_future: Pin<&'a mut FR>,
468}
469
470impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
471where
472    F: ClientFacts,
473    P: ClientTransport,
474    I: TimeInterval,
475    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
476{
477    fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
478        Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
479    }
480}
481
482// Return Ok(true) to indicate Ok
483// Return Ok(false) when client sender has close normally
484// Err(e) when connection error
485impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
486where
487    F: ClientFacts,
488    P: ClientTransport,
489    I: TimeInterval,
490    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
491{
492    type Output = Result<(), RpcIntErr>;
493
494    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
495        let mut _self = self.get_mut();
496        // In case ticker not fire, and ensure ticker schedule after ready
497        while let Poll::Ready(_) = _self.inv.as_mut().poll_tick(ctx) {
498            _self.client.time_reach();
499        }
500        if _self.client.has_err.load(Ordering::Relaxed) {
501            // When sentinel detect peer unreachable, recv_some mighe blocked, at least inv will
502            // wait us, just exit
503            return Poll::Ready(Err(RpcIntErr::IO));
504        }
505        _self.client.get_timer_mut().poll_sent_task(ctx);
506        // Even if receive future has block, we should poll_sent_task in order to detect timeout event
507        if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
508            return Poll::Ready(r);
509        }
510        return Poll::Pending;
511    }
512}