rpc_it/
rpc.rs

1//! RPC connection implementation
2//!
3//! # Cautions
4//!
5//! - All `write` operation are not cancellation-safe in async context. Once you abort the async
6//!   write task such as `request`, `notify`, `response*` series, the write stream may remain in
7//!   corrupted state, which may invalidate any subsequent write operation.
8//!
9
10use std::{
11    any::Any,
12    fmt::Debug,
13    num::NonZeroUsize,
14    ops::Range,
15    sync::{Arc, Weak},
16};
17
18use crate::{
19    codec::{self, Codec, DecodeError, Framing, FramingError},
20    transport::{AsyncFrameRead, AsyncFrameWrite},
21};
22
23use async_lock::Mutex as AsyncMutex;
24use bitflags::bitflags;
25use bytes::{Bytes, BytesMut};
26use enum_as_inner::EnumAsInner;
27use futures_util::AsyncRead;
28pub use req::RequestContext;
29pub use req::{OwnedResponseFuture, ResponseFuture};
30
31/* ---------------------------------------------------------------------------------------------- */
32/*                                          FEATURE FLAGS                                         */
33/* ---------------------------------------------------------------------------------------------- */
34bitflags! {
35    #[derive(Default, Debug, Clone, Copy)]
36    pub struct Feature : u32 {
37        /// For inbound request messages, if the request object was dropped without sending
38        /// response, the background driver will automatically send 'unhandled' response to the
39        /// remote end.
40        ///
41        /// If you don't want any undesired response to be sent, or you're creating a client handle,
42        /// which usually does not receive any request, you can disable this feature.
43        const ENABLE_AUTO_RESPONSE =            1 << 1;
44
45        /// Do not receive any request from the remote end.
46        const NO_RECEIVE_REQUEST =              1 << 2;
47
48        /// Do not receive any notification from the remote end.
49        const NO_RECEIVE_NOTIFY =               1 << 3;
50    }
51}
52
53/* ---------------------------------------------------------------------------------------------- */
54/*                                          BACKED TRAIT                                          */
55/* ---------------------------------------------------------------------------------------------- */
56/// Creates RPC connection from [`crate::transport::AsyncReadFrame`] and
57/// [`crate::transport::AsyncWriteFrame`], and [`crate::codec::Codec`].
58///
59/// For unsupported features(e.g. notify from client), the codec should return
60/// [`crate::codec::EncodeError::UnsupportedFeature`] error.
61struct ConnectionImpl<C, T, R, U> {
62    codec: Arc<C>,
63    write: AsyncMutex<T>,
64    reqs: R,
65    tx_drive: flume::Sender<InboundDriverDirective>,
66    features: Feature,
67    user_data: U,
68    _unpin: std::marker::PhantomPinned,
69}
70
71/// Wraps connection implementation with virtual dispatch.
72trait Connection: Send + Sync + 'static + Debug {
73    fn codec(&self) -> &dyn Codec;
74    fn write(&self) -> &AsyncMutex<dyn AsyncFrameWrite>;
75    fn reqs(&self) -> Option<&RequestContext>;
76    fn tx_drive(&self) -> &flume::Sender<InboundDriverDirective>;
77    fn feature_flag(&self) -> Feature;
78    fn user_data(&self) -> &dyn UserData;
79}
80
81impl<C, T, R, U> Connection for ConnectionImpl<C, T, R, U>
82where
83    C: Codec,
84    T: AsyncFrameWrite,
85    R: GetRequestContext,
86    U: UserData,
87{
88    fn codec(&self) -> &dyn Codec {
89        &*self.codec
90    }
91
92    fn write(&self) -> &AsyncMutex<dyn AsyncFrameWrite> {
93        &self.write
94    }
95
96    fn reqs(&self) -> Option<&RequestContext> {
97        self.reqs.get_req_con()
98    }
99
100    fn tx_drive(&self) -> &flume::Sender<InboundDriverDirective> {
101        &self.tx_drive
102    }
103
104    fn feature_flag(&self) -> Feature {
105        self.features
106    }
107
108    fn user_data(&self) -> &dyn UserData {
109        &self.user_data
110    }
111}
112
113impl<C, T, R, U> ConnectionImpl<C, T, R, U>
114where
115    C: Codec,
116    T: AsyncFrameWrite,
117    R: GetRequestContext,
118    U: UserData,
119{
120    fn dyn_ref(&self) -> &dyn Connection {
121        self
122    }
123}
124
125/// Trick to get the request context from the generic type, and to cast [`ConnectionBody`] to
126/// `dyn` [`Connection`] trait object.
127pub trait GetRequestContext: std::fmt::Debug + Send + Sync + 'static {
128    fn get_req_con(&self) -> Option<&RequestContext>;
129}
130
131impl GetRequestContext for Arc<RequestContext> {
132    fn get_req_con(&self) -> Option<&RequestContext> {
133        Some(self)
134    }
135}
136
137impl GetRequestContext for RequestContext {
138    fn get_req_con(&self) -> Option<&RequestContext> {
139        Some(self)
140    }
141}
142
143impl GetRequestContext for () {
144    fn get_req_con(&self) -> Option<&RequestContext> {
145        None
146    }
147}
148
149impl<C, T, R: Debug, U> Debug for ConnectionImpl<C, T, R, U> {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("ConnectionBody").field("reqs", &self.reqs).finish()
152    }
153}
154
155/* --------------------------------------- User Data Trait -------------------------------------- */
156pub trait UserData: Any + Send + Sync + 'static {
157    fn as_any(&self) -> &(dyn Any + Send + Sync + 'static);
158}
159
160impl<T> UserData for T
161where
162    T: Any + Send + Sync + 'static,
163{
164    fn as_any(&self) -> &(dyn Any + Send + Sync + 'static) {
165        self
166    }
167}
168
169impl AsRef<dyn Any + Send + Sync + 'static> for dyn UserData {
170    fn as_ref(&self) -> &(dyn Any + Send + Sync + 'static) {
171        self.as_any()
172    }
173}
174
175#[derive(Clone)]
176pub struct OwnedUserData(Arc<dyn Connection>);
177
178impl OwnedUserData {
179    pub fn raw(&self) -> &dyn UserData {
180        self.0.user_data()
181    }
182}
183
184impl std::ops::Deref for OwnedUserData {
185    type Target = dyn Any;
186
187    fn deref(&self) -> &Self::Target {
188        self.0.user_data().as_any()
189    }
190}
191
192/* -------------------------------------- Driver Directives ------------------------------------- */
193
194/// Which couldn't be handled within the non-async drop handlers ...
195///
196/// - A received request object is dropped before response is sent, therefore an 'aborted' message
197///   should be sent to the receiver
198#[derive(Debug)]
199enum InboundDriverDirective {
200    /// Defer sending response to the background task.
201    DeferredWrite(DeferredWrite),
202
203    /// Manually close the connection.
204    Close,
205}
206
207#[derive(Debug)]
208enum DeferredWrite {
209    /// Send error response to the request
210    ErrorResponse(msg::RequestInner, codec::PredefinedResponseError),
211
212    /// Send request was deferred. This is used for non-blocking response, etc.
213    Raw(bytes::BytesMut),
214
215    /// Send flush request from background
216    Flush,
217}
218
219/* ---------------------------------------------------------------------------------------------- */
220/*                                             HANDLES                                            */
221/* ---------------------------------------------------------------------------------------------- */
222/// Bidirectional RPC handle. It can serve as both client and server.
223#[derive(Clone, Debug)]
224pub struct Transceiver(Sender, flume::Receiver<msg::RecvMsg>);
225
226impl std::ops::Deref for Transceiver {
227    type Target = Sender;
228
229    fn deref(&self) -> &Self::Target {
230        &self.0
231    }
232}
233
234impl Transceiver {
235    pub fn into_sender(self) -> Sender {
236        self.0
237    }
238
239    pub fn clone_sender(&self) -> Sender {
240        self.0.clone()
241    }
242
243    pub fn blocking_recv(&self) -> Result<msg::RecvMsg, RecvError> {
244        Ok(self.1.recv()?)
245    }
246
247    pub async fn recv(&self) -> Result<msg::RecvMsg, RecvError> {
248        Ok(self.1.recv_async().await?)
249    }
250
251    pub fn try_recv(&self) -> Result<msg::RecvMsg, TryRecvError> {
252        self.1.try_recv().map_err(|e| match e {
253            flume::TryRecvError::Empty => TryRecvError::Empty,
254            flume::TryRecvError::Disconnected => TryRecvError::Disconnected,
255        })
256    }
257}
258
259impl From<Transceiver> for Sender {
260    fn from(value: Transceiver) -> Self {
261        value.0
262    }
263}
264
265/// Send-only handle. This holds strong reference to the connection.
266#[derive(Clone, Debug)]
267pub struct Sender(Arc<dyn Connection>);
268
269/// Reused buffer over multiple RPC request/responses
270///
271/// To minimize the memory allocation during sender payload serialization, reuse this buffer over
272/// multiple RPC request/notification.
273#[derive(Debug, Clone, Default)]
274pub struct WriteBuffer {
275    value: BytesMut,
276}
277
278impl WriteBuffer {
279    pub(crate) fn prepare(&mut self) {
280        self.value.clear();
281    }
282
283    pub fn reserve(&mut self, size: usize) {
284        self.value.clear();
285        self.value.reserve(size);
286    }
287}
288
289#[derive(Debug, thiserror::Error)]
290pub enum CallError {
291    #[error("Failed to send request: {0}")]
292    SendFailed(#[from] SendError),
293
294    #[error("Failed to receive response: {0}")]
295    FlushFailed(#[from] std::io::Error),
296
297    #[error("Failed to receive response: {0}")]
298    RecvFailed(#[from] RecvError),
299
300    #[error("Remote returned invalid return type: {0}")]
301    ParseFailed(DecodeError, msg::Response),
302
303    #[error("Remote returned error response")]
304    ErrorResponse(msg::Response),
305}
306
307#[derive(Debug, EnumAsInner, thiserror::Error)]
308pub enum TypedCallError<E> {
309    #[error("Failed to send request: {0}")]
310    SendFailed(#[from] super::SendError),
311
312    #[error("Failed to receive response: {0}")]
313    FlushFailed(#[from] std::io::Error),
314
315    #[error("Failed to receive response: {0}")]
316    RecvFailed(#[from] super::RecvError),
317
318    #[error("Remote returned error response")]
319    Response(#[from] super::ResponseError<E>),
320}
321
322impl Sender {
323    /// A shortcut for request, flush, and receive response.
324    pub async fn call<R: serde::de::DeserializeOwned>(
325        &self,
326        method: &str,
327        params: &impl serde::Serialize,
328    ) -> Result<R, CallError> {
329        let response = self.request(method, params).await?;
330
331        self.0.__flush().await?;
332
333        let msg = response.await?;
334
335        if msg.is_error {
336            return Err(CallError::ErrorResponse(msg));
337        }
338
339        match msg.parse() {
340            Ok(value) => Ok(value),
341            Err(err) => Err(CallError::ParseFailed(err, msg)),
342        }
343    }
344
345    /// A shortcut for strictly typed request
346    pub async fn call_with_err<R: serde::de::DeserializeOwned, E: serde::de::DeserializeOwned>(
347        &self,
348        method: &str,
349        params: &impl serde::Serialize,
350    ) -> Result<R, TypedCallError<E>> {
351        let response = self.request(method, params).await?;
352
353        self.0.__flush().await?;
354
355        let msg = response.await?;
356        Ok(msg.result()?)
357    }
358
359    /// Send request, and create response future which will be resolved when the response is
360    /// received.
361    pub async fn request<T: serde::Serialize>(
362        &self,
363        method: &str,
364        params: &T,
365    ) -> Result<ResponseFuture, SendError> {
366        self.request_with_reuse(&mut Default::default(), method, params).await
367    }
368
369    /// Send notification message.
370    pub async fn notify<T: serde::Serialize>(
371        &self,
372        method: &str,
373        params: &T,
374    ) -> Result<(), SendError> {
375        self.notify_with_reuse(&mut Default::default(), method, params).await
376    }
377
378    #[doc(hidden)]
379    pub async fn request_with_reuse<T: serde::Serialize>(
380        &self,
381        buf: &mut WriteBuffer,
382        method: &str,
383        params: &T,
384    ) -> Result<ResponseFuture, SendError> {
385        let fut = self.prepare_request(buf, method, params)?;
386        self.0.__write_raw(&mut buf.value).await?;
387
388        Ok(fut)
389    }
390
391    #[doc(hidden)]
392    pub async fn notify_with_reuse<T: serde::Serialize>(
393        &self,
394        buf: &mut WriteBuffer,
395        method: &str,
396        params: &T,
397    ) -> Result<(), SendError> {
398        buf.prepare();
399
400        self.0.codec().encode_notify(method, params, &mut buf.value)?;
401        self.0.__write_raw(&mut buf.value).await?;
402
403        Ok(())
404    }
405
406    /// Sends a request and returns a future that will be resolved when the response is received.
407    ///
408    /// This method is non-blocking, as the message writing will be deferred to the background
409    pub fn request_deferred<T: serde::Serialize>(
410        &self,
411        method: &str,
412        params: &T,
413    ) -> Result<ResponseFuture, SendError> {
414        self.request_deferred_with_reuse(&mut Default::default(), method, params)
415    }
416
417    /// Send deferred notification. This method is non-blocking, as the message writing will be
418    /// deferred to the background driver worker.
419    pub fn notify_deferred<T: serde::Serialize>(
420        &self,
421        method: &str,
422        params: &T,
423    ) -> Result<(), SendError> {
424        self.notify_deferred_with_reuse(&mut Default::default(), method, params)
425    }
426
427    /// Sends a request and returns a future that will be resolved when the response is received.
428    ///
429    /// This method is non-blocking, as the message writing will be deferred to the background
430    #[doc(hidden)]
431    pub fn request_deferred_with_reuse<T: serde::Serialize>(
432        &self,
433        buffer: &mut WriteBuffer,
434        method: &str,
435        params: &T,
436    ) -> Result<ResponseFuture, SendError> {
437        buffer.prepare();
438        let fut = self.prepare_request(buffer, method, params);
439
440        self.0
441            .tx_drive()
442            .send(InboundDriverDirective::DeferredWrite(DeferredWrite::Raw(buffer.value.split())))
443            .map_err(|_| SendError::Disconnected)?;
444
445        fut
446    }
447
448    /// Send deferred notification. This method is non-blocking, as the message writing will be
449    /// deferred to the background driver worker.
450    #[doc(hidden)]
451    pub fn notify_deferred_with_reuse<T: serde::Serialize>(
452        &self,
453        buffer: &mut WriteBuffer,
454        method: &str,
455        params: &T,
456    ) -> Result<(), SendError> {
457        buffer.prepare();
458        self.0.codec().encode_notify(method, params, &mut buffer.value)?;
459
460        self.0
461            .tx_drive()
462            .send(InboundDriverDirective::DeferredWrite(DeferredWrite::Raw(buffer.value.split())))
463            .map_err(|_| SendError::Disconnected)
464    }
465
466    fn prepare_request<T: serde::Serialize>(
467        &self,
468        buf: &mut WriteBuffer,
469        method: &str,
470        params: &T,
471    ) -> Result<ResponseFuture, SendError> {
472        let Some(req) = self.0.reqs() else { return Err(SendError::RequestDisabled) };
473
474        buf.prepare();
475        let req_id_hint = req.next_req_id_base();
476        let req_id_hash =
477            self.0.codec().encode_request(method, req_id_hint, params, &mut buf.value)?;
478
479        // Registering request always preceded than sending request. If request was not sent due to
480        // I/O issue or cancellation, the request will be unregistered on the drop of the
481        // `ResponseFuture`.
482        let slot_id = req.register_req(req_id_hash);
483
484        Ok(ResponseFuture::new(&self.0, slot_id))
485    }
486
487    /// Check if the connection is disconnected.
488    pub fn is_disconnected(&self) -> bool {
489        self.0.tx_drive().is_disconnected()
490    }
491
492    /// Closes the connection. If it's already closed, it'll return false.
493    ///
494    /// # Caution
495    ///
496    /// If multiple threads calls this method at the same time, more than one thread may return
497    /// true, as the close operation is lazy.
498    pub fn close(self) -> bool {
499        self.0.tx_drive().send(InboundDriverDirective::Close).is_ok()
500    }
501
502    /// Flush underlying write stream.
503    pub async fn flush(&self) -> std::io::Result<()> {
504        self.0.__flush().await
505    }
506
507    /// Perform flush from background task.
508    pub fn flush_deferred(&self) -> bool {
509        self.0.tx_drive().send(InboundDriverDirective::DeferredWrite(DeferredWrite::Flush)).is_ok()
510    }
511
512    /// Is sending request enabled?
513    pub fn is_request_enabled(&self) -> bool {
514        self.0.reqs().is_some()
515    }
516
517    /// Get feature flags
518    pub fn get_feature_flags(&self) -> Feature {
519        self.0.feature_flag()
520    }
521}
522
523mod req {
524    use std::{
525        borrow::Cow,
526        mem::replace,
527        num::NonZeroU64,
528        sync::{atomic::AtomicU64, Arc},
529        task::Poll,
530    };
531
532    use dashmap::DashMap;
533    use futures_util::{task::AtomicWaker, FutureExt};
534    use parking_lot::Mutex;
535
536    use crate::RecvError;
537
538    use super::{msg, Connection};
539
540    /// RPC request context. Stores request ID and response receiver context.
541    #[derive(Debug, Default)]
542    pub struct RequestContext {
543        /// We may not run out of 64-bit sequential integers ...
544        req_id_gen: AtomicU64,
545        waiters: DashMap<NonZeroU64, RequestSlot>,
546    }
547
548    #[derive(Clone, Debug)]
549    pub(super) struct RequestSlotId(NonZeroU64);
550
551    #[derive(Debug)]
552    struct RequestSlot {
553        waker: AtomicWaker,
554        value: Mutex<Option<msg::Response>>,
555    }
556
557    impl RequestContext {
558        /// Never returns 0.
559        pub(super) fn next_req_id_base(&self) -> NonZeroU64 {
560            // SAFETY: 1 + 2 * N is always odd, and non-zero on wrapping condition.
561            // > Additionally, to see it wraps, we need to send 2^63 requests ...
562            unsafe {
563                NonZeroU64::new_unchecked(
564                    1 + self.req_id_gen.fetch_add(2, std::sync::atomic::Ordering::Relaxed),
565                )
566            }
567        }
568
569        #[must_use]
570        pub(super) fn register_req(&self, req_id_hash: NonZeroU64) -> RequestSlotId {
571            let slot = RequestSlot { waker: AtomicWaker::new(), value: Mutex::new(None) };
572            if self.waiters.insert(req_id_hash, slot).is_some() {
573                panic!("Request ID collision")
574            }
575            RequestSlotId(req_id_hash)
576        }
577
578        /// Routes given response to appropriate handler. Returns `Err` if no handler is found.
579        pub(super) fn route_response(
580            &self,
581            req_id_hash: NonZeroU64,
582            response: msg::Response,
583        ) -> Result<(), msg::Response> {
584            let Some(slot) = self.waiters.get(&req_id_hash) else {
585                return Err(response);
586            };
587
588            let mut value = slot.value.lock();
589            value.replace(response);
590            slot.waker.wake();
591
592            Ok(())
593        }
594
595        /// Wake up all waiters. This is to abort all pending response futures that are waiting on
596        /// closed connections. This doesn't do more than that, as the request context itself can
597        /// be shared among multiple connections.
598        pub(super) fn wake_up_all(&self) {
599            for x in self.waiters.iter() {
600                x.waker.wake();
601            }
602        }
603    }
604
605    /// When dropped, the response handler will be unregistered from the queue.
606    #[must_use = "futures do nothing unless you `.await` or poll them"]
607    #[derive(Debug)]
608    pub struct ResponseFuture<'a>(ResponseFutureInner<'a>);
609
610    #[must_use = "futures do nothing unless you `.await` or poll them"]
611    #[derive(Debug)]
612    pub struct OwnedResponseFuture(ResponseFuture<'static>);
613
614    #[derive(Debug)]
615    enum ResponseFutureInner<'a> {
616        Waiting(Cow<'a, Arc<dyn Connection>>, NonZeroU64),
617        Finished,
618    }
619
620    impl<'a> ResponseFuture<'a> {
621        pub(super) fn new(handle: &'a Arc<dyn Connection>, slot_id: RequestSlotId) -> Self {
622            Self(ResponseFutureInner::Waiting(Cow::Borrowed(handle), slot_id.0))
623        }
624
625        pub fn to_owned(mut self) -> OwnedResponseFuture {
626            let state = replace(&mut self.0, ResponseFutureInner::Finished);
627
628            match state {
629                ResponseFutureInner::Waiting(conn, id) => OwnedResponseFuture(ResponseFuture(
630                    ResponseFutureInner::Waiting(Cow::Owned(conn.into_owned()), id),
631                )),
632                ResponseFutureInner::Finished => {
633                    OwnedResponseFuture(ResponseFuture(ResponseFutureInner::Finished))
634                }
635            }
636        }
637
638        pub fn try_recv(&mut self) -> Result<Option<msg::Response>, RecvError> {
639            use ResponseFutureInner::*;
640
641            match &mut self.0 {
642                Waiting(conn, hash) => {
643                    if conn.__is_disconnected() {
644                        // Let the 'drop' trait erase the request from the queue.
645                        return Err(RecvError::Disconnected);
646                    }
647
648                    let mut value = None;
649                    conn.reqs().unwrap().waiters.remove_if(hash, |_, elem| {
650                        if let Some(v) = elem.value.lock().take() {
651                            value = Some(v);
652                            true
653                        } else {
654                            false
655                        }
656                    });
657
658                    if value.is_some() {
659                        self.0 = Finished;
660                    }
661
662                    Ok(value)
663                }
664
665                Finished => Ok(None),
666            }
667        }
668    }
669
670    impl<'a> std::future::Future for ResponseFuture<'a> {
671        type Output = Result<msg::Response, RecvError>;
672
673        fn poll(
674            mut self: std::pin::Pin<&mut Self>,
675            cx: &mut std::task::Context<'_>,
676        ) -> std::task::Poll<Self::Output> {
677            use ResponseFutureInner::*;
678
679            match &mut self.0 {
680                Waiting(conn, hash) => {
681                    if conn.__is_disconnected() {
682                        // Let the 'drop' trait erase the request from the queue.
683                        return Poll::Ready(Err(RecvError::Disconnected));
684                    }
685
686                    let mut value = None;
687                    conn.reqs().unwrap().waiters.remove_if(hash, |_, elem| {
688                        if let Some(v) = elem.value.lock().take() {
689                            value = Some(v);
690                            true
691                        } else {
692                            elem.waker.register(cx.waker());
693                            false
694                        }
695                    });
696
697                    if let Some(value) = value {
698                        self.0 = Finished;
699                        Poll::Ready(Ok(value))
700                    } else {
701                        Poll::Pending
702                    }
703                }
704
705                Finished => panic!("ResponseFuture polled after completion"),
706            }
707        }
708    }
709
710    impl std::future::Future for OwnedResponseFuture {
711        type Output = Result<msg::Response, RecvError>;
712
713        fn poll(
714            mut self: std::pin::Pin<&mut Self>,
715            cx: &mut std::task::Context<'_>,
716        ) -> std::task::Poll<Self::Output> {
717            self.0.poll_unpin(cx)
718        }
719    }
720
721    impl futures_util::future::FusedFuture for ResponseFuture<'_> {
722        fn is_terminated(&self) -> bool {
723            matches!(self.0, ResponseFutureInner::Finished)
724        }
725    }
726
727    impl futures_util::future::FusedFuture for OwnedResponseFuture {
728        fn is_terminated(&self) -> bool {
729            self.0.is_terminated()
730        }
731    }
732
733    impl std::ops::Deref for OwnedResponseFuture {
734        type Target = ResponseFuture<'static>;
735
736        fn deref(&self) -> &Self::Target {
737            &self.0
738        }
739    }
740
741    impl std::ops::DerefMut for OwnedResponseFuture {
742        fn deref_mut(&mut self) -> &mut Self::Target {
743            &mut self.0
744        }
745    }
746
747    impl Drop for ResponseFuture<'_> {
748        fn drop(&mut self) {
749            let state = std::mem::replace(&mut self.0, ResponseFutureInner::Finished);
750            let ResponseFutureInner::Waiting(conn, hash) = state else { return };
751            let reqs = conn.reqs().unwrap();
752            assert!(
753                reqs.waiters.remove(&hash).is_some(),
754                "Request lifespan must be bound to this future."
755            );
756        }
757    }
758}
759
760/* ---------------------------------------------------------------------------------------------- */
761/*                                        ERROR DEFINIITONS                                       */
762/* ---------------------------------------------------------------------------------------------- */
763#[derive(Debug, thiserror::Error)]
764pub enum SendError {
765    #[error("Encoding outbound message failed: {0}")]
766    CodecError(#[from] crate::codec::EncodeError),
767
768    #[error("Error during preparing send: {0}")]
769    SendSetupFailed(std::io::Error),
770
771    #[error("Error during sending message to write stream: {0}")]
772    IoError(#[from] std::io::Error),
773
774    #[error("Request feature is disabled for this connection")]
775    RequestDisabled,
776
777    #[error("Channel is already closed!")]
778    Disconnected,
779}
780
781#[derive(Debug, thiserror::Error)]
782pub enum RecvError {
783    #[error("Channel has been closed")]
784    Disconnected,
785}
786
787#[derive(Debug, thiserror::Error)]
788pub enum TryRecvError {
789    #[error("Connection has been disconnected")]
790    Disconnected,
791
792    #[error("Channel is empty")]
793    Empty,
794}
795
796impl From<flume::RecvError> for RecvError {
797    fn from(value: flume::RecvError) -> Self {
798        match value {
799            flume::RecvError::Disconnected => Self::Disconnected,
800        }
801    }
802}
803
804/* ---------------------------------------------------------------------------------------------- */
805/*                                             BUILDER                                            */
806/* ---------------------------------------------------------------------------------------------- */
807
808//  - Create async worker task to handle receive
809pub struct Builder<Tw, Tr, C, E, R, U> {
810    codec: C,
811    write: Tw,
812    read: Tr,
813    ev: E,
814    reqs: R,
815    user_data: U,
816
817    /// Required configurations
818    cfg: BuilderOtherConfig,
819}
820
821#[derive(Default)]
822struct BuilderOtherConfig {
823    inbound_channel_cap: Option<NonZeroUsize>,
824    feature_flag: Feature,
825}
826
827impl Default for Builder<(), (), (), EmptyEventListener, (), ()> {
828    fn default() -> Self {
829        Self {
830            codec: (),
831            write: (),
832            read: (),
833            cfg: Default::default(),
834            ev: EmptyEventListener,
835            reqs: (),
836            user_data: (),
837        }
838    }
839}
840
841impl<Tw, Tr, C, E, R, U> Builder<Tw, Tr, C, E, R, U> {
842    /// Specify codec to use.
843    pub fn with_codec<C1: Codec>(
844        self,
845        codec: impl Into<Arc<C1>>,
846    ) -> Builder<Tw, Tr, Arc<C1>, E, R, U> {
847        Builder {
848            codec: codec.into(),
849            write: self.write,
850            read: self.read,
851            cfg: self.cfg,
852            ev: self.ev,
853            reqs: self.reqs,
854            user_data: self.user_data,
855        }
856    }
857
858    /// Specify write frame to use
859    pub fn with_write<Tw2>(self, write: Tw2) -> Builder<Tw2, Tr, C, E, R, U>
860    where
861        Tw2: AsyncFrameWrite,
862    {
863        Builder {
864            codec: self.codec,
865            write,
866            read: self.read,
867            cfg: self.cfg,
868            ev: self.ev,
869            reqs: self.reqs,
870            user_data: self.user_data,
871        }
872    }
873
874    /// Specify [`AsyncFrameRead`] to use
875    pub fn with_read<Tr2>(self, read: Tr2) -> Builder<Tw, Tr2, C, E, R, U>
876    where
877        Tr2: AsyncFrameRead,
878    {
879        Builder {
880            codec: self.codec,
881            write: self.write,
882            read,
883            cfg: self.cfg,
884            ev: self.ev,
885            reqs: self.reqs,
886            user_data: self.user_data,
887        }
888    }
889
890    /// Specify [`InboundEventSubscriber`] to use. This is used to handle errnous events from
891    /// inbound driver
892    pub fn with_event_listener<E2: InboundEventSubscriber>(
893        self,
894        ev: E2,
895    ) -> Builder<Tw, Tr, C, E2, R, U> {
896        Builder {
897            codec: self.codec,
898            write: self.write,
899            read: self.read,
900            cfg: self.cfg,
901            ev,
902            reqs: self.reqs,
903            user_data: self.user_data,
904        }
905    }
906
907    /// Specify addtional user data to store in the connection.
908    pub fn with_user_data<U2: UserData>(self, user_data: U2) -> Builder<Tw, Tr, C, E, R, U2> {
909        Builder {
910            codec: self.codec,
911            write: self.write,
912            read: self.read,
913            cfg: self.cfg,
914            ev: self.ev,
915            reqs: self.reqs,
916            user_data,
917        }
918    }
919
920    /// Set the read frame with stream reader and framing decoder.
921    ///
922    /// # Parameters
923    ///
924    /// - `default_readbuf_reserve`: When the framing decoder does not provide the next buffer size,
925    ///   this value is used to pre-allocate the buffer for the next [`AsyncReadFrame::poll_read`]
926    ///   call.
927    pub fn with_read_stream<Tr2, F>(
928        self,
929        read: Tr2,
930        framing: F,
931        default_readbuf_reserve: usize,
932    ) -> Builder<Tw, impl AsyncFrameRead, C, E, R, U>
933    where
934        Tr2: AsyncRead + Send + Sync + 'static,
935        F: Framing,
936    {
937        use std::pin::Pin;
938        use std::task::{Context, Poll};
939
940        struct FramingReader<T, F> {
941            reader: T,
942            framing: F,
943            buf: bytes::BytesMut,
944            nreserve: usize,
945            state_skip_reading: bool,
946        }
947
948        impl<T, F> AsyncFrameRead for FramingReader<T, F>
949        where
950            T: AsyncRead + Sync + Send + 'static,
951            F: Framing,
952        {
953            fn poll_next(
954                self: Pin<&mut Self>,
955                cx: &mut Context<'_>,
956            ) -> Poll<std::io::Result<Bytes>> {
957                // SAFETY: We won't move this value, for sure, right?
958                let FramingReader { reader, framing, buf, nreserve, state_skip_reading } =
959                    unsafe { self.get_unchecked_mut() };
960
961                let mut reader = unsafe { Pin::new_unchecked(reader) };
962
963                loop {
964                    // Read until the transport returns 'pending'
965                    let size_hint = framing.next_buffer_size();
966                    let size_required = size_hint.map(|x| x.get());
967
968                    while !*state_skip_reading && size_required.is_some_and(|x| buf.len() < x) {
969                        let n_req_size = size_required.unwrap_or(0).saturating_sub(buf.len());
970                        let num_reserve = (*nreserve).max(n_req_size);
971
972                        let old_cursor = buf.len();
973                        buf.reserve(num_reserve);
974
975                        unsafe {
976                            buf.set_len(old_cursor + num_reserve);
977                            match reader.as_mut().poll_read(cx, buf) {
978                                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
979                                Poll::Ready(Ok(n)) => {
980                                    buf.set_len(old_cursor + n);
981                                }
982                                Poll::Pending => {
983                                    buf.set_len(old_cursor);
984
985                                    // Skip reading until all prepared buffer is consumed.
986                                    *state_skip_reading = true;
987                                    break;
988                                }
989                            }
990                        }
991                    }
992
993                    // Try decode the buffer.
994                    match framing.try_framing(&buf[..]) {
995                        Ok(Some(x)) => {
996                            framing.advance();
997
998                            debug_assert!(x.valid_data_end <= x.next_frame_start);
999                            debug_assert!(x.next_frame_start <= buf.len());
1000
1001                            let mut frame = buf.split_to(x.next_frame_start);
1002                            break Poll::Ready(Ok(frame.split_off(x.valid_data_end).into()));
1003                        }
1004
1005                        // Just poll once more, until it retunrs 'Pending' ...
1006                        Ok(None) => {
1007                            *state_skip_reading = false;
1008                            continue;
1009                        }
1010
1011                        Err(FramingError::Broken(why)) => {
1012                            return Poll::Ready(Err(std::io::Error::new(
1013                                std::io::ErrorKind::InvalidData,
1014                                why,
1015                            )))
1016                        }
1017
1018                        Err(FramingError::Recoverable(num_discard_bytes)) => {
1019                            debug_assert!(num_discard_bytes <= buf.len());
1020
1021                            let _ = buf.split_to(num_discard_bytes);
1022                            continue;
1023                        }
1024                    }
1025                }
1026            }
1027        }
1028
1029        Builder {
1030            codec: self.codec,
1031            write: self.write,
1032            read: FramingReader {
1033                reader: read,
1034                framing,
1035                buf: BytesMut::default(),
1036                nreserve: default_readbuf_reserve,
1037                state_skip_reading: false,
1038            },
1039            cfg: self.cfg,
1040            ev: self.ev,
1041            reqs: self.reqs,
1042            user_data: self.user_data,
1043        }
1044    }
1045
1046    /// Enable request features with default request context.
1047    pub fn with_request(self) -> Builder<Tw, Tr, C, E, RequestContext, U> {
1048        Builder {
1049            codec: self.codec,
1050            write: self.write,
1051            read: self.read,
1052            cfg: self.cfg,
1053            ev: self.ev,
1054            reqs: RequestContext::default(),
1055            user_data: self.user_data,
1056        }
1057    }
1058
1059    /// Enable request features with custom request context. e.g. you can use
1060    /// [`Arc<RequestContext>`] to share the request context between multiple connections.
1061    pub fn with_request_context(
1062        self,
1063        reqs: impl GetRequestContext,
1064    ) -> Builder<Tw, Tr, C, E, impl GetRequestContext, U> {
1065        Builder {
1066            codec: self.codec,
1067            write: self.write,
1068            read: self.read,
1069            cfg: self.cfg,
1070            ev: self.ev,
1071            reqs,
1072            user_data: self.user_data,
1073        }
1074    }
1075
1076    /// Setting zero will create unbounded channel. Default is unbounded.
1077    pub fn with_inbound_channel_capacity(mut self, capacity: usize) -> Self {
1078        self.cfg.inbound_channel_cap = capacity.try_into().ok();
1079        self
1080    }
1081
1082    /// Enable specified feature flags. This applies bit-or-assign operation on feature flags.
1083    pub fn with_feature(mut self, feature: Feature) -> Self {
1084        self.cfg.feature_flag |= feature;
1085        self
1086    }
1087
1088    /// Disable specified feature flags.
1089    pub fn without_feature(mut self, feature: Feature) -> Self {
1090        self.cfg.feature_flag &= !feature;
1091        self
1092    }
1093}
1094
1095impl<Tw, Tr, C, E, R, U> Builder<Tw, Tr, Arc<C>, E, R, U>
1096where
1097    Tw: AsyncFrameWrite,
1098    Tr: AsyncFrameRead,
1099    C: Codec,
1100    E: InboundEventSubscriber,
1101    R: GetRequestContext,
1102    U: UserData,
1103{
1104    /// Build the connection from provided parameters.
1105    ///
1106    /// To start the connection, you need to spawn the returned future to the executor.
1107    #[must_use = "The connection will be closed immediately if you don't spawn the future!"]
1108    pub fn build(self) -> (Transceiver, impl std::future::Future<Output = ()> + Send) {
1109        let (tx_inb_drv, rx_inb_drv) = flume::unbounded();
1110        let (tx_in_msg, rx_in_msg) =
1111            if let Some(chan_cap) = self.cfg.inbound_channel_cap.map(|x| x.get()) {
1112                flume::bounded(chan_cap)
1113            } else {
1114                flume::unbounded()
1115            };
1116
1117        let conn: Arc<dyn Connection>;
1118        let this = ConnectionImpl {
1119            codec: self.codec,
1120            write: AsyncMutex::new(self.write),
1121            reqs: self.reqs,
1122            tx_drive: tx_inb_drv,
1123            features: self.cfg.feature_flag,
1124            user_data: self.user_data,
1125            _unpin: std::marker::PhantomPinned,
1126        };
1127
1128        let this = Arc::new(this);
1129        let fut_driver = ConnectionImpl::inbound_event_handler(DriverBody {
1130            w_this: Arc::downgrade(&this),
1131            read: self.read,
1132            ev_subs: self.ev,
1133            rx_drive: rx_inb_drv,
1134            tx_msg: tx_in_msg,
1135        });
1136
1137        conn = this;
1138        (Transceiver(Sender(conn), rx_in_msg), fut_driver)
1139    }
1140}
1141
1142/* ---------------------------------------------------------------------------------------------- */
1143/*                                      EVENT LISTENER TRAIT                                      */
1144/* ---------------------------------------------------------------------------------------------- */
1145
1146#[derive(Debug, thiserror::Error)]
1147pub enum InboundError {
1148    #[error("Response packet is received, but we haven't enabled request feature!")]
1149    RedundantResponse(msg::Response),
1150
1151    #[error("Response packet is not routed.")]
1152    ExpiredResponse(msg::Response),
1153
1154    #[error("Response hash was restored as 0, which is invalid.")]
1155    ResponseHashZero(msg::Response),
1156
1157    #[error("Failed to decode inbound type: {} bytes, error was = {0} ", .1.len())]
1158    InboundDecodeError(DecodeError, bytes::Bytes),
1159
1160    #[error("Disabled notification/request is received")]
1161    DisabledInbound(msg::RecvMsg),
1162}
1163
1164/// This trait is to notify the event from the connection.
1165pub trait InboundEventSubscriber: Send + Sync + 'static {
1166    /// Called when the inbound stream is closed. If channel is closed
1167    ///
1168    /// # Arguments
1169    ///
1170    /// - `closed_by_us`: Are we closing? Otherwise, the read stream is closed by the remote.
1171    /// - `result`:
1172    ///     - Closing result of the read stream. Usually, if `closed_by_us` is true, this is
1173    ///       `Ok(())`, otherwise, may contain error related to network stream disconnection.
1174    fn on_close(&self, closed_by_us: bool, result: std::io::Result<()>) {
1175        let _ = (closed_by_us, result);
1176    }
1177
1178    /// When an errnous response is received.
1179    ///
1180    /// -
1181    fn on_inbound_error(&self, error: InboundError) {
1182        let _ = (error,);
1183    }
1184}
1185
1186/// Placeholder implementation of event listener.
1187pub struct EmptyEventListener;
1188
1189impl InboundEventSubscriber for EmptyEventListener {}
1190
1191/* ---------------------------------------------------------------------------------------------- */
1192/*                                            MESSAGES                                            */
1193/* ---------------------------------------------------------------------------------------------- */
1194pub trait Message {
1195    fn payload(&self) -> &[u8];
1196    fn raw(&self) -> &[u8];
1197    fn raw_bytes(&self) -> Bytes;
1198    fn codec(&self) -> &dyn Codec;
1199
1200    /// Parses payload as speicfied type.
1201    fn parse<'a, T: serde::de::Deserialize<'a>>(&'a self) -> Result<T, DecodeError> {
1202        let mut dst = None;
1203        self.parse_in_place(&mut dst)?;
1204        Ok(dst.unwrap())
1205    }
1206
1207    /// Parses payload as speicfied type, in-place. It is marked as hidden to follow origianl method
1208    /// [`serde::Deserialize::deserialize_in_place`]'s visibility convention, which is '*almost
1209    /// never what newbies are looking for*'.
1210    #[doc(hidden)]
1211    fn parse_in_place<'a, T: serde::de::Deserialize<'a>>(
1212        &'a self,
1213        dst: &mut Option<T>,
1214    ) -> Result<(), DecodeError> {
1215        self.codec().decode_payload(self.payload(), &mut |de| {
1216            if let Some(dst) = dst.as_mut() {
1217                T::deserialize_in_place(de, dst)
1218            } else {
1219                *dst = Some(T::deserialize(de)?);
1220                Ok(())
1221            }
1222        })
1223    }
1224}
1225
1226pub trait MessageReqId: Message {
1227    fn req_id(&self) -> codec::ReqIdRef;
1228}
1229
1230pub trait MessageMethodName: Message {
1231    fn method_raw(&self) -> &[u8];
1232    fn method(&self) -> Option<&str> {
1233        std::str::from_utf8(self.method_raw()).ok()
1234    }
1235}
1236
1237pub trait ExtractUserData {
1238    fn extract_sender(&self) -> Sender;
1239    fn user_data_raw(&self) -> &dyn UserData;
1240    fn user_data_owned(&self) -> OwnedUserData;
1241    fn user_data<T: UserData>(&self) -> Option<&T> {
1242        self.user_data_raw().as_any().downcast_ref()
1243    }
1244}
1245
1246impl ExtractUserData for Sender {
1247    fn extract_sender(&self) -> Sender {
1248        self.clone()
1249    }
1250
1251    fn user_data_raw(&self) -> &dyn UserData {
1252        self.0.user_data()
1253    }
1254
1255    fn user_data_owned(&self) -> OwnedUserData {
1256        OwnedUserData(self.0.clone())
1257    }
1258}
1259
1260/// Common content of inbound message
1261#[derive(Debug)]
1262struct InboundBody {
1263    buffer: Bytes,
1264    payload: Range<usize>,
1265    codec: Arc<dyn Codec>,
1266}
1267
1268pub mod msg {
1269    use std::{ops::Range, sync::Arc};
1270
1271    use enum_as_inner::EnumAsInner;
1272
1273    use crate::{
1274        codec::{DecodeError, EncodeError, PredefinedResponseError, ReqId, ReqIdRef},
1275        rpc::MessageReqId,
1276    };
1277
1278    use super::{Connection, DeferredWrite, ExtractUserData, Message, UserData, WriteBuffer};
1279
1280    macro_rules! impl_message {
1281        ($t:ty) => {
1282            impl super::Message for $t {
1283                fn payload(&self) -> &[u8] {
1284                    &self.h.buffer[self.h.payload.clone()]
1285                }
1286
1287                fn raw(&self) -> &[u8] {
1288                    &self.h.buffer
1289                }
1290
1291                fn raw_bytes(&self) -> bytes::Bytes {
1292                    self.h.buffer.clone()
1293                }
1294
1295                fn codec(&self) -> &dyn super::Codec {
1296                    &*self.h.codec
1297                }
1298            }
1299        };
1300    }
1301
1302    macro_rules! impl_method_name {
1303        ($t:ty) => {
1304            impl super::MessageMethodName for $t {
1305                fn method_raw(&self) -> &[u8] {
1306                    &self.h.buffer[self.method.clone()]
1307                }
1308            }
1309        };
1310    }
1311
1312    macro_rules! impl_req_id {
1313        ($t:ty) => {
1314            impl super::MessageReqId for $t {
1315                fn req_id(&self) -> ReqIdRef {
1316                    self.req_id.make_ref(&self.h.buffer)
1317                }
1318            }
1319        };
1320    }
1321
1322    /* ------------------------------------- Request Logics ------------------------------------- */
1323    #[derive(Debug)]
1324    pub struct RequestInner {
1325        pub(super) h: super::InboundBody,
1326        pub(super) method: Range<usize>,
1327        pub(super) req_id: ReqId,
1328    }
1329
1330    impl_message!(RequestInner);
1331    impl_method_name!(RequestInner);
1332    impl_req_id!(RequestInner);
1333
1334    #[derive(Debug)]
1335    pub struct Request {
1336        pub(super) body: Option<(RequestInner, Arc<dyn Connection>)>,
1337    }
1338
1339    impl std::ops::Deref for Request {
1340        type Target = RequestInner;
1341
1342        fn deref(&self) -> &Self::Target {
1343            &self.body.as_ref().unwrap().0
1344        }
1345    }
1346
1347    impl ExtractUserData for Request {
1348        fn user_data_raw(&self) -> &dyn UserData {
1349            self.body.as_ref().unwrap().1.user_data()
1350        }
1351        fn user_data_owned(&self) -> super::OwnedUserData {
1352            super::OwnedUserData(self.body.as_ref().unwrap().1.clone())
1353        }
1354        fn extract_sender(&self) -> crate::Sender {
1355            super::Sender(self.body.as_ref().unwrap().1.clone())
1356        }
1357    }
1358
1359    impl Request {
1360        pub async fn response<T: serde::Serialize>(
1361            self,
1362            value: Result<&T, &T>,
1363        ) -> Result<(), super::SendError> {
1364            let mut buf = Default::default();
1365            self.response_with_reuse(&mut buf, value).await
1366        }
1367
1368        /// Response with given value. If the value is `Err`, the response will be sent as error
1369        #[doc(hidden)]
1370        pub async fn response_with_reuse<T: serde::Serialize>(
1371            self,
1372            buf: &mut super::WriteBuffer,
1373            value: Result<&T, &T>,
1374        ) -> Result<(), super::SendError> {
1375            let conn = self.prepare_response(buf, value)?;
1376            conn.__write_raw(&mut buf.value).await?;
1377            Ok(())
1378        }
1379
1380        /// Explicitly abort the request. This is useful when you want to cancel the request
1381        pub async fn abort(self) -> Result<(), super::SendError> {
1382            self.error_predefined(PredefinedResponseError::Aborted).await
1383        }
1384
1385        /// Notify handler received notify handler ...
1386        pub async fn error_notify_handler(self) -> Result<(), super::SendError> {
1387            self.error_predefined(PredefinedResponseError::NotifyHandler).await
1388        }
1389
1390        /// Response with 'parse failed' predefined error type.
1391        pub async fn error_parse_failed<T>(self) -> Result<(), super::SendError> {
1392            let err = PredefinedResponseError::ParseFailed(std::any::type_name::<T>().into());
1393            self.error_predefined(err).await
1394        }
1395
1396        /// Response 'internal' error with given error code.
1397        pub async fn error_internal(
1398            self,
1399            errc: i32,
1400            detail: impl Into<Option<String>>,
1401        ) -> Result<(), super::SendError> {
1402            let err = if let Some(detail) = detail.into() {
1403                PredefinedResponseError::InternalDetailed(errc, detail)
1404            } else {
1405                PredefinedResponseError::Internal(errc)
1406            };
1407
1408            self.error_predefined(err).await
1409        }
1410
1411        async fn error_predefined(
1412            mut self,
1413            err: super::codec::PredefinedResponseError,
1414        ) -> Result<(), super::SendError> {
1415            let (inner, conn) = self.body.take().unwrap();
1416            conn.__send_err_predef(&mut Default::default(), &inner, &err).await
1417        }
1418
1419        /* ---------------------------------- Deferred Version ---------------------------------- */
1420        #[doc(hidden)]
1421        pub fn response_deferred_with_reuse<T: serde::Serialize>(
1422            self,
1423            buffer: &mut WriteBuffer,
1424            value: Result<&T, &T>,
1425        ) -> Result<(), super::SendError> {
1426            buffer.prepare();
1427            let conn = self.prepare_response(buffer, value).unwrap();
1428
1429            conn.tx_drive()
1430                .send(super::InboundDriverDirective::DeferredWrite(
1431                    DeferredWrite::Raw(buffer.value.split()).into(),
1432                ))
1433                .map_err(|_| super::SendError::Disconnected)
1434        }
1435
1436        pub fn response_deferred<T: serde::Serialize>(
1437            self,
1438            value: Result<&T, &T>,
1439        ) -> Result<(), super::SendError> {
1440            self.response_deferred_with_reuse(&mut Default::default(), value)
1441        }
1442
1443        pub fn abort_deferred(self) -> Result<(), super::SendError> {
1444            self.error_predefined_deferred(PredefinedResponseError::Aborted)
1445        }
1446
1447        pub fn error_notify_handler_deferred(self) -> Result<(), super::SendError> {
1448            self.error_predefined_deferred(PredefinedResponseError::NotifyHandler)
1449        }
1450
1451        pub fn error_parse_failed_deferred<T>(self) -> Result<(), super::SendError> {
1452            let err = PredefinedResponseError::ParseFailed(std::any::type_name::<T>().into());
1453            self.error_predefined_deferred(err)
1454        }
1455
1456        pub fn error_internal_deferred(
1457            self,
1458            errc: i32,
1459            detail: impl Into<Option<String>>,
1460        ) -> Result<(), super::SendError> {
1461            let err = if let Some(detail) = detail.into() {
1462                PredefinedResponseError::InternalDetailed(errc, detail)
1463            } else {
1464                PredefinedResponseError::Internal(errc)
1465            };
1466
1467            self.error_predefined_deferred(err)
1468        }
1469
1470        fn error_predefined_deferred(
1471            mut self,
1472            err: super::codec::PredefinedResponseError,
1473        ) -> Result<(), super::SendError> {
1474            let (inner, conn) = self.body.take().unwrap();
1475            conn.tx_drive()
1476                .send(super::InboundDriverDirective::DeferredWrite(
1477                    DeferredWrite::ErrorResponse(inner, err).into(),
1478                ))
1479                .map_err(|_| super::SendError::Disconnected)
1480        }
1481
1482        /* ------------------------------------ Inner Methods ----------------------------------- */
1483        fn prepare_response<T: serde::Serialize>(
1484            mut self,
1485            buf: &mut super::WriteBuffer,
1486            value: Result<&T, &T>,
1487        ) -> Result<Arc<dyn Connection>, EncodeError> {
1488            let (inner, conn) = self.body.take().unwrap();
1489            buf.prepare();
1490
1491            let encode_as_error = value.is_err();
1492            let value = value.unwrap_or_else(|x| x);
1493            inner.codec().encode_response(
1494                inner.req_id(),
1495                encode_as_error,
1496                value,
1497                &mut buf.value,
1498            )?;
1499
1500            Ok(conn.clone())
1501        }
1502    }
1503
1504    impl Drop for Request {
1505        fn drop(&mut self) {
1506            if let Some((inner, conn)) = self.body.take() {
1507                conn.tx_drive()
1508                    .send(super::InboundDriverDirective::DeferredWrite(
1509                        super::DeferredWrite::ErrorResponse(
1510                            inner,
1511                            PredefinedResponseError::Unhandled,
1512                        ),
1513                    ))
1514                    .ok();
1515            }
1516        }
1517    }
1518
1519    /* -------------------------------------- Notify Logics ------------------------------------- */
1520    #[derive(Debug)]
1521    pub struct Notify {
1522        pub(super) h: super::InboundBody,
1523        pub(super) method: Range<usize>,
1524        pub(super) sender: Arc<dyn Connection>,
1525    }
1526
1527    impl_message!(Notify);
1528    impl_method_name!(Notify);
1529
1530    impl ExtractUserData for Notify {
1531        fn user_data_raw(&self) -> &dyn UserData {
1532            self.sender.user_data()
1533        }
1534
1535        fn user_data_owned(&self) -> super::OwnedUserData {
1536            super::OwnedUserData(self.sender.clone())
1537        }
1538
1539        fn extract_sender(&self) -> crate::Sender {
1540            super::Sender(self.sender.clone())
1541        }
1542    }
1543
1544    /* ---------------------------------- Received Message Type --------------------------------- */
1545
1546    #[derive(Debug, EnumAsInner)]
1547    pub enum RecvMsg {
1548        Request(Request),
1549        Notify(Notify),
1550    }
1551
1552    impl ExtractUserData for RecvMsg {
1553        fn user_data_raw(&self) -> &dyn UserData {
1554            match self {
1555                Self::Request(x) => x.user_data_raw(),
1556                Self::Notify(x) => x.user_data_raw(),
1557            }
1558        }
1559
1560        fn user_data_owned(&self) -> super::OwnedUserData {
1561            match self {
1562                Self::Request(x) => x.user_data_owned(),
1563                Self::Notify(x) => x.user_data_owned(),
1564            }
1565        }
1566
1567        fn extract_sender(&self) -> crate::Sender {
1568            match self {
1569                Self::Request(x) => x.extract_sender(),
1570                Self::Notify(x) => x.extract_sender(),
1571            }
1572        }
1573    }
1574
1575    impl super::Message for RecvMsg {
1576        fn payload(&self) -> &[u8] {
1577            match self {
1578                Self::Request(x) => x.payload(),
1579                Self::Notify(x) => x.payload(),
1580            }
1581        }
1582
1583        fn raw(&self) -> &[u8] {
1584            match self {
1585                Self::Request(x) => x.raw(),
1586                Self::Notify(x) => x.raw(),
1587            }
1588        }
1589
1590        fn raw_bytes(&self) -> bytes::Bytes {
1591            match self {
1592                Self::Request(x) => x.raw_bytes(),
1593                Self::Notify(x) => x.raw_bytes(),
1594            }
1595        }
1596
1597        fn codec(&self) -> &dyn crate::codec::Codec {
1598            match self {
1599                Self::Request(x) => x.codec(),
1600                Self::Notify(x) => x.codec(),
1601            }
1602        }
1603    }
1604
1605    impl super::MessageMethodName for RecvMsg {
1606        fn method_raw(&self) -> &[u8] {
1607            match self {
1608                Self::Request(x) => x.method_raw(),
1609                Self::Notify(x) => x.method_raw(),
1610            }
1611        }
1612    }
1613
1614    /* ------------------------------------- Response Logics ------------------------------------ */
1615    #[derive(Debug)]
1616    pub struct Response {
1617        pub(super) h: super::InboundBody,
1618        pub(super) req_id: ReqId,
1619
1620        /// Should we interpret the payload as error object?
1621        pub(super) is_error: bool,
1622    }
1623
1624    impl_message!(Response);
1625    impl_req_id!(Response);
1626
1627    impl Response {
1628        pub fn is_error(&self) -> bool {
1629            self.is_error
1630        }
1631
1632        pub fn result<'a, T: serde::Deserialize<'a>, E: serde::Deserialize<'a>>(
1633            &'a self,
1634        ) -> Result<T, ResponseError<E>> {
1635            if !self.is_error {
1636                Ok(self.parse()?)
1637            } else {
1638                let codec = self.codec();
1639                if let Some(predef) = codec.try_decode_predef_error(self.payload()) {
1640                    Err(ResponseError::Predefined(predef))
1641                } else {
1642                    Err(ResponseError::Typed(self.parse()?))
1643                }
1644            }
1645        }
1646    }
1647
1648    #[derive(EnumAsInner, Debug, thiserror::Error)]
1649    pub enum ResponseError<T> {
1650        #[error("Requested type returned")]
1651        Typed(T),
1652
1653        #[error("Predefined error returned: {0}")]
1654        Predefined(PredefinedResponseError),
1655
1656        #[error("Decode error: {0}")]
1657        DecodeError(#[from] DecodeError),
1658    }
1659}
1660
1661/* ---------------------------------------------------------------------------------------------- */
1662/*                                      DRIVER IMPLEMENTATION                                     */
1663/* ---------------------------------------------------------------------------------------------- */
1664
1665struct DriverBody<C, T, E, R, U, Tr> {
1666    w_this: Weak<ConnectionImpl<C, T, R, U>>,
1667    read: Tr,
1668    ev_subs: E,
1669    rx_drive: flume::Receiver<InboundDriverDirective>,
1670    tx_msg: flume::Sender<msg::RecvMsg>,
1671}
1672
1673mod inner {
1674    //! Internal driver context. It receives the request from the connection, and dispatches it to
1675    //! the handler.
1676
1677    use bytes::BytesMut;
1678    use capture_it::capture;
1679    use futures_util::{future::FusedFuture, FutureExt};
1680    use std::{future::poll_fn, num::NonZeroU64, sync::Arc};
1681
1682    use crate::{
1683        codec::{self, Codec, InboundFrameType},
1684        rpc::{DeferredWrite, SendError},
1685        transport::{AsyncFrameRead, AsyncFrameWrite, FrameReader},
1686    };
1687
1688    use super::{
1689        msg, ConnectionImpl, DriverBody, Feature, GetRequestContext, InboundBody,
1690        InboundDriverDirective, InboundError, InboundEventSubscriber, MessageReqId, UserData,
1691    };
1692
1693    /// Request-acceting version of connection driver
1694    impl<C, T, R, U> ConnectionImpl<C, T, R, U>
1695    where
1696        C: Codec,
1697        T: AsyncFrameWrite,
1698        R: GetRequestContext,
1699        U: UserData,
1700    {
1701        pub(crate) async fn inbound_event_handler<Tr, E>(body: DriverBody<C, T, E, R, U, Tr>)
1702        where
1703            Tr: AsyncFrameRead,
1704            E: InboundEventSubscriber,
1705        {
1706            body.execute().await;
1707        }
1708    }
1709
1710    impl<C, T, E, R, U, Tr> DriverBody<C, T, E, R, U, Tr>
1711    where
1712        C: Codec,
1713        T: AsyncFrameWrite,
1714        E: InboundEventSubscriber,
1715        R: GetRequestContext,
1716        U: UserData,
1717        Tr: AsyncFrameRead,
1718    {
1719        async fn execute(self) {
1720            let DriverBody { w_this, mut read, mut ev_subs, rx_drive: rx_msg, tx_msg } = self;
1721
1722            use futures_util::future::Fuse;
1723            let mut fut_drive_msg = Fuse::terminated();
1724            let mut fut_read = Fuse::terminated();
1725            let mut close_from_remote = false;
1726
1727            /* ----------------------------- Background Sender Task ----------------------------- */
1728            // Tasks that perform deferred write operations in the background. It handles messages
1729            // sent from non-async non-blocking context, such as 'unhandled' response pushed inside
1730            // `Drop` handler.
1731            let (tx_bg_sender, rx_bg_sender) = flume::unbounded();
1732            let fut_bg_sender = capture!([w_this], async move {
1733                let mut pool = super::WriteBuffer::default();
1734                while let Ok(msg) = rx_bg_sender.recv_async().await {
1735                    let Some(this) = w_this.upgrade() else { break };
1736                    match msg {
1737                        DeferredWrite::ErrorResponse(req, err) => {
1738                            let err = this.dyn_ref().__send_err_predef(&mut pool, &req, &err).await;
1739
1740                            // Ignore all other error types, as this message is triggered
1741                            // crate-internally on very limited situations
1742                            if let Err(SendError::IoError(e)) = err {
1743                                return Err(e);
1744                            }
1745                        }
1746                        DeferredWrite::Raw(mut msg) => {
1747                            this.dyn_ref().__write_raw(&mut msg).await?;
1748                        }
1749                        DeferredWrite::Flush => {
1750                            this.dyn_ref().__flush().await?;
1751                        }
1752                    }
1753                }
1754
1755                Ok::<(), std::io::Error>(())
1756            });
1757            let fut_bg_sender = fut_bg_sender.fuse();
1758            let mut fut_bg_sender = std::pin::pin!(fut_bg_sender);
1759            let mut read = std::pin::pin!(read);
1760
1761            /* ------------------------------------ App Loop ------------------------------------ */
1762            loop {
1763                if fut_drive_msg.is_terminated() {
1764                    fut_drive_msg = rx_msg.recv_async().fuse();
1765                }
1766
1767                if fut_read.is_terminated() {
1768                    fut_read = poll_fn(|cx| read.as_mut().poll_next(cx)).fuse();
1769                }
1770
1771                futures_util::select! {
1772                    msg = fut_drive_msg => {
1773                        match msg {
1774                            Ok(InboundDriverDirective::DeferredWrite(msg)) => {
1775                                // This may not fail.
1776                                tx_bg_sender.send_async(msg).await.ok();
1777                            }
1778
1779                            Err(_) | Ok(InboundDriverDirective::Close) => {
1780                                // Connection disposed by 'us' (by dropping RPC handle). i.e.
1781                                // `close_from_remote = false`
1782                                break;
1783                            }
1784                        }
1785                    }
1786
1787                    inbound = fut_read => {
1788                        let Some(this) = w_this.upgrade() else { break };
1789
1790                        match inbound {
1791                            Ok(bytes) => {
1792                                Self::on_read(
1793                                    &this,
1794                                    &mut ev_subs,
1795                                    bytes,
1796                                    &tx_msg
1797                                )
1798                                .await;
1799                            }
1800                            Err(e) => {
1801                                close_from_remote = true;
1802                                ev_subs.on_close(false, Err(e));
1803
1804                                break;
1805                            }
1806                        }
1807                    }
1808
1809                    result = fut_bg_sender => {
1810                        if let Err(err) = result {
1811                            close_from_remote = true;
1812                            ev_subs.on_close(true, Err(err));
1813                        }
1814
1815                        break;
1816                    }
1817                }
1818            }
1819
1820            // If we're exitting with alive handle, manually close the write stream.
1821            if let Some(x) = w_this.upgrade() {
1822                x.dyn_ref().__close().await.ok();
1823            }
1824
1825            // Just try to close the channel
1826            if !close_from_remote {
1827                ev_subs.on_close(true, Ok(())); // We're closing this
1828            }
1829
1830            // Let all pending requests to be cancelled.
1831            'cancel: {
1832                let Some(this) = w_this.upgrade() else { break 'cancel };
1833                let Some(reqs) = this.reqs.get_req_con() else { break 'cancel };
1834
1835                // Let the handle recognized as 'disconnected'
1836                drop(fut_drive_msg);
1837                drop(rx_msg);
1838
1839                // Wake up all pending responses
1840                reqs.wake_up_all();
1841            }
1842        }
1843
1844        async fn on_read(
1845            this: &Arc<ConnectionImpl<C, T, R, U>>,
1846            ev_subs: &mut E,
1847            frame: bytes::Bytes,
1848            tx_msg: &flume::Sender<msg::RecvMsg>,
1849        ) {
1850            let parsed = this.codec.decode_inbound(&frame);
1851            let (header, payload_span) = match parsed {
1852                Ok(x) => x,
1853                Err(e) => {
1854                    ev_subs.on_inbound_error(InboundError::InboundDecodeError(e, frame));
1855                    return;
1856                }
1857            };
1858
1859            let h = InboundBody { buffer: frame, payload: payload_span, codec: this.codec.clone() };
1860            match header {
1861                InboundFrameType::Notify { .. } | InboundFrameType::Request { .. } => {
1862                    let (msg, disabled) = match header {
1863                        InboundFrameType::Notify { method } => (
1864                            msg::RecvMsg::Notify(msg::Notify { h, method, sender: this.clone() }),
1865                            this.features.contains(Feature::NO_RECEIVE_NOTIFY),
1866                        ),
1867                        InboundFrameType::Request { method, req_id } => (
1868                            msg::RecvMsg::Request(msg::Request {
1869                                body: Some((msg::RequestInner { h, method, req_id }, this.clone())),
1870                            }),
1871                            this.features.contains(Feature::NO_RECEIVE_REQUEST),
1872                        ),
1873                        _ => unreachable!(),
1874                    };
1875
1876                    if disabled {
1877                        ev_subs.on_inbound_error(InboundError::DisabledInbound(msg));
1878                    } else {
1879                        tx_msg.send_async(msg).await.ok();
1880                    }
1881                }
1882
1883                InboundFrameType::Response { req_id, req_id_hash, is_error } => {
1884                    let response = msg::Response { h, is_error, req_id };
1885
1886                    let Some(reqs) = this.reqs.get_req_con() else {
1887                        ev_subs.on_inbound_error(InboundError::RedundantResponse(response));
1888                        return;
1889                    };
1890
1891                    let Some(req_id_hash) = NonZeroU64::new(req_id_hash) else {
1892                        ev_subs.on_inbound_error(InboundError::ResponseHashZero(response));
1893                        return;
1894                    };
1895
1896                    if let Err(msg) = reqs.route_response(req_id_hash, response) {
1897                        ev_subs.on_inbound_error(InboundError::ExpiredResponse(msg));
1898                    }
1899                }
1900            }
1901        }
1902    }
1903
1904    /// SAFETY: `ConnectionImpl` is `!Unpin`, thus it is safe to pin the reference.
1905    macro_rules! pin {
1906        ($this:ident, $ident:ident) => {
1907            let mut $ident = $this.write().lock().await;
1908            let mut $ident = unsafe { std::pin::Pin::new_unchecked(&mut *$ident) };
1909        };
1910    }
1911
1912    impl dyn super::Connection {
1913        pub(crate) async fn __send_err_predef(
1914            &self,
1915            buf: &mut super::WriteBuffer,
1916            recv: &msg::RequestInner,
1917            error: &codec::PredefinedResponseError,
1918        ) -> Result<(), super::SendError> {
1919            buf.prepare();
1920
1921            self.codec().encode_response_predefined(recv.req_id(), error, &mut buf.value)?;
1922            self.__write_raw(&mut buf.value).await?;
1923            Ok(())
1924        }
1925
1926        pub(crate) async fn __write_raw(&self, buf: &mut BytesMut) -> std::io::Result<()> {
1927            pin!(self, write);
1928
1929            write.as_mut().begin_write_frame(buf.len())?;
1930            let mut reader = FrameReader::new(buf);
1931
1932            while !reader.is_empty() {
1933                poll_fn(|cx| write.as_mut().poll_write(cx, &mut reader)).await?;
1934            }
1935
1936            Ok(())
1937        }
1938
1939        pub(crate) async fn __flush(&self) -> std::io::Result<()> {
1940            pin!(self, write);
1941            poll_fn(|cx| write.as_mut().poll_flush(cx)).await
1942        }
1943
1944        pub(crate) async fn __close(&self) -> std::io::Result<()> {
1945            pin!(self, write);
1946            poll_fn(|cx| write.as_mut().poll_close(cx)).await
1947        }
1948
1949        pub(crate) fn __is_disconnected(&self) -> bool {
1950            self.tx_drive().is_disconnected()
1951        }
1952    }
1953}