Skip to main content

razor_stream/client/
stream.rs

1//! [ClientStream] represents a client-side connection.
2//!
3//! On Drop, will close the connection on write-side, the response reader coroutine not exit
4//! until all the ClientTask have a response or after `task_timeout` is reached.
5//!
6//! The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
7//! An internal timer then registers the request through a channel, and when the response
8//! is received, it can optionally notify the user through a user-defined channel or another mechanism.
9
10use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::null::CloseHandle;
16use futures_util::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20    cell::UnsafeCell,
21    fmt,
22    future::Future,
23    mem::transmute,
24    pin::Pin,
25    sync::{
26        Arc,
27        atomic::{AtomicBool, AtomicU64, Ordering},
28    },
29    task::{Context, Poll},
30};
31
32/// ClientStream represents a client-side connection.
33///
34/// On Drop, the connection will be closed on the write-side. The response reader coroutine will not exit
35/// until all the ClientTasks have a response or after `task_timeout` is reached.
36///
37/// The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
38/// An internal timer then registers the request through a channel, and when the response
39/// is received, it can optionally notify the user through a user-defined channel or another mechanism.
40pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
41    close_tx: Option<CloseHandle<mpsc::Null>>,
42    inner: Arc<ClientStreamInner<F, P>>,
43}
44
45impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
46    /// Make a streaming connection to the server, returns [ClientStream] on success
47    #[inline]
48    pub async fn connect(
49        facts: Arc<F>, rt: Option<&<P::RT as AsyncRuntime>::Exec>, addr: &str, conn_id: &str,
50        last_resp_ts: Option<Arc<AtomicU64>>,
51    ) -> Result<Self, RpcIntErr> {
52        let client_id = facts.get_client_id();
53        let conn = P::connect(addr, conn_id, facts.get_config()).await?;
54        let this = Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts);
55        let inner = this.inner.clone();
56        let f = inner.receive_loop();
57        if let Some(_rt) = rt {
58            _rt.spawn_detach(f);
59        } else {
60            P::RT::spawn_detach(f);
61        }
62        Ok(this)
63    }
64
65    #[inline]
66    fn new(
67        facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
68        last_resp_ts: Option<Arc<AtomicU64>>,
69    ) -> Self {
70        let (_close_tx, _close_rx) = mpsc::new::<mpsc::Null, _, _>();
71        let inner = Arc::new(ClientStreamInner::new(
72            facts,
73            conn,
74            client_id,
75            conn_id,
76            _close_rx,
77            last_resp_ts,
78        ));
79        logger_debug!(inner.logger, "{:?} connected", inner);
80        Self { close_tx: Some(_close_tx), inner }
81    }
82
83    #[inline]
84    pub fn get_codec(&self) -> &F::Codec {
85        &self.inner.codec
86    }
87
88    /// Should be call in sender threads
89    ///
90    /// NOTE: will skip if throttler is full
91    #[inline(always)]
92    pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
93        self.inner.send_ping_req().await
94    }
95
96    #[inline(always)]
97    pub fn get_last_resp_ts(&self) -> u64 {
98        if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
99    }
100
101    /// Since sender and receiver are two threads, might be close on either side
102    #[inline(always)]
103    pub fn is_closed(&self) -> bool {
104        self.inner.closed.load(Ordering::SeqCst)
105    }
106
107    /// Force the receiver to exit.
108    ///
109    /// You can call it when connectivity probes detect that a server is unreachable.
110    /// And then just let the Client drop
111    pub async fn set_error_and_exit(&mut self) {
112        // TODO review usage when doing ConnProbe
113        self.inner.has_err.store(true, Ordering::SeqCst);
114        self.inner.conn.close_conn::<F>(&self.inner.logger).await;
115    }
116
117    /// send_task() should only be called without parallelism.
118    ///
119    /// NOTE: After send, will wait for response if too many inflight task in throttler.
120    ///
121    /// Since the transport layer might have buffer, user should always call flush explicitly.
122    /// You can set `need_flush` = true for some urgent messages, or call flush_req() explicitly.
123    ///
124    #[inline(always)]
125    pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
126        self.inner.send_task(task, need_flush).await
127    }
128
129    /// Since the transport layer might have buffer, user should always call flush explicitly.
130    /// you can set `need_flush` = true for some urgent message, or call flush_req() explicitly.
131    #[inline(always)]
132    pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
133        self.inner.flush_req().await
134    }
135
136    /// Check the throttler and see if future send_task() might be blocked
137    #[inline]
138    pub fn will_block(&self) -> bool {
139        self.inner.throttler.nearly_full()
140    }
141
142    /// Get the task sent but not yet received response
143    #[inline]
144    pub fn get_inflight_count(&self) -> usize {
145        // TODO confirm ping task counted ?
146        self.inner.throttler.get_inflight_count()
147    }
148}
149
150impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
151    fn drop(&mut self) {
152        self.close_tx.take();
153        let timer = self.inner.get_timer_mut();
154        timer.stop_reg_task();
155        self.inner.closed.store(true, Ordering::SeqCst);
156    }
157}
158
159impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
160    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
161        self.inner.fmt(f)
162    }
163}
164
165struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
166    client_id: u64,
167    conn: P,
168    seq: AtomicU64,
169    // NOTE: because close_rx is AsyncRx (lower cost to register waker), but does not have Sync, to solve the & does
170    // not have Send problem, we convert this to mut ref to make borrow checker shutup
171    close_rx: UnsafeCell<AsyncRx<mpsc::Null>>,
172    closed: AtomicBool, // flag set by either sender or receive on there exit
173    timer: UnsafeCell<ClientTaskTimer<F>>,
174    // TODO can closed and has_err merge ?
175    has_err: AtomicBool,
176    throttler: Throttler,
177    last_resp_ts: Option<Arc<AtomicU64>>,
178    encode_buf: UnsafeCell<Vec<u8>>,
179    codec: F::Codec,
180    logger: Arc<LogFilter>,
181    facts: Arc<F>,
182}
183
184unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
185
186unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
187
188impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
189    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190        self.conn.fmt(f)
191    }
192}
193
194impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
195    pub fn new(
196        facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: AsyncRx<mpsc::Null>,
197        last_resp_ts: Option<Arc<AtomicU64>>,
198    ) -> Self {
199        let config = facts.get_config();
200        let mut thresholds = config.thresholds;
201        if thresholds == 0 {
202            thresholds = 128;
203        }
204        let client_inner = Self {
205            client_id,
206            conn,
207            close_rx: UnsafeCell::new(close_rx),
208            closed: AtomicBool::new(false),
209            seq: AtomicU64::new(1),
210            encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
211            throttler: Throttler::new(thresholds),
212            last_resp_ts,
213            has_err: AtomicBool::new(false),
214            codec: F::Codec::default(),
215            logger: facts.new_logger(),
216            timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
217            facts,
218        };
219        logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
220        client_inner
221    }
222
223    #[inline(always)]
224    fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
225        unsafe { transmute(self.timer.get()) }
226    }
227
228    #[inline(always)]
229    fn get_close_rx(&self) -> &mut AsyncRx<mpsc::Null> {
230        unsafe { transmute(self.close_rx.get()) }
231    }
232
233    #[inline(always)]
234    fn get_encoded_buf(&self) -> &mut Vec<u8> {
235        unsafe { transmute(self.encode_buf.get()) }
236    }
237
238    /// Directly work on the socket steam, when failed
239    async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
240        if self.throttler.nearly_full() {
241            need_flush = true;
242        }
243        let timer = self.get_timer_mut();
244        timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
245        // It's possible receiver set close after pending_task_count increase, keep this until
246        // review
247        if self.closed.load(Ordering::Acquire) {
248            logger_warn!(
249                self.logger,
250                "{:?} sending task {:?} failed: {}",
251                self,
252                task,
253                RpcIntErr::IO,
254            );
255            task.set_rpc_error(RpcIntErr::IO);
256            self.facts.error_handle(task);
257            timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); // rollback
258            return Err(RpcIntErr::IO);
259        }
260        match self.send_request(task, need_flush).await {
261            Err(_) => {
262                self.closed.store(true, Ordering::SeqCst);
263                self.has_err.store(true, Ordering::SeqCst);
264                return Err(RpcIntErr::IO);
265            }
266            Ok(_) => {
267                // register task to norifier
268                self.throttler.throttle().await;
269                return Ok(());
270            }
271        }
272    }
273
274    #[inline(always)]
275    async fn flush_req(&self) -> Result<(), RpcIntErr> {
276        if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
277            logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
278            self.closed.store(true, Ordering::SeqCst);
279            self.has_err.store(true, Ordering::SeqCst);
280            let timer = self.get_timer_mut();
281            timer.stop_reg_task();
282            return Err(RpcIntErr::IO);
283        }
284        Ok(())
285    }
286
287    #[inline(always)]
288    async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
289        let seq = self.seq_update();
290        task.set_seq(seq);
291        let buf = self.get_encoded_buf();
292        match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
293            Err(_) => {
294                logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
295                return Err(RpcIntErr::Encode);
296            }
297            Ok(blob_buf) => {
298                if let Err(e) =
299                    self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
300                {
301                    logger_warn!(
302                        self.logger,
303                        "{:?} send_req write req {:?} err: {:?}",
304                        self,
305                        task,
306                        e
307                    );
308                    self.closed.store(true, Ordering::SeqCst);
309                    self.has_err.store(true, Ordering::SeqCst);
310                    let timer = self.get_timer_mut();
311                    // TODO check stop_reg_task
312                    // rollback counter
313                    timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
314                    timer.stop_reg_task();
315                    logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
316                    task.set_rpc_error(RpcIntErr::IO);
317                    self.facts.error_handle(task);
318                    return Err(RpcIntErr::IO);
319                } else {
320                    let wg = self.throttler.add_task();
321                    let timer = self.get_timer_mut();
322                    logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
323                    timer.reg_task(task, wg).await;
324                }
325                return Ok(());
326            }
327        }
328    }
329
330    #[inline(always)]
331    async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
332        if self.closed.load(Ordering::Acquire) {
333            logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
334            return Err(RpcIntErr::IO);
335        }
336        // PING does not counted in throttler
337        let buf = self.get_encoded_buf();
338        proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
339        // Ping does not need to reg_task, and have no error_handle, just to keep the connection
340        // alive. Connection Prober can monitor the liveness of ClientConn
341        if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
342            logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
343            self.closed.store(true, Ordering::SeqCst);
344            return Err(RpcIntErr::IO);
345        }
346        Ok(())
347    }
348
349    // return Ok(false) when close_rx has close and nothing more pending resp to receive
350    async fn recv_some(&self) -> Result<(), RpcIntErr> {
351        for _ in 0i32..20 {
352            // Underlayer rpc socket is buffered, might not yeal to runtime
353            // return if recv_one_resp runs too long, allow timer to be fire at each second
354            match self.recv_one_resp().await {
355                Err(e) => {
356                    return Err(e);
357                }
358                Ok(_) => {
359                    if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
360                        last_resp_ts.store(self.facts.get_timestamp(), Ordering::Release);
361                    }
362                }
363            }
364        }
365        Ok(())
366    }
367
368    async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
369        let timer = self.get_timer_mut();
370        loop {
371            if self.closed.load(Ordering::Acquire) {
372                logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
373                // ensure task receive on normal exit
374                if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
375                    return Err(RpcIntErr::IO);
376                }
377                // When ClientStream(sender) dropped, receiver will be timer
378                if let Err(_e) = self
379                    .conn
380                    .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
381                    .await
382                {
383                    self.closed.store(true, Ordering::SeqCst);
384                    return Err(RpcIntErr::IO);
385                }
386            } else {
387                // Block here for new header without timeout
388                // NOTE: because close_rx is AsyncRx, which does not have Sync, to solve the & does
389                // not have Send problem, we convert this to mut ref to make borrow checker shutup
390                match self
391                    .conn
392                    .read_resp(
393                        self.facts.as_ref(),
394                        &self.logger,
395                        &self.codec,
396                        Some(self.get_close_rx()),
397                        timer,
398                    )
399                    .await
400                {
401                    Err(_e) => {
402                        return Err(RpcIntErr::IO);
403                    }
404                    Ok(r) => {
405                        // TODO FIXME
406                        if !r {
407                            self.closed.store(true, Ordering::SeqCst);
408                            continue;
409                        }
410                    }
411                }
412            }
413        }
414    }
415
416    async fn receive_loop(self: Arc<Self>) {
417        let mut tick = <P::RT as AsyncTime>::interval(Duration::from_secs(1));
418        loop {
419            let f = self.recv_some();
420            pin_mut!(f);
421            let selector = ReceiverTimerFuture::new(&self, &mut tick, &mut f);
422            match selector.await {
423                Ok(_) => {}
424                Err(e) => {
425                    logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
426                    self.closed.store(true, Ordering::SeqCst);
427                    let timer = self.get_timer_mut();
428                    timer.clean_pending_tasks(self.facts.as_ref());
429                    // If pending_task_count > 0 means some tasks may still remain in the pending chan
430                    while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
431                        // After the 'closed' flag has taken effect,
432                        // pending_task_count will not keep growing,
433                        // so there is no need to sleep here.
434                        timer.clean_pending_tasks(self.facts.as_ref());
435                        <P::RT as AsyncTime>::sleep(Duration::from_secs(1)).await;
436                    }
437                    return;
438                }
439            }
440        }
441    }
442
443    // Adjust the waiting queue
444    fn time_reach(&self) {
445        logger_trace!(
446            self.logger,
447            "{:?} has {} pending_tasks",
448            self,
449            self.throttler.get_inflight_count()
450        );
451        let timer = self.get_timer_mut();
452        timer.adjust_task_queue(self.facts.as_ref());
453        return;
454    }
455
456    #[inline(always)]
457    fn seq_update(&self) -> u64 {
458        self.seq.fetch_add(1, Ordering::SeqCst)
459    }
460}
461
462impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
463    fn drop(&mut self) {
464        let timer = self.get_timer_mut();
465        timer.clean_pending_tasks(self.facts.as_ref());
466    }
467}
468
469struct ReceiverTimerFuture<'a, F, P, I, FR>
470where
471    F: ClientFacts,
472    P: ClientTransport,
473    I: TimeInterval,
474    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
475{
476    client: &'a ClientStreamInner<F, P>,
477    inv: Pin<&'a mut I>,
478    recv_future: Pin<&'a mut FR>,
479}
480
481impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
482where
483    F: ClientFacts,
484    P: ClientTransport,
485    I: TimeInterval,
486    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
487{
488    fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
489        Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
490    }
491}
492
493// Return Ok(true) to indicate Ok
494// Return Ok(false) when client sender has close normally
495// Err(e) when connection error
496impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
497where
498    F: ClientFacts,
499    P: ClientTransport,
500    I: TimeInterval,
501    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
502{
503    type Output = Result<(), RpcIntErr>;
504
505    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
506        let mut _self = self.get_mut();
507        // In case ticker not fire, and ensure ticker schedule after ready
508        while _self.inv.as_mut().poll_tick(ctx).is_ready() {
509            _self.client.time_reach();
510        }
511        if _self.client.has_err.load(Ordering::Relaxed) {
512            // When sentinel detect peer unreachable, recv_some mighe blocked, at least inv will
513            // wait us, just exit
514            return Poll::Ready(Err(RpcIntErr::IO));
515        }
516        _self.client.get_timer_mut().poll_sent_task(ctx);
517        // Even if receive future has block, we should poll_sent_task in order to detect timeout event
518        if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
519            return Poll::Ready(r);
520        }
521        return Poll::Pending;
522    }
523}