Skip to main content

razor_stream/client/
stream.rs

1//! [ClientStream] represents a client-side connection.
2//!
3//! On Drop, will close the connection on write-side, the response reader coroutine not exit
4//! until all the ClientTask have a response or after `task_timeout` is reached.
5//!
6//! The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
7//! An internal timer then registers the request through a channel, and when the response
8//! is received, it can optionally notify the user through a user-defined channel or another mechanism.
9
10use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::null::CloseHandle;
16use futures::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20    cell::UnsafeCell,
21    fmt,
22    future::Future,
23    mem::transmute,
24    pin::Pin,
25    sync::{
26        Arc,
27        atomic::{AtomicBool, AtomicU64, Ordering},
28    },
29    task::{Context, Poll},
30};
31use sync_utils::time::DelayedTime;
32
33/// ClientStream represents a client-side connection.
34///
35/// On Drop, the connection will be closed on the write-side. The response reader coroutine will not exit
36/// until all the ClientTasks have a response or after `task_timeout` is reached.
37///
38/// The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
39/// An internal timer then registers the request through a channel, and when the response
40/// is received, it can optionally notify the user through a user-defined channel or another mechanism.
41pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
42    close_tx: Option<CloseHandle<mpsc::Null>>,
43    inner: Arc<ClientStreamInner<F, P>>,
44}
45
46impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
47    /// Make a streaming connection to the server, returns [ClientStream] on success
48    #[inline]
49    pub fn connect(
50        facts: Arc<F>, addr: &str, conn_id: &str, last_resp_ts: Option<Arc<AtomicU64>>,
51    ) -> impl Future<Output = Result<Self, RpcIntErr>> + Send {
52        async move {
53            let client_id = facts.get_client_id();
54            let conn = P::connect(addr, conn_id, facts.get_config()).await?;
55            Ok(Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts))
56        }
57    }
58
59    #[inline]
60    fn new(
61        facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
62        last_resp_ts: Option<Arc<AtomicU64>>,
63    ) -> Self {
64        let (_close_tx, _close_rx) = mpsc::new::<mpsc::Null, _, _>();
65        let inner = Arc::new(ClientStreamInner::new(
66            facts,
67            conn,
68            client_id,
69            conn_id,
70            _close_rx,
71            last_resp_ts,
72        ));
73        logger_debug!(inner.logger, "{:?} connected", inner);
74        let _inner = inner.clone();
75        inner.facts.spawn_detach(async move {
76            _inner.receive_loop().await;
77        });
78        Self { close_tx: Some(_close_tx), inner }
79    }
80
81    #[inline]
82    pub fn get_codec(&self) -> &F::Codec {
83        &self.inner.codec
84    }
85
86    /// Should be call in sender threads
87    ///
88    /// NOTE: will skip if throttler is full
89    #[inline(always)]
90    pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
91        self.inner.send_ping_req().await
92    }
93
94    #[inline(always)]
95    pub fn get_last_resp_ts(&self) -> u64 {
96        if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
97    }
98
99    /// Since sender and receiver are two threads, might be close on either side
100    #[inline(always)]
101    pub fn is_closed(&self) -> bool {
102        self.inner.closed.load(Ordering::SeqCst)
103    }
104
105    /// Force the receiver to exit.
106    ///
107    /// You can call it when connectivity probes detect that a server is unreachable.
108    /// And then just let the Client drop
109    pub async fn set_error_and_exit(&mut self) {
110        // TODO review usage when doing ConnProbe
111        self.inner.has_err.store(true, Ordering::SeqCst);
112        self.inner.conn.close_conn::<F>(&self.inner.logger).await;
113    }
114
115    /// send_task() should only be called without parallelism.
116    ///
117    /// NOTE: After send, will wait for response if too many inflight task in throttler.
118    ///
119    /// Since the transport layer might have buffer, user should always call flush explicitly.
120    /// You can set `need_flush` = true for some urgent messages, or call flush_req() explicitly.
121    ///
122    #[inline(always)]
123    pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
124        self.inner.send_task(task, need_flush).await
125    }
126
127    /// Since the transport layer might have buffer, user should always call flush explicitly.
128    /// you can set `need_flush` = true for some urgent message, or call flush_req() explicitly.
129    #[inline(always)]
130    pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
131        self.inner.flush_req().await
132    }
133
134    /// Check the throttler and see if future send_task() might be blocked
135    #[inline]
136    pub fn will_block(&self) -> bool {
137        self.inner.throttler.nearly_full()
138    }
139
140    /// Get the task sent but not yet received response
141    #[inline]
142    pub fn get_inflight_count(&self) -> usize {
143        // TODO confirm ping task counted ?
144        self.inner.throttler.get_inflight_count()
145    }
146}
147
148impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
149    fn drop(&mut self) {
150        self.close_tx.take();
151        let timer = self.inner.get_timer_mut();
152        timer.stop_reg_task();
153        self.inner.closed.store(true, Ordering::SeqCst);
154    }
155}
156
157impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
158    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
159        self.inner.fmt(f)
160    }
161}
162
163struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
164    client_id: u64,
165    conn: P,
166    seq: AtomicU64,
167    // NOTE: because close_rx is AsyncRx (lower cost to register waker), but does not have Sync, to solve the & does
168    // not have Send problem, we convert this to mut ref to make borrow checker shutup
169    close_rx: UnsafeCell<AsyncRx<mpsc::Null>>,
170    closed: AtomicBool, // flag set by either sender or receive on there exit
171    timer: UnsafeCell<ClientTaskTimer<F>>,
172    // TODO can closed and has_err merge ?
173    has_err: AtomicBool,
174    throttler: Throttler,
175    last_resp_ts: Option<Arc<AtomicU64>>,
176    encode_buf: UnsafeCell<Vec<u8>>,
177    codec: F::Codec,
178    logger: Arc<LogFilter>,
179    facts: Arc<F>,
180}
181
182unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
183
184unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
185
186impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
187    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
188        self.conn.fmt(f)
189    }
190}
191
192impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
193    pub fn new(
194        facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: AsyncRx<mpsc::Null>,
195        last_resp_ts: Option<Arc<AtomicU64>>,
196    ) -> Self {
197        let config = facts.get_config();
198        let mut thresholds = config.thresholds;
199        if thresholds == 0 {
200            thresholds = 128;
201        }
202        let client_inner = Self {
203            client_id,
204            conn,
205            close_rx: UnsafeCell::new(close_rx),
206            closed: AtomicBool::new(false),
207            seq: AtomicU64::new(1),
208            encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
209            throttler: Throttler::new(thresholds),
210            last_resp_ts,
211            has_err: AtomicBool::new(false),
212            codec: F::Codec::default(),
213            logger: facts.new_logger(),
214            timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
215            facts,
216        };
217        logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
218        client_inner
219    }
220
221    #[inline(always)]
222    fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
223        unsafe { transmute(self.timer.get()) }
224    }
225
226    #[inline(always)]
227    fn get_close_rx(&self) -> &mut AsyncRx<mpsc::Null> {
228        unsafe { transmute(self.close_rx.get()) }
229    }
230
231    #[inline(always)]
232    fn get_encoded_buf(&self) -> &mut Vec<u8> {
233        unsafe { transmute(self.encode_buf.get()) }
234    }
235
236    /// Directly work on the socket steam, when failed
237    async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
238        if self.throttler.nearly_full() {
239            need_flush = true;
240        }
241        let timer = self.get_timer_mut();
242        timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
243        // It's possible receiver set close after pending_task_count increase, keep this until
244        // review
245        if self.closed.load(Ordering::Acquire) {
246            logger_warn!(
247                self.logger,
248                "{:?} sending task {:?} failed: {}",
249                self,
250                task,
251                RpcIntErr::IO,
252            );
253            task.set_rpc_error(RpcIntErr::IO);
254            self.facts.error_handle(task);
255            timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); // rollback
256            return Err(RpcIntErr::IO);
257        }
258        match self.send_request(task, need_flush).await {
259            Err(_) => {
260                self.closed.store(true, Ordering::SeqCst);
261                self.has_err.store(true, Ordering::SeqCst);
262                return Err(RpcIntErr::IO);
263            }
264            Ok(_) => {
265                // register task to norifier
266                self.throttler.throttle().await;
267                return Ok(());
268            }
269        }
270    }
271
272    #[inline(always)]
273    async fn flush_req(&self) -> Result<(), RpcIntErr> {
274        if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
275            logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
276            self.closed.store(true, Ordering::SeqCst);
277            self.has_err.store(true, Ordering::SeqCst);
278            let timer = self.get_timer_mut();
279            timer.stop_reg_task();
280            return Err(RpcIntErr::IO);
281        }
282        Ok(())
283    }
284
285    #[inline(always)]
286    async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
287        let seq = self.seq_update();
288        task.set_seq(seq);
289        let buf = self.get_encoded_buf();
290        match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
291            Err(_) => {
292                logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
293                return Err(RpcIntErr::Encode);
294            }
295            Ok(blob_buf) => {
296                if let Err(e) =
297                    self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
298                {
299                    logger_warn!(
300                        self.logger,
301                        "{:?} send_req write req {:?} err: {:?}",
302                        self,
303                        task,
304                        e
305                    );
306                    self.closed.store(true, Ordering::SeqCst);
307                    self.has_err.store(true, Ordering::SeqCst);
308                    let timer = self.get_timer_mut();
309                    // TODO check stop_reg_task
310                    // rollback counter
311                    timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
312                    timer.stop_reg_task();
313                    logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
314                    task.set_rpc_error(RpcIntErr::IO);
315                    self.facts.error_handle(task);
316                    return Err(RpcIntErr::IO);
317                } else {
318                    let wg = self.throttler.add_task();
319                    let timer = self.get_timer_mut();
320                    logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
321                    timer.reg_task(task, wg).await;
322                }
323                return Ok(());
324            }
325        }
326    }
327
328    #[inline(always)]
329    async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
330        if self.closed.load(Ordering::Acquire) {
331            logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
332            return Err(RpcIntErr::IO);
333        }
334        // PING does not counted in throttler
335        let buf = self.get_encoded_buf();
336        proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
337        // Ping does not need to reg_task, and have no error_handle, just to keep the connection
338        // alive. Connection Prober can monitor the liveness of ClientConn
339        if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
340            logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
341            self.closed.store(true, Ordering::SeqCst);
342            return Err(RpcIntErr::IO);
343        }
344        Ok(())
345    }
346
347    // return Ok(false) when close_rx has close and nothing more pending resp to receive
348    async fn recv_some(&self) -> Result<(), RpcIntErr> {
349        for _ in 0i32..20 {
350            // Underlayer rpc socket is buffered, might not yeal to runtime
351            // return if recv_one_resp runs too long, allow timer to be fire at each second
352            match self.recv_one_resp().await {
353                Err(e) => {
354                    return Err(e);
355                }
356                Ok(_) => {
357                    if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
358                        last_resp_ts.store(DelayedTime::get(), Ordering::Relaxed);
359                    }
360                }
361            }
362        }
363        Ok(())
364    }
365
366    async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
367        let timer = self.get_timer_mut();
368        loop {
369            if self.closed.load(Ordering::Acquire) {
370                logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
371                // ensure task receive on normal exit
372                if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
373                    return Err(RpcIntErr::IO);
374                }
375                // When ClientStream(sender) dropped, receiver will be timer
376                if let Err(_e) = self
377                    .conn
378                    .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
379                    .await
380                {
381                    self.closed.store(true, Ordering::SeqCst);
382                    return Err(RpcIntErr::IO);
383                }
384            } else {
385                // Block here for new header without timeout
386                // NOTE: because close_rx is AsyncRx, which does not have Sync, to solve the & does
387                // not have Send problem, we convert this to mut ref to make borrow checker shutup
388                match self
389                    .conn
390                    .read_resp(
391                        self.facts.as_ref(),
392                        &self.logger,
393                        &self.codec,
394                        Some(self.get_close_rx()),
395                        timer,
396                    )
397                    .await
398                {
399                    Err(_e) => {
400                        return Err(RpcIntErr::IO);
401                    }
402                    Ok(r) => {
403                        // TODO FIXME
404                        if !r {
405                            self.closed.store(true, Ordering::SeqCst);
406                            continue;
407                        }
408                    }
409                }
410            }
411        }
412    }
413
414    async fn receive_loop(&self) {
415        let mut tick = <F as AsyncTime>::interval(Duration::from_secs(1));
416        loop {
417            let f = self.recv_some();
418            pin_mut!(f);
419            let selector = ReceiverTimerFuture::new(self, &mut tick, &mut f);
420            match selector.await {
421                Ok(_) => {}
422                Err(e) => {
423                    logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
424                    self.closed.store(true, Ordering::SeqCst);
425                    let timer = self.get_timer_mut();
426                    timer.clean_pending_tasks(self.facts.as_ref());
427                    // If pending_task_count > 0 means some tasks may still remain in the pending chan
428                    while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
429                        // After the 'closed' flag has taken effect,
430                        // pending_task_count will not keep growing,
431                        // so there is no need to sleep here.
432                        timer.clean_pending_tasks(self.facts.as_ref());
433                        <F as AsyncTime>::sleep(Duration::from_secs(1)).await;
434                    }
435                    return;
436                }
437            }
438        }
439    }
440
441    // Adjust the waiting queue
442    fn time_reach(&self) {
443        logger_trace!(
444            self.logger,
445            "{:?} has {} pending_tasks",
446            self,
447            self.throttler.get_inflight_count()
448        );
449        let timer = self.get_timer_mut();
450        timer.adjust_task_queue(self.facts.as_ref());
451        return;
452    }
453
454    #[inline(always)]
455    fn seq_update(&self) -> u64 {
456        self.seq.fetch_add(1, Ordering::SeqCst)
457    }
458}
459
460impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
461    fn drop(&mut self) {
462        let timer = self.get_timer_mut();
463        timer.clean_pending_tasks(self.facts.as_ref());
464    }
465}
466
467struct ReceiverTimerFuture<'a, F, P, I, FR>
468where
469    F: ClientFacts,
470    P: ClientTransport,
471    I: TimeInterval,
472    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
473{
474    client: &'a ClientStreamInner<F, P>,
475    inv: Pin<&'a mut I>,
476    recv_future: Pin<&'a mut FR>,
477}
478
479impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
480where
481    F: ClientFacts,
482    P: ClientTransport,
483    I: TimeInterval,
484    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
485{
486    fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
487        Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
488    }
489}
490
491// Return Ok(true) to indicate Ok
492// Return Ok(false) when client sender has close normally
493// Err(e) when connection error
494impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
495where
496    F: ClientFacts,
497    P: ClientTransport,
498    I: TimeInterval,
499    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
500{
501    type Output = Result<(), RpcIntErr>;
502
503    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
504        let mut _self = self.get_mut();
505        // In case ticker not fire, and ensure ticker schedule after ready
506        while let Poll::Ready(_) = _self.inv.as_mut().poll_tick(ctx) {
507            _self.client.time_reach();
508        }
509        if _self.client.has_err.load(Ordering::Relaxed) {
510            // When sentinel detect peer unreachable, recv_some mighe blocked, at least inv will
511            // wait us, just exit
512            return Poll::Ready(Err(RpcIntErr::IO));
513        }
514        _self.client.get_timer_mut().poll_sent_task(ctx);
515        // Even if receive future has block, we should poll_sent_task in order to detect timeout event
516        if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
517            return Poll::Ready(r);
518        }
519        return Poll::Pending;
520    }
521}