Skip to main content

razor_stream/client/
stream.rs

1//! [ClientStream] represents a client-side connection.
2//!
3//! On Drop, will close the connection on write-side, the response reader coroutine not exit
4//! until all the ClientTask have a response or after `task_timeout` is reached.
5//!
6//! The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
7//! An internal timer then registers the request through a channel, and when the response
8//! is received, it can optionally notify the user through a user-defined channel or another mechanism.
9
10use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::null::CloseHandle;
16use futures_util::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20    cell::UnsafeCell,
21    fmt,
22    future::Future,
23    mem::transmute,
24    pin::Pin,
25    sync::{
26        Arc,
27        atomic::{AtomicBool, AtomicU64, Ordering},
28    },
29    task::{Context, Poll},
30};
31
32/// ClientStream represents a client-side connection.
33///
34/// On Drop, the connection will be closed on the write-side. The response reader coroutine will not exit
35/// until all the ClientTasks have a response or after `task_timeout` is reached.
36///
37/// The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
38/// An internal timer then registers the request through a channel, and when the response
39/// is received, it can optionally notify the user through a user-defined channel or another mechanism.
40pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
41    close_tx: Option<CloseHandle<mpsc::Null>>,
42    inner: Arc<ClientStreamInner<F, P>>,
43}
44
45impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
46    /// Make a streaming connection to the server, returns [ClientStream] on success
47    #[inline]
48    pub async fn connect(
49        facts: Arc<F>, addr: &str, conn_id: &str, last_resp_ts: Option<Arc<AtomicU64>>,
50    ) -> Result<Self, RpcIntErr> {
51        let client_id = facts.get_client_id();
52        let conn = P::connect(addr, conn_id, facts.get_config()).await?;
53        Ok(Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts))
54    }
55
56    #[inline]
57    fn new(
58        facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
59        last_resp_ts: Option<Arc<AtomicU64>>,
60    ) -> Self {
61        let (_close_tx, _close_rx) = mpsc::new::<mpsc::Null, _, _>();
62        let inner = Arc::new(ClientStreamInner::new(
63            facts,
64            conn,
65            client_id,
66            conn_id,
67            _close_rx,
68            last_resp_ts,
69        ));
70        logger_debug!(inner.logger, "{:?} connected", inner);
71        let _inner = inner.clone();
72        inner.facts.spawn_detach(async move {
73            _inner.receive_loop().await;
74        });
75        Self { close_tx: Some(_close_tx), inner }
76    }
77
78    #[inline]
79    pub fn get_codec(&self) -> &F::Codec {
80        &self.inner.codec
81    }
82
83    /// Should be call in sender threads
84    ///
85    /// NOTE: will skip if throttler is full
86    #[inline(always)]
87    pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
88        self.inner.send_ping_req().await
89    }
90
91    #[inline(always)]
92    pub fn get_last_resp_ts(&self) -> u64 {
93        if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
94    }
95
96    /// Since sender and receiver are two threads, might be close on either side
97    #[inline(always)]
98    pub fn is_closed(&self) -> bool {
99        self.inner.closed.load(Ordering::SeqCst)
100    }
101
102    /// Force the receiver to exit.
103    ///
104    /// You can call it when connectivity probes detect that a server is unreachable.
105    /// And then just let the Client drop
106    pub async fn set_error_and_exit(&mut self) {
107        // TODO review usage when doing ConnProbe
108        self.inner.has_err.store(true, Ordering::SeqCst);
109        self.inner.conn.close_conn::<F>(&self.inner.logger).await;
110    }
111
112    /// send_task() should only be called without parallelism.
113    ///
114    /// NOTE: After send, will wait for response if too many inflight task in throttler.
115    ///
116    /// Since the transport layer might have buffer, user should always call flush explicitly.
117    /// You can set `need_flush` = true for some urgent messages, or call flush_req() explicitly.
118    ///
119    #[inline(always)]
120    pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
121        self.inner.send_task(task, need_flush).await
122    }
123
124    /// Since the transport layer might have buffer, user should always call flush explicitly.
125    /// you can set `need_flush` = true for some urgent message, or call flush_req() explicitly.
126    #[inline(always)]
127    pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
128        self.inner.flush_req().await
129    }
130
131    /// Check the throttler and see if future send_task() might be blocked
132    #[inline]
133    pub fn will_block(&self) -> bool {
134        self.inner.throttler.nearly_full()
135    }
136
137    /// Get the task sent but not yet received response
138    #[inline]
139    pub fn get_inflight_count(&self) -> usize {
140        // TODO confirm ping task counted ?
141        self.inner.throttler.get_inflight_count()
142    }
143}
144
145impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
146    fn drop(&mut self) {
147        self.close_tx.take();
148        let timer = self.inner.get_timer_mut();
149        timer.stop_reg_task();
150        self.inner.closed.store(true, Ordering::SeqCst);
151    }
152}
153
154impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
155    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156        self.inner.fmt(f)
157    }
158}
159
160struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
161    client_id: u64,
162    conn: P,
163    seq: AtomicU64,
164    // NOTE: because close_rx is AsyncRx (lower cost to register waker), but does not have Sync, to solve the & does
165    // not have Send problem, we convert this to mut ref to make borrow checker shutup
166    close_rx: UnsafeCell<AsyncRx<mpsc::Null>>,
167    closed: AtomicBool, // flag set by either sender or receive on there exit
168    timer: UnsafeCell<ClientTaskTimer<F>>,
169    // TODO can closed and has_err merge ?
170    has_err: AtomicBool,
171    throttler: Throttler,
172    last_resp_ts: Option<Arc<AtomicU64>>,
173    encode_buf: UnsafeCell<Vec<u8>>,
174    codec: F::Codec,
175    logger: Arc<LogFilter>,
176    facts: Arc<F>,
177}
178
179unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
180
181unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
182
183impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
184    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
185        self.conn.fmt(f)
186    }
187}
188
189impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
190    pub fn new(
191        facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: AsyncRx<mpsc::Null>,
192        last_resp_ts: Option<Arc<AtomicU64>>,
193    ) -> Self {
194        let config = facts.get_config();
195        let mut thresholds = config.thresholds;
196        if thresholds == 0 {
197            thresholds = 128;
198        }
199        let client_inner = Self {
200            client_id,
201            conn,
202            close_rx: UnsafeCell::new(close_rx),
203            closed: AtomicBool::new(false),
204            seq: AtomicU64::new(1),
205            encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
206            throttler: Throttler::new(thresholds),
207            last_resp_ts,
208            has_err: AtomicBool::new(false),
209            codec: F::Codec::default(),
210            logger: facts.new_logger(),
211            timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
212            facts,
213        };
214        logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
215        client_inner
216    }
217
218    #[inline(always)]
219    fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
220        unsafe { transmute(self.timer.get()) }
221    }
222
223    #[inline(always)]
224    fn get_close_rx(&self) -> &mut AsyncRx<mpsc::Null> {
225        unsafe { transmute(self.close_rx.get()) }
226    }
227
228    #[inline(always)]
229    fn get_encoded_buf(&self) -> &mut Vec<u8> {
230        unsafe { transmute(self.encode_buf.get()) }
231    }
232
233    /// Directly work on the socket steam, when failed
234    async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
235        if self.throttler.nearly_full() {
236            need_flush = true;
237        }
238        let timer = self.get_timer_mut();
239        timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
240        // It's possible receiver set close after pending_task_count increase, keep this until
241        // review
242        if self.closed.load(Ordering::Acquire) {
243            logger_warn!(
244                self.logger,
245                "{:?} sending task {:?} failed: {}",
246                self,
247                task,
248                RpcIntErr::IO,
249            );
250            task.set_rpc_error(RpcIntErr::IO);
251            self.facts.error_handle(task);
252            timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); // rollback
253            return Err(RpcIntErr::IO);
254        }
255        match self.send_request(task, need_flush).await {
256            Err(_) => {
257                self.closed.store(true, Ordering::SeqCst);
258                self.has_err.store(true, Ordering::SeqCst);
259                return Err(RpcIntErr::IO);
260            }
261            Ok(_) => {
262                // register task to norifier
263                self.throttler.throttle().await;
264                return Ok(());
265            }
266        }
267    }
268
269    #[inline(always)]
270    async fn flush_req(&self) -> Result<(), RpcIntErr> {
271        if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
272            logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
273            self.closed.store(true, Ordering::SeqCst);
274            self.has_err.store(true, Ordering::SeqCst);
275            let timer = self.get_timer_mut();
276            timer.stop_reg_task();
277            return Err(RpcIntErr::IO);
278        }
279        Ok(())
280    }
281
282    #[inline(always)]
283    async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
284        let seq = self.seq_update();
285        task.set_seq(seq);
286        let buf = self.get_encoded_buf();
287        match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
288            Err(_) => {
289                logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
290                return Err(RpcIntErr::Encode);
291            }
292            Ok(blob_buf) => {
293                if let Err(e) =
294                    self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
295                {
296                    logger_warn!(
297                        self.logger,
298                        "{:?} send_req write req {:?} err: {:?}",
299                        self,
300                        task,
301                        e
302                    );
303                    self.closed.store(true, Ordering::SeqCst);
304                    self.has_err.store(true, Ordering::SeqCst);
305                    let timer = self.get_timer_mut();
306                    // TODO check stop_reg_task
307                    // rollback counter
308                    timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
309                    timer.stop_reg_task();
310                    logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
311                    task.set_rpc_error(RpcIntErr::IO);
312                    self.facts.error_handle(task);
313                    return Err(RpcIntErr::IO);
314                } else {
315                    let wg = self.throttler.add_task();
316                    let timer = self.get_timer_mut();
317                    logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
318                    timer.reg_task(task, wg).await;
319                }
320                return Ok(());
321            }
322        }
323    }
324
325    #[inline(always)]
326    async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
327        if self.closed.load(Ordering::Acquire) {
328            logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
329            return Err(RpcIntErr::IO);
330        }
331        // PING does not counted in throttler
332        let buf = self.get_encoded_buf();
333        proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
334        // Ping does not need to reg_task, and have no error_handle, just to keep the connection
335        // alive. Connection Prober can monitor the liveness of ClientConn
336        if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
337            logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
338            self.closed.store(true, Ordering::SeqCst);
339            return Err(RpcIntErr::IO);
340        }
341        Ok(())
342    }
343
344    // return Ok(false) when close_rx has close and nothing more pending resp to receive
345    async fn recv_some(&self) -> Result<(), RpcIntErr> {
346        for _ in 0i32..20 {
347            // Underlayer rpc socket is buffered, might not yeal to runtime
348            // return if recv_one_resp runs too long, allow timer to be fire at each second
349            match self.recv_one_resp().await {
350                Err(e) => {
351                    return Err(e);
352                }
353                Ok(_) => {
354                    if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
355                        last_resp_ts.store(self.facts.get_timestamp(), Ordering::Release);
356                    }
357                }
358            }
359        }
360        Ok(())
361    }
362
363    async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
364        let timer = self.get_timer_mut();
365        loop {
366            if self.closed.load(Ordering::Acquire) {
367                logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
368                // ensure task receive on normal exit
369                if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
370                    return Err(RpcIntErr::IO);
371                }
372                // When ClientStream(sender) dropped, receiver will be timer
373                if let Err(_e) = self
374                    .conn
375                    .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
376                    .await
377                {
378                    self.closed.store(true, Ordering::SeqCst);
379                    return Err(RpcIntErr::IO);
380                }
381            } else {
382                // Block here for new header without timeout
383                // NOTE: because close_rx is AsyncRx, which does not have Sync, to solve the & does
384                // not have Send problem, we convert this to mut ref to make borrow checker shutup
385                match self
386                    .conn
387                    .read_resp(
388                        self.facts.as_ref(),
389                        &self.logger,
390                        &self.codec,
391                        Some(self.get_close_rx()),
392                        timer,
393                    )
394                    .await
395                {
396                    Err(_e) => {
397                        return Err(RpcIntErr::IO);
398                    }
399                    Ok(r) => {
400                        // TODO FIXME
401                        if !r {
402                            self.closed.store(true, Ordering::SeqCst);
403                            continue;
404                        }
405                    }
406                }
407            }
408        }
409    }
410
411    async fn receive_loop(&self) {
412        let mut tick = <F as AsyncTime>::interval(Duration::from_secs(1));
413        loop {
414            let f = self.recv_some();
415            pin_mut!(f);
416            let selector = ReceiverTimerFuture::new(self, &mut tick, &mut f);
417            match selector.await {
418                Ok(_) => {}
419                Err(e) => {
420                    logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
421                    self.closed.store(true, Ordering::SeqCst);
422                    let timer = self.get_timer_mut();
423                    timer.clean_pending_tasks(self.facts.as_ref());
424                    // If pending_task_count > 0 means some tasks may still remain in the pending chan
425                    while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
426                        // After the 'closed' flag has taken effect,
427                        // pending_task_count will not keep growing,
428                        // so there is no need to sleep here.
429                        timer.clean_pending_tasks(self.facts.as_ref());
430                        <F as AsyncTime>::sleep(Duration::from_secs(1)).await;
431                    }
432                    return;
433                }
434            }
435        }
436    }
437
438    // Adjust the waiting queue
439    fn time_reach(&self) {
440        logger_trace!(
441            self.logger,
442            "{:?} has {} pending_tasks",
443            self,
444            self.throttler.get_inflight_count()
445        );
446        let timer = self.get_timer_mut();
447        timer.adjust_task_queue(self.facts.as_ref());
448        return;
449    }
450
451    #[inline(always)]
452    fn seq_update(&self) -> u64 {
453        self.seq.fetch_add(1, Ordering::SeqCst)
454    }
455}
456
457impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
458    fn drop(&mut self) {
459        let timer = self.get_timer_mut();
460        timer.clean_pending_tasks(self.facts.as_ref());
461    }
462}
463
464struct ReceiverTimerFuture<'a, F, P, I, FR>
465where
466    F: ClientFacts,
467    P: ClientTransport,
468    I: TimeInterval,
469    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
470{
471    client: &'a ClientStreamInner<F, P>,
472    inv: Pin<&'a mut I>,
473    recv_future: Pin<&'a mut FR>,
474}
475
476impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
477where
478    F: ClientFacts,
479    P: ClientTransport,
480    I: TimeInterval,
481    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
482{
483    fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
484        Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
485    }
486}
487
488// Return Ok(true) to indicate Ok
489// Return Ok(false) when client sender has close normally
490// Err(e) when connection error
491impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
492where
493    F: ClientFacts,
494    P: ClientTransport,
495    I: TimeInterval,
496    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
497{
498    type Output = Result<(), RpcIntErr>;
499
500    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
501        let mut _self = self.get_mut();
502        // In case ticker not fire, and ensure ticker schedule after ready
503        while _self.inv.as_mut().poll_tick(ctx).is_ready() {
504            _self.client.time_reach();
505        }
506        if _self.client.has_err.load(Ordering::Relaxed) {
507            // When sentinel detect peer unreachable, recv_some mighe blocked, at least inv will
508            // wait us, just exit
509            return Poll::Ready(Err(RpcIntErr::IO));
510        }
511        _self.client.get_timer_mut().poll_sent_task(ctx);
512        // Even if receive future has block, we should poll_sent_task in order to detect timeout event
513        if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
514            return Poll::Ready(r);
515        }
516        return Poll::Pending;
517    }
518}