Skip to main content

razor_stream/client/
stream.rs

1//! [ClientStream] represents a client-side connection.
2//!
3//! On Drop, will close the connection on write-side, the response reader coroutine not exit
4//! until all the ClientTask have a response or after `task_timeout` is reached.
5//!
6//! The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
7//! An internal timer then registers the request through a channel, and when the response
8//! is received, it can optionally notify the user through a user-defined channel or another mechanism.
9
10use super::throttler::Throttler;
11use crate::client::task::ClientTaskDone;
12use crate::client::timer::ClientTaskTimer;
13use crate::{client::*, proto};
14use captains_log::filter::LogFilter;
15use crossfire::null::CloseHandle;
16use futures_util::pin_mut;
17use orb::prelude::*;
18use std::time::Duration;
19use std::{
20    cell::UnsafeCell,
21    fmt,
22    future::Future,
23    mem::transmute,
24    pin::Pin,
25    sync::{
26        Arc,
27        atomic::{AtomicBool, AtomicU64, Ordering},
28    },
29    task::{Context, Poll},
30};
31
32/// ClientStream represents a client-side connection.
33///
34/// On Drop, the connection will be closed on the write-side. The response reader coroutine will not exit
35/// until all the ClientTasks have a response or after `task_timeout` is reached.
36///
37/// The user sends packets in sequence, with a throttler controlling the IO depth of in-flight packets.
38/// An internal timer then registers the request through a channel, and when the response
39/// is received, it can optionally notify the user through a user-defined channel or another mechanism.
40pub struct ClientStream<F: ClientFacts, P: ClientTransport> {
41    close_tx: Option<CloseHandle<mpsc::Null>>,
42    inner: Arc<ClientStreamInner<F, P>>,
43}
44
45impl<F: ClientFacts, P: ClientTransport> ClientStream<F, P> {
46    /// Make a streaming connection to the server, returns [ClientStream] on success
47    #[inline]
48    pub async fn connect<RT: orb::AsyncRuntime>(
49        facts: Arc<F>, rt: &RT, addr: &str, conn_id: &str, last_resp_ts: Option<Arc<AtomicU64>>,
50    ) -> Result<Self, RpcIntErr> {
51        let client_id = facts.get_client_id();
52        let conn = P::connect(addr, conn_id, facts.get_config()).await?;
53        let this = Self::new(facts, conn, client_id, conn_id.to_string(), last_resp_ts);
54        let inner = this.inner.clone();
55        rt.spawn_detach(async move {
56            inner.receive_loop::<RT>().await;
57        });
58        Ok(this)
59    }
60
61    #[inline]
62    fn new(
63        facts: Arc<F>, conn: P, client_id: u64, conn_id: String,
64        last_resp_ts: Option<Arc<AtomicU64>>,
65    ) -> Self {
66        let (_close_tx, _close_rx) = mpsc::new::<mpsc::Null, _, _>();
67        let inner = Arc::new(ClientStreamInner::new(
68            facts,
69            conn,
70            client_id,
71            conn_id,
72            _close_rx,
73            last_resp_ts,
74        ));
75        logger_debug!(inner.logger, "{:?} connected", inner);
76        Self { close_tx: Some(_close_tx), inner }
77    }
78
79    #[inline]
80    pub fn get_codec(&self) -> &F::Codec {
81        &self.inner.codec
82    }
83
84    /// Should be call in sender threads
85    ///
86    /// NOTE: will skip if throttler is full
87    #[inline(always)]
88    pub async fn ping(&mut self) -> Result<(), RpcIntErr> {
89        self.inner.send_ping_req().await
90    }
91
92    #[inline(always)]
93    pub fn get_last_resp_ts(&self) -> u64 {
94        if let Some(ts) = self.inner.last_resp_ts.as_ref() { ts.load(Ordering::Relaxed) } else { 0 }
95    }
96
97    /// Since sender and receiver are two threads, might be close on either side
98    #[inline(always)]
99    pub fn is_closed(&self) -> bool {
100        self.inner.closed.load(Ordering::SeqCst)
101    }
102
103    /// Force the receiver to exit.
104    ///
105    /// You can call it when connectivity probes detect that a server is unreachable.
106    /// And then just let the Client drop
107    pub async fn set_error_and_exit(&mut self) {
108        // TODO review usage when doing ConnProbe
109        self.inner.has_err.store(true, Ordering::SeqCst);
110        self.inner.conn.close_conn::<F>(&self.inner.logger).await;
111    }
112
113    /// send_task() should only be called without parallelism.
114    ///
115    /// NOTE: After send, will wait for response if too many inflight task in throttler.
116    ///
117    /// Since the transport layer might have buffer, user should always call flush explicitly.
118    /// You can set `need_flush` = true for some urgent messages, or call flush_req() explicitly.
119    ///
120    #[inline(always)]
121    pub async fn send_task(&mut self, task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
122        self.inner.send_task(task, need_flush).await
123    }
124
125    /// Since the transport layer might have buffer, user should always call flush explicitly.
126    /// you can set `need_flush` = true for some urgent message, or call flush_req() explicitly.
127    #[inline(always)]
128    pub async fn flush_req(&mut self) -> Result<(), RpcIntErr> {
129        self.inner.flush_req().await
130    }
131
132    /// Check the throttler and see if future send_task() might be blocked
133    #[inline]
134    pub fn will_block(&self) -> bool {
135        self.inner.throttler.nearly_full()
136    }
137
138    /// Get the task sent but not yet received response
139    #[inline]
140    pub fn get_inflight_count(&self) -> usize {
141        // TODO confirm ping task counted ?
142        self.inner.throttler.get_inflight_count()
143    }
144}
145
146impl<F: ClientFacts, P: ClientTransport> Drop for ClientStream<F, P> {
147    fn drop(&mut self) {
148        self.close_tx.take();
149        let timer = self.inner.get_timer_mut();
150        timer.stop_reg_task();
151        self.inner.closed.store(true, Ordering::SeqCst);
152    }
153}
154
155impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStream<F, P> {
156    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
157        self.inner.fmt(f)
158    }
159}
160
161struct ClientStreamInner<F: ClientFacts, P: ClientTransport> {
162    client_id: u64,
163    conn: P,
164    seq: AtomicU64,
165    // NOTE: because close_rx is AsyncRx (lower cost to register waker), but does not have Sync, to solve the & does
166    // not have Send problem, we convert this to mut ref to make borrow checker shutup
167    close_rx: UnsafeCell<AsyncRx<mpsc::Null>>,
168    closed: AtomicBool, // flag set by either sender or receive on there exit
169    timer: UnsafeCell<ClientTaskTimer<F>>,
170    // TODO can closed and has_err merge ?
171    has_err: AtomicBool,
172    throttler: Throttler,
173    last_resp_ts: Option<Arc<AtomicU64>>,
174    encode_buf: UnsafeCell<Vec<u8>>,
175    codec: F::Codec,
176    logger: Arc<LogFilter>,
177    facts: Arc<F>,
178}
179
180unsafe impl<F: ClientFacts, P: ClientTransport> Send for ClientStreamInner<F, P> {}
181
182unsafe impl<F: ClientFacts, P: ClientTransport> Sync for ClientStreamInner<F, P> {}
183
184impl<F: ClientFacts, P: ClientTransport> fmt::Debug for ClientStreamInner<F, P> {
185    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
186        self.conn.fmt(f)
187    }
188}
189
190impl<F: ClientFacts, P: ClientTransport> ClientStreamInner<F, P> {
191    pub fn new(
192        facts: Arc<F>, conn: P, client_id: u64, conn_id: String, close_rx: AsyncRx<mpsc::Null>,
193        last_resp_ts: Option<Arc<AtomicU64>>,
194    ) -> Self {
195        let config = facts.get_config();
196        let mut thresholds = config.thresholds;
197        if thresholds == 0 {
198            thresholds = 128;
199        }
200        let client_inner = Self {
201            client_id,
202            conn,
203            close_rx: UnsafeCell::new(close_rx),
204            closed: AtomicBool::new(false),
205            seq: AtomicU64::new(1),
206            encode_buf: UnsafeCell::new(Vec::with_capacity(1024)),
207            throttler: Throttler::new(thresholds),
208            last_resp_ts,
209            has_err: AtomicBool::new(false),
210            codec: F::Codec::default(),
211            logger: facts.new_logger(),
212            timer: UnsafeCell::new(ClientTaskTimer::new(conn_id, config.task_timeout, thresholds)),
213            facts,
214        };
215        logger_trace!(client_inner.logger, "{:?} throttler is set to {}", client_inner, thresholds,);
216        client_inner
217    }
218
219    #[inline(always)]
220    fn get_timer_mut(&self) -> &mut ClientTaskTimer<F> {
221        unsafe { transmute(self.timer.get()) }
222    }
223
224    #[inline(always)]
225    fn get_close_rx(&self) -> &mut AsyncRx<mpsc::Null> {
226        unsafe { transmute(self.close_rx.get()) }
227    }
228
229    #[inline(always)]
230    fn get_encoded_buf(&self) -> &mut Vec<u8> {
231        unsafe { transmute(self.encode_buf.get()) }
232    }
233
234    /// Directly work on the socket steam, when failed
235    async fn send_task(&self, mut task: F::Task, mut need_flush: bool) -> Result<(), RpcIntErr> {
236        if self.throttler.nearly_full() {
237            need_flush = true;
238        }
239        let timer = self.get_timer_mut();
240        timer.pending_task_count_ref().fetch_add(1, Ordering::SeqCst);
241        // It's possible receiver set close after pending_task_count increase, keep this until
242        // review
243        if self.closed.load(Ordering::Acquire) {
244            logger_warn!(
245                self.logger,
246                "{:?} sending task {:?} failed: {}",
247                self,
248                task,
249                RpcIntErr::IO,
250            );
251            task.set_rpc_error(RpcIntErr::IO);
252            self.facts.error_handle(task);
253            timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst); // rollback
254            return Err(RpcIntErr::IO);
255        }
256        match self.send_request(task, need_flush).await {
257            Err(_) => {
258                self.closed.store(true, Ordering::SeqCst);
259                self.has_err.store(true, Ordering::SeqCst);
260                return Err(RpcIntErr::IO);
261            }
262            Ok(_) => {
263                // register task to norifier
264                self.throttler.throttle().await;
265                return Ok(());
266            }
267        }
268    }
269
270    #[inline(always)]
271    async fn flush_req(&self) -> Result<(), RpcIntErr> {
272        if let Err(e) = self.conn.flush_req::<F>(&self.logger).await {
273            logger_warn!(self.logger, "{:?} flush_req flush err: {}", self, e);
274            self.closed.store(true, Ordering::SeqCst);
275            self.has_err.store(true, Ordering::SeqCst);
276            let timer = self.get_timer_mut();
277            timer.stop_reg_task();
278            return Err(RpcIntErr::IO);
279        }
280        Ok(())
281    }
282
283    #[inline(always)]
284    async fn send_request(&self, mut task: F::Task, need_flush: bool) -> Result<(), RpcIntErr> {
285        let seq = self.seq_update();
286        task.set_seq(seq);
287        let buf = self.get_encoded_buf();
288        match proto::ReqHead::encode(&self.codec, buf, self.client_id, &task) {
289            Err(_) => {
290                logger_warn!(&self.logger, "{:?} send_req encode req {:?} err", self, task);
291                return Err(RpcIntErr::Encode);
292            }
293            Ok(blob_buf) => {
294                if let Err(e) =
295                    self.conn.write_req::<F>(&self.logger, buf, blob_buf, need_flush).await
296                {
297                    logger_warn!(
298                        self.logger,
299                        "{:?} send_req write req {:?} err: {:?}",
300                        self,
301                        task,
302                        e
303                    );
304                    self.closed.store(true, Ordering::SeqCst);
305                    self.has_err.store(true, Ordering::SeqCst);
306                    let timer = self.get_timer_mut();
307                    // TODO check stop_reg_task
308                    // rollback counter
309                    timer.pending_task_count_ref().fetch_sub(1, Ordering::SeqCst);
310                    timer.stop_reg_task();
311                    logger_warn!(self.logger, "{:?} sending task {:?} err: {}", self, task, e);
312                    task.set_rpc_error(RpcIntErr::IO);
313                    self.facts.error_handle(task);
314                    return Err(RpcIntErr::IO);
315                } else {
316                    let wg = self.throttler.add_task();
317                    let timer = self.get_timer_mut();
318                    logger_trace!(self.logger, "{:?} send task {:?} ok", self, task);
319                    timer.reg_task(task, wg).await;
320                }
321                return Ok(());
322            }
323        }
324    }
325
326    #[inline(always)]
327    async fn send_ping_req(&self) -> Result<(), RpcIntErr> {
328        if self.closed.load(Ordering::Acquire) {
329            logger_warn!(self.logger, "{:?} send_ping_req skip as conn closed", self);
330            return Err(RpcIntErr::IO);
331        }
332        // PING does not counted in throttler
333        let buf = self.get_encoded_buf();
334        proto::ReqHead::encode_ping(buf, self.client_id, self.seq_update());
335        // Ping does not need to reg_task, and have no error_handle, just to keep the connection
336        // alive. Connection Prober can monitor the liveness of ClientConn
337        if let Err(e) = self.conn.write_req::<F>(&self.logger, buf, None, true).await {
338            logger_warn!(self.logger, "{:?} send ping err: {:?}", self, e);
339            self.closed.store(true, Ordering::SeqCst);
340            return Err(RpcIntErr::IO);
341        }
342        Ok(())
343    }
344
345    // return Ok(false) when close_rx has close and nothing more pending resp to receive
346    async fn recv_some(&self) -> Result<(), RpcIntErr> {
347        for _ in 0i32..20 {
348            // Underlayer rpc socket is buffered, might not yeal to runtime
349            // return if recv_one_resp runs too long, allow timer to be fire at each second
350            match self.recv_one_resp().await {
351                Err(e) => {
352                    return Err(e);
353                }
354                Ok(_) => {
355                    if let Some(last_resp_ts) = self.last_resp_ts.as_ref() {
356                        last_resp_ts.store(self.facts.get_timestamp(), Ordering::Release);
357                    }
358                }
359            }
360        }
361        Ok(())
362    }
363
364    async fn recv_one_resp(&self) -> Result<(), RpcIntErr> {
365        let timer = self.get_timer_mut();
366        loop {
367            if self.closed.load(Ordering::Acquire) {
368                logger_trace!(self.logger, "{:?} read_resp from already close", self.conn);
369                // ensure task receive on normal exit
370                if timer.check_pending_tasks_empty() || self.has_err.load(Ordering::Relaxed) {
371                    return Err(RpcIntErr::IO);
372                }
373                // When ClientStream(sender) dropped, receiver will be timer
374                if let Err(_e) = self
375                    .conn
376                    .read_resp(self.facts.as_ref(), &self.logger, &self.codec, None, timer)
377                    .await
378                {
379                    self.closed.store(true, Ordering::SeqCst);
380                    return Err(RpcIntErr::IO);
381                }
382            } else {
383                // Block here for new header without timeout
384                // NOTE: because close_rx is AsyncRx, which does not have Sync, to solve the & does
385                // not have Send problem, we convert this to mut ref to make borrow checker shutup
386                match self
387                    .conn
388                    .read_resp(
389                        self.facts.as_ref(),
390                        &self.logger,
391                        &self.codec,
392                        Some(self.get_close_rx()),
393                        timer,
394                    )
395                    .await
396                {
397                    Err(_e) => {
398                        return Err(RpcIntErr::IO);
399                    }
400                    Ok(r) => {
401                        // TODO FIXME
402                        if !r {
403                            self.closed.store(true, Ordering::SeqCst);
404                            continue;
405                        }
406                    }
407                }
408            }
409        }
410    }
411
412    async fn receive_loop<RT: AsyncRuntime>(&self) {
413        let mut tick = <RT as AsyncTime>::interval(Duration::from_secs(1));
414        loop {
415            let f = self.recv_some();
416            pin_mut!(f);
417            let selector = ReceiverTimerFuture::new(self, &mut tick, &mut f);
418            match selector.await {
419                Ok(_) => {}
420                Err(e) => {
421                    logger_debug!(self.logger, "{:?} receive_loop error: {}", self, e);
422                    self.closed.store(true, Ordering::SeqCst);
423                    let timer = self.get_timer_mut();
424                    timer.clean_pending_tasks(self.facts.as_ref());
425                    // If pending_task_count > 0 means some tasks may still remain in the pending chan
426                    while timer.pending_task_count_ref().load(Ordering::SeqCst) > 0 {
427                        // After the 'closed' flag has taken effect,
428                        // pending_task_count will not keep growing,
429                        // so there is no need to sleep here.
430                        timer.clean_pending_tasks(self.facts.as_ref());
431                        <RT as AsyncTime>::sleep(Duration::from_secs(1)).await;
432                    }
433                    return;
434                }
435            }
436        }
437    }
438
439    // Adjust the waiting queue
440    fn time_reach(&self) {
441        logger_trace!(
442            self.logger,
443            "{:?} has {} pending_tasks",
444            self,
445            self.throttler.get_inflight_count()
446        );
447        let timer = self.get_timer_mut();
448        timer.adjust_task_queue(self.facts.as_ref());
449        return;
450    }
451
452    #[inline(always)]
453    fn seq_update(&self) -> u64 {
454        self.seq.fetch_add(1, Ordering::SeqCst)
455    }
456}
457
458impl<F: ClientFacts, P: ClientTransport> Drop for ClientStreamInner<F, P> {
459    fn drop(&mut self) {
460        let timer = self.get_timer_mut();
461        timer.clean_pending_tasks(self.facts.as_ref());
462    }
463}
464
465struct ReceiverTimerFuture<'a, F, P, I, FR>
466where
467    F: ClientFacts,
468    P: ClientTransport,
469    I: TimeInterval,
470    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
471{
472    client: &'a ClientStreamInner<F, P>,
473    inv: Pin<&'a mut I>,
474    recv_future: Pin<&'a mut FR>,
475}
476
477impl<'a, F, P, I, FR> ReceiverTimerFuture<'a, F, P, I, FR>
478where
479    F: ClientFacts,
480    P: ClientTransport,
481    I: TimeInterval,
482    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
483{
484    fn new(client: &'a ClientStreamInner<F, P>, inv: &'a mut I, recv_future: &'a mut FR) -> Self {
485        Self { inv: Pin::new(inv), client, recv_future: Pin::new(recv_future) }
486    }
487}
488
489// Return Ok(true) to indicate Ok
490// Return Ok(false) when client sender has close normally
491// Err(e) when connection error
492impl<'a, F, P, I, FR> Future for ReceiverTimerFuture<'a, F, P, I, FR>
493where
494    F: ClientFacts,
495    P: ClientTransport,
496    I: TimeInterval,
497    FR: Future<Output = Result<(), RpcIntErr>> + Unpin,
498{
499    type Output = Result<(), RpcIntErr>;
500
501    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
502        let mut _self = self.get_mut();
503        // In case ticker not fire, and ensure ticker schedule after ready
504        while _self.inv.as_mut().poll_tick(ctx).is_ready() {
505            _self.client.time_reach();
506        }
507        if _self.client.has_err.load(Ordering::Relaxed) {
508            // When sentinel detect peer unreachable, recv_some mighe blocked, at least inv will
509            // wait us, just exit
510            return Poll::Ready(Err(RpcIntErr::IO));
511        }
512        _self.client.get_timer_mut().poll_sent_task(ctx);
513        // Even if receive future has block, we should poll_sent_task in order to detect timeout event
514        if let Poll::Ready(r) = _self.recv_future.as_mut().poll(ctx) {
515            return Poll::Ready(r);
516        }
517        return Poll::Pending;
518    }
519}