1use std::{
11 any::Any,
12 fmt::Debug,
13 num::NonZeroUsize,
14 ops::Range,
15 sync::{Arc, Weak},
16};
17
18use crate::{
19 codec::{self, Codec, DecodeError, Framing, FramingError},
20 transport::{AsyncFrameRead, AsyncFrameWrite},
21};
22
23use async_lock::Mutex as AsyncMutex;
24use bitflags::bitflags;
25use bytes::{Bytes, BytesMut};
26use enum_as_inner::EnumAsInner;
27use futures_util::AsyncRead;
28pub use req::RequestContext;
29pub use req::{OwnedResponseFuture, ResponseFuture};
30
31bitflags! {
35 #[derive(Default, Debug, Clone, Copy)]
36 pub struct Feature : u32 {
37 const ENABLE_AUTO_RESPONSE = 1 << 1;
44
45 const NO_RECEIVE_REQUEST = 1 << 2;
47
48 const NO_RECEIVE_NOTIFY = 1 << 3;
50 }
51}
52
53struct ConnectionImpl<C, T, R, U> {
62 codec: Arc<C>,
63 write: AsyncMutex<T>,
64 reqs: R,
65 tx_drive: flume::Sender<InboundDriverDirective>,
66 features: Feature,
67 user_data: U,
68 _unpin: std::marker::PhantomPinned,
69}
70
71trait Connection: Send + Sync + 'static + Debug {
73 fn codec(&self) -> &dyn Codec;
74 fn write(&self) -> &AsyncMutex<dyn AsyncFrameWrite>;
75 fn reqs(&self) -> Option<&RequestContext>;
76 fn tx_drive(&self) -> &flume::Sender<InboundDriverDirective>;
77 fn feature_flag(&self) -> Feature;
78 fn user_data(&self) -> &dyn UserData;
79}
80
81impl<C, T, R, U> Connection for ConnectionImpl<C, T, R, U>
82where
83 C: Codec,
84 T: AsyncFrameWrite,
85 R: GetRequestContext,
86 U: UserData,
87{
88 fn codec(&self) -> &dyn Codec {
89 &*self.codec
90 }
91
92 fn write(&self) -> &AsyncMutex<dyn AsyncFrameWrite> {
93 &self.write
94 }
95
96 fn reqs(&self) -> Option<&RequestContext> {
97 self.reqs.get_req_con()
98 }
99
100 fn tx_drive(&self) -> &flume::Sender<InboundDriverDirective> {
101 &self.tx_drive
102 }
103
104 fn feature_flag(&self) -> Feature {
105 self.features
106 }
107
108 fn user_data(&self) -> &dyn UserData {
109 &self.user_data
110 }
111}
112
113impl<C, T, R, U> ConnectionImpl<C, T, R, U>
114where
115 C: Codec,
116 T: AsyncFrameWrite,
117 R: GetRequestContext,
118 U: UserData,
119{
120 fn dyn_ref(&self) -> &dyn Connection {
121 self
122 }
123}
124
125pub trait GetRequestContext: std::fmt::Debug + Send + Sync + 'static {
128 fn get_req_con(&self) -> Option<&RequestContext>;
129}
130
131impl GetRequestContext for Arc<RequestContext> {
132 fn get_req_con(&self) -> Option<&RequestContext> {
133 Some(self)
134 }
135}
136
137impl GetRequestContext for RequestContext {
138 fn get_req_con(&self) -> Option<&RequestContext> {
139 Some(self)
140 }
141}
142
143impl GetRequestContext for () {
144 fn get_req_con(&self) -> Option<&RequestContext> {
145 None
146 }
147}
148
149impl<C, T, R: Debug, U> Debug for ConnectionImpl<C, T, R, U> {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("ConnectionBody").field("reqs", &self.reqs).finish()
152 }
153}
154
155pub trait UserData: Any + Send + Sync + 'static {
157 fn as_any(&self) -> &(dyn Any + Send + Sync + 'static);
158}
159
160impl<T> UserData for T
161where
162 T: Any + Send + Sync + 'static,
163{
164 fn as_any(&self) -> &(dyn Any + Send + Sync + 'static) {
165 self
166 }
167}
168
169impl AsRef<dyn Any + Send + Sync + 'static> for dyn UserData {
170 fn as_ref(&self) -> &(dyn Any + Send + Sync + 'static) {
171 self.as_any()
172 }
173}
174
175#[derive(Clone)]
176pub struct OwnedUserData(Arc<dyn Connection>);
177
178impl OwnedUserData {
179 pub fn raw(&self) -> &dyn UserData {
180 self.0.user_data()
181 }
182}
183
184impl std::ops::Deref for OwnedUserData {
185 type Target = dyn Any;
186
187 fn deref(&self) -> &Self::Target {
188 self.0.user_data().as_any()
189 }
190}
191
192#[derive(Debug)]
199enum InboundDriverDirective {
200 DeferredWrite(DeferredWrite),
202
203 Close,
205}
206
207#[derive(Debug)]
208enum DeferredWrite {
209 ErrorResponse(msg::RequestInner, codec::PredefinedResponseError),
211
212 Raw(bytes::BytesMut),
214
215 Flush,
217}
218
219#[derive(Clone, Debug)]
224pub struct Transceiver(Sender, flume::Receiver<msg::RecvMsg>);
225
226impl std::ops::Deref for Transceiver {
227 type Target = Sender;
228
229 fn deref(&self) -> &Self::Target {
230 &self.0
231 }
232}
233
234impl Transceiver {
235 pub fn into_sender(self) -> Sender {
236 self.0
237 }
238
239 pub fn clone_sender(&self) -> Sender {
240 self.0.clone()
241 }
242
243 pub fn blocking_recv(&self) -> Result<msg::RecvMsg, RecvError> {
244 Ok(self.1.recv()?)
245 }
246
247 pub async fn recv(&self) -> Result<msg::RecvMsg, RecvError> {
248 Ok(self.1.recv_async().await?)
249 }
250
251 pub fn try_recv(&self) -> Result<msg::RecvMsg, TryRecvError> {
252 self.1.try_recv().map_err(|e| match e {
253 flume::TryRecvError::Empty => TryRecvError::Empty,
254 flume::TryRecvError::Disconnected => TryRecvError::Disconnected,
255 })
256 }
257}
258
259impl From<Transceiver> for Sender {
260 fn from(value: Transceiver) -> Self {
261 value.0
262 }
263}
264
265#[derive(Clone, Debug)]
267pub struct Sender(Arc<dyn Connection>);
268
269#[derive(Debug, Clone, Default)]
274pub struct WriteBuffer {
275 value: BytesMut,
276}
277
278impl WriteBuffer {
279 pub(crate) fn prepare(&mut self) {
280 self.value.clear();
281 }
282
283 pub fn reserve(&mut self, size: usize) {
284 self.value.clear();
285 self.value.reserve(size);
286 }
287}
288
289#[derive(Debug, thiserror::Error)]
290pub enum CallError {
291 #[error("Failed to send request: {0}")]
292 SendFailed(#[from] SendError),
293
294 #[error("Failed to receive response: {0}")]
295 FlushFailed(#[from] std::io::Error),
296
297 #[error("Failed to receive response: {0}")]
298 RecvFailed(#[from] RecvError),
299
300 #[error("Remote returned invalid return type: {0}")]
301 ParseFailed(DecodeError, msg::Response),
302
303 #[error("Remote returned error response")]
304 ErrorResponse(msg::Response),
305}
306
307#[derive(Debug, EnumAsInner, thiserror::Error)]
308pub enum TypedCallError<E> {
309 #[error("Failed to send request: {0}")]
310 SendFailed(#[from] super::SendError),
311
312 #[error("Failed to receive response: {0}")]
313 FlushFailed(#[from] std::io::Error),
314
315 #[error("Failed to receive response: {0}")]
316 RecvFailed(#[from] super::RecvError),
317
318 #[error("Remote returned error response")]
319 Response(#[from] super::ResponseError<E>),
320}
321
322impl Sender {
323 pub async fn call<R: serde::de::DeserializeOwned>(
325 &self,
326 method: &str,
327 params: &impl serde::Serialize,
328 ) -> Result<R, CallError> {
329 let response = self.request(method, params).await?;
330
331 self.0.__flush().await?;
332
333 let msg = response.await?;
334
335 if msg.is_error {
336 return Err(CallError::ErrorResponse(msg));
337 }
338
339 match msg.parse() {
340 Ok(value) => Ok(value),
341 Err(err) => Err(CallError::ParseFailed(err, msg)),
342 }
343 }
344
345 pub async fn call_with_err<R: serde::de::DeserializeOwned, E: serde::de::DeserializeOwned>(
347 &self,
348 method: &str,
349 params: &impl serde::Serialize,
350 ) -> Result<R, TypedCallError<E>> {
351 let response = self.request(method, params).await?;
352
353 self.0.__flush().await?;
354
355 let msg = response.await?;
356 Ok(msg.result()?)
357 }
358
359 pub async fn request<T: serde::Serialize>(
362 &self,
363 method: &str,
364 params: &T,
365 ) -> Result<ResponseFuture, SendError> {
366 self.request_with_reuse(&mut Default::default(), method, params).await
367 }
368
369 pub async fn notify<T: serde::Serialize>(
371 &self,
372 method: &str,
373 params: &T,
374 ) -> Result<(), SendError> {
375 self.notify_with_reuse(&mut Default::default(), method, params).await
376 }
377
378 #[doc(hidden)]
379 pub async fn request_with_reuse<T: serde::Serialize>(
380 &self,
381 buf: &mut WriteBuffer,
382 method: &str,
383 params: &T,
384 ) -> Result<ResponseFuture, SendError> {
385 let fut = self.prepare_request(buf, method, params)?;
386 self.0.__write_raw(&mut buf.value).await?;
387
388 Ok(fut)
389 }
390
391 #[doc(hidden)]
392 pub async fn notify_with_reuse<T: serde::Serialize>(
393 &self,
394 buf: &mut WriteBuffer,
395 method: &str,
396 params: &T,
397 ) -> Result<(), SendError> {
398 buf.prepare();
399
400 self.0.codec().encode_notify(method, params, &mut buf.value)?;
401 self.0.__write_raw(&mut buf.value).await?;
402
403 Ok(())
404 }
405
406 pub fn request_deferred<T: serde::Serialize>(
410 &self,
411 method: &str,
412 params: &T,
413 ) -> Result<ResponseFuture, SendError> {
414 self.request_deferred_with_reuse(&mut Default::default(), method, params)
415 }
416
417 pub fn notify_deferred<T: serde::Serialize>(
420 &self,
421 method: &str,
422 params: &T,
423 ) -> Result<(), SendError> {
424 self.notify_deferred_with_reuse(&mut Default::default(), method, params)
425 }
426
427 #[doc(hidden)]
431 pub fn request_deferred_with_reuse<T: serde::Serialize>(
432 &self,
433 buffer: &mut WriteBuffer,
434 method: &str,
435 params: &T,
436 ) -> Result<ResponseFuture, SendError> {
437 buffer.prepare();
438 let fut = self.prepare_request(buffer, method, params);
439
440 self.0
441 .tx_drive()
442 .send(InboundDriverDirective::DeferredWrite(DeferredWrite::Raw(buffer.value.split())))
443 .map_err(|_| SendError::Disconnected)?;
444
445 fut
446 }
447
448 #[doc(hidden)]
451 pub fn notify_deferred_with_reuse<T: serde::Serialize>(
452 &self,
453 buffer: &mut WriteBuffer,
454 method: &str,
455 params: &T,
456 ) -> Result<(), SendError> {
457 buffer.prepare();
458 self.0.codec().encode_notify(method, params, &mut buffer.value)?;
459
460 self.0
461 .tx_drive()
462 .send(InboundDriverDirective::DeferredWrite(DeferredWrite::Raw(buffer.value.split())))
463 .map_err(|_| SendError::Disconnected)
464 }
465
466 fn prepare_request<T: serde::Serialize>(
467 &self,
468 buf: &mut WriteBuffer,
469 method: &str,
470 params: &T,
471 ) -> Result<ResponseFuture, SendError> {
472 let Some(req) = self.0.reqs() else { return Err(SendError::RequestDisabled) };
473
474 buf.prepare();
475 let req_id_hint = req.next_req_id_base();
476 let req_id_hash =
477 self.0.codec().encode_request(method, req_id_hint, params, &mut buf.value)?;
478
479 let slot_id = req.register_req(req_id_hash);
483
484 Ok(ResponseFuture::new(&self.0, slot_id))
485 }
486
487 pub fn is_disconnected(&self) -> bool {
489 self.0.tx_drive().is_disconnected()
490 }
491
492 pub fn close(self) -> bool {
499 self.0.tx_drive().send(InboundDriverDirective::Close).is_ok()
500 }
501
502 pub async fn flush(&self) -> std::io::Result<()> {
504 self.0.__flush().await
505 }
506
507 pub fn flush_deferred(&self) -> bool {
509 self.0.tx_drive().send(InboundDriverDirective::DeferredWrite(DeferredWrite::Flush)).is_ok()
510 }
511
512 pub fn is_request_enabled(&self) -> bool {
514 self.0.reqs().is_some()
515 }
516
517 pub fn get_feature_flags(&self) -> Feature {
519 self.0.feature_flag()
520 }
521}
522
523mod req {
524 use std::{
525 borrow::Cow,
526 mem::replace,
527 num::NonZeroU64,
528 sync::{atomic::AtomicU64, Arc},
529 task::Poll,
530 };
531
532 use dashmap::DashMap;
533 use futures_util::{task::AtomicWaker, FutureExt};
534 use parking_lot::Mutex;
535
536 use crate::RecvError;
537
538 use super::{msg, Connection};
539
540 #[derive(Debug, Default)]
542 pub struct RequestContext {
543 req_id_gen: AtomicU64,
545 waiters: DashMap<NonZeroU64, RequestSlot>,
546 }
547
548 #[derive(Clone, Debug)]
549 pub(super) struct RequestSlotId(NonZeroU64);
550
551 #[derive(Debug)]
552 struct RequestSlot {
553 waker: AtomicWaker,
554 value: Mutex<Option<msg::Response>>,
555 }
556
557 impl RequestContext {
558 pub(super) fn next_req_id_base(&self) -> NonZeroU64 {
560 unsafe {
563 NonZeroU64::new_unchecked(
564 1 + self.req_id_gen.fetch_add(2, std::sync::atomic::Ordering::Relaxed),
565 )
566 }
567 }
568
569 #[must_use]
570 pub(super) fn register_req(&self, req_id_hash: NonZeroU64) -> RequestSlotId {
571 let slot = RequestSlot { waker: AtomicWaker::new(), value: Mutex::new(None) };
572 if self.waiters.insert(req_id_hash, slot).is_some() {
573 panic!("Request ID collision")
574 }
575 RequestSlotId(req_id_hash)
576 }
577
578 pub(super) fn route_response(
580 &self,
581 req_id_hash: NonZeroU64,
582 response: msg::Response,
583 ) -> Result<(), msg::Response> {
584 let Some(slot) = self.waiters.get(&req_id_hash) else {
585 return Err(response);
586 };
587
588 let mut value = slot.value.lock();
589 value.replace(response);
590 slot.waker.wake();
591
592 Ok(())
593 }
594
595 pub(super) fn wake_up_all(&self) {
599 for x in self.waiters.iter() {
600 x.waker.wake();
601 }
602 }
603 }
604
605 #[must_use = "futures do nothing unless you `.await` or poll them"]
607 #[derive(Debug)]
608 pub struct ResponseFuture<'a>(ResponseFutureInner<'a>);
609
610 #[must_use = "futures do nothing unless you `.await` or poll them"]
611 #[derive(Debug)]
612 pub struct OwnedResponseFuture(ResponseFuture<'static>);
613
614 #[derive(Debug)]
615 enum ResponseFutureInner<'a> {
616 Waiting(Cow<'a, Arc<dyn Connection>>, NonZeroU64),
617 Finished,
618 }
619
620 impl<'a> ResponseFuture<'a> {
621 pub(super) fn new(handle: &'a Arc<dyn Connection>, slot_id: RequestSlotId) -> Self {
622 Self(ResponseFutureInner::Waiting(Cow::Borrowed(handle), slot_id.0))
623 }
624
625 pub fn to_owned(mut self) -> OwnedResponseFuture {
626 let state = replace(&mut self.0, ResponseFutureInner::Finished);
627
628 match state {
629 ResponseFutureInner::Waiting(conn, id) => OwnedResponseFuture(ResponseFuture(
630 ResponseFutureInner::Waiting(Cow::Owned(conn.into_owned()), id),
631 )),
632 ResponseFutureInner::Finished => {
633 OwnedResponseFuture(ResponseFuture(ResponseFutureInner::Finished))
634 }
635 }
636 }
637
638 pub fn try_recv(&mut self) -> Result<Option<msg::Response>, RecvError> {
639 use ResponseFutureInner::*;
640
641 match &mut self.0 {
642 Waiting(conn, hash) => {
643 if conn.__is_disconnected() {
644 return Err(RecvError::Disconnected);
646 }
647
648 let mut value = None;
649 conn.reqs().unwrap().waiters.remove_if(hash, |_, elem| {
650 if let Some(v) = elem.value.lock().take() {
651 value = Some(v);
652 true
653 } else {
654 false
655 }
656 });
657
658 if value.is_some() {
659 self.0 = Finished;
660 }
661
662 Ok(value)
663 }
664
665 Finished => Ok(None),
666 }
667 }
668 }
669
670 impl<'a> std::future::Future for ResponseFuture<'a> {
671 type Output = Result<msg::Response, RecvError>;
672
673 fn poll(
674 mut self: std::pin::Pin<&mut Self>,
675 cx: &mut std::task::Context<'_>,
676 ) -> std::task::Poll<Self::Output> {
677 use ResponseFutureInner::*;
678
679 match &mut self.0 {
680 Waiting(conn, hash) => {
681 if conn.__is_disconnected() {
682 return Poll::Ready(Err(RecvError::Disconnected));
684 }
685
686 let mut value = None;
687 conn.reqs().unwrap().waiters.remove_if(hash, |_, elem| {
688 if let Some(v) = elem.value.lock().take() {
689 value = Some(v);
690 true
691 } else {
692 elem.waker.register(cx.waker());
693 false
694 }
695 });
696
697 if let Some(value) = value {
698 self.0 = Finished;
699 Poll::Ready(Ok(value))
700 } else {
701 Poll::Pending
702 }
703 }
704
705 Finished => panic!("ResponseFuture polled after completion"),
706 }
707 }
708 }
709
710 impl std::future::Future for OwnedResponseFuture {
711 type Output = Result<msg::Response, RecvError>;
712
713 fn poll(
714 mut self: std::pin::Pin<&mut Self>,
715 cx: &mut std::task::Context<'_>,
716 ) -> std::task::Poll<Self::Output> {
717 self.0.poll_unpin(cx)
718 }
719 }
720
721 impl futures_util::future::FusedFuture for ResponseFuture<'_> {
722 fn is_terminated(&self) -> bool {
723 matches!(self.0, ResponseFutureInner::Finished)
724 }
725 }
726
727 impl futures_util::future::FusedFuture for OwnedResponseFuture {
728 fn is_terminated(&self) -> bool {
729 self.0.is_terminated()
730 }
731 }
732
733 impl std::ops::Deref for OwnedResponseFuture {
734 type Target = ResponseFuture<'static>;
735
736 fn deref(&self) -> &Self::Target {
737 &self.0
738 }
739 }
740
741 impl std::ops::DerefMut for OwnedResponseFuture {
742 fn deref_mut(&mut self) -> &mut Self::Target {
743 &mut self.0
744 }
745 }
746
747 impl Drop for ResponseFuture<'_> {
748 fn drop(&mut self) {
749 let state = std::mem::replace(&mut self.0, ResponseFutureInner::Finished);
750 let ResponseFutureInner::Waiting(conn, hash) = state else { return };
751 let reqs = conn.reqs().unwrap();
752 assert!(
753 reqs.waiters.remove(&hash).is_some(),
754 "Request lifespan must be bound to this future."
755 );
756 }
757 }
758}
759
760#[derive(Debug, thiserror::Error)]
764pub enum SendError {
765 #[error("Encoding outbound message failed: {0}")]
766 CodecError(#[from] crate::codec::EncodeError),
767
768 #[error("Error during preparing send: {0}")]
769 SendSetupFailed(std::io::Error),
770
771 #[error("Error during sending message to write stream: {0}")]
772 IoError(#[from] std::io::Error),
773
774 #[error("Request feature is disabled for this connection")]
775 RequestDisabled,
776
777 #[error("Channel is already closed!")]
778 Disconnected,
779}
780
781#[derive(Debug, thiserror::Error)]
782pub enum RecvError {
783 #[error("Channel has been closed")]
784 Disconnected,
785}
786
787#[derive(Debug, thiserror::Error)]
788pub enum TryRecvError {
789 #[error("Connection has been disconnected")]
790 Disconnected,
791
792 #[error("Channel is empty")]
793 Empty,
794}
795
796impl From<flume::RecvError> for RecvError {
797 fn from(value: flume::RecvError) -> Self {
798 match value {
799 flume::RecvError::Disconnected => Self::Disconnected,
800 }
801 }
802}
803
804pub struct Builder<Tw, Tr, C, E, R, U> {
810 codec: C,
811 write: Tw,
812 read: Tr,
813 ev: E,
814 reqs: R,
815 user_data: U,
816
817 cfg: BuilderOtherConfig,
819}
820
821#[derive(Default)]
822struct BuilderOtherConfig {
823 inbound_channel_cap: Option<NonZeroUsize>,
824 feature_flag: Feature,
825}
826
827impl Default for Builder<(), (), (), EmptyEventListener, (), ()> {
828 fn default() -> Self {
829 Self {
830 codec: (),
831 write: (),
832 read: (),
833 cfg: Default::default(),
834 ev: EmptyEventListener,
835 reqs: (),
836 user_data: (),
837 }
838 }
839}
840
841impl<Tw, Tr, C, E, R, U> Builder<Tw, Tr, C, E, R, U> {
842 pub fn with_codec<C1: Codec>(
844 self,
845 codec: impl Into<Arc<C1>>,
846 ) -> Builder<Tw, Tr, Arc<C1>, E, R, U> {
847 Builder {
848 codec: codec.into(),
849 write: self.write,
850 read: self.read,
851 cfg: self.cfg,
852 ev: self.ev,
853 reqs: self.reqs,
854 user_data: self.user_data,
855 }
856 }
857
858 pub fn with_write<Tw2>(self, write: Tw2) -> Builder<Tw2, Tr, C, E, R, U>
860 where
861 Tw2: AsyncFrameWrite,
862 {
863 Builder {
864 codec: self.codec,
865 write,
866 read: self.read,
867 cfg: self.cfg,
868 ev: self.ev,
869 reqs: self.reqs,
870 user_data: self.user_data,
871 }
872 }
873
874 pub fn with_read<Tr2>(self, read: Tr2) -> Builder<Tw, Tr2, C, E, R, U>
876 where
877 Tr2: AsyncFrameRead,
878 {
879 Builder {
880 codec: self.codec,
881 write: self.write,
882 read,
883 cfg: self.cfg,
884 ev: self.ev,
885 reqs: self.reqs,
886 user_data: self.user_data,
887 }
888 }
889
890 pub fn with_event_listener<E2: InboundEventSubscriber>(
893 self,
894 ev: E2,
895 ) -> Builder<Tw, Tr, C, E2, R, U> {
896 Builder {
897 codec: self.codec,
898 write: self.write,
899 read: self.read,
900 cfg: self.cfg,
901 ev,
902 reqs: self.reqs,
903 user_data: self.user_data,
904 }
905 }
906
907 pub fn with_user_data<U2: UserData>(self, user_data: U2) -> Builder<Tw, Tr, C, E, R, U2> {
909 Builder {
910 codec: self.codec,
911 write: self.write,
912 read: self.read,
913 cfg: self.cfg,
914 ev: self.ev,
915 reqs: self.reqs,
916 user_data,
917 }
918 }
919
920 pub fn with_read_stream<Tr2, F>(
928 self,
929 read: Tr2,
930 framing: F,
931 default_readbuf_reserve: usize,
932 ) -> Builder<Tw, impl AsyncFrameRead, C, E, R, U>
933 where
934 Tr2: AsyncRead + Send + Sync + 'static,
935 F: Framing,
936 {
937 use std::pin::Pin;
938 use std::task::{Context, Poll};
939
940 struct FramingReader<T, F> {
941 reader: T,
942 framing: F,
943 buf: bytes::BytesMut,
944 nreserve: usize,
945 state_skip_reading: bool,
946 }
947
948 impl<T, F> AsyncFrameRead for FramingReader<T, F>
949 where
950 T: AsyncRead + Sync + Send + 'static,
951 F: Framing,
952 {
953 fn poll_next(
954 self: Pin<&mut Self>,
955 cx: &mut Context<'_>,
956 ) -> Poll<std::io::Result<Bytes>> {
957 let FramingReader { reader, framing, buf, nreserve, state_skip_reading } =
959 unsafe { self.get_unchecked_mut() };
960
961 let mut reader = unsafe { Pin::new_unchecked(reader) };
962
963 loop {
964 let size_hint = framing.next_buffer_size();
966 let size_required = size_hint.map(|x| x.get());
967
968 while !*state_skip_reading && size_required.is_some_and(|x| buf.len() < x) {
969 let n_req_size = size_required.unwrap_or(0).saturating_sub(buf.len());
970 let num_reserve = (*nreserve).max(n_req_size);
971
972 let old_cursor = buf.len();
973 buf.reserve(num_reserve);
974
975 unsafe {
976 buf.set_len(old_cursor + num_reserve);
977 match reader.as_mut().poll_read(cx, buf) {
978 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
979 Poll::Ready(Ok(n)) => {
980 buf.set_len(old_cursor + n);
981 }
982 Poll::Pending => {
983 buf.set_len(old_cursor);
984
985 *state_skip_reading = true;
987 break;
988 }
989 }
990 }
991 }
992
993 match framing.try_framing(&buf[..]) {
995 Ok(Some(x)) => {
996 framing.advance();
997
998 debug_assert!(x.valid_data_end <= x.next_frame_start);
999 debug_assert!(x.next_frame_start <= buf.len());
1000
1001 let mut frame = buf.split_to(x.next_frame_start);
1002 break Poll::Ready(Ok(frame.split_off(x.valid_data_end).into()));
1003 }
1004
1005 Ok(None) => {
1007 *state_skip_reading = false;
1008 continue;
1009 }
1010
1011 Err(FramingError::Broken(why)) => {
1012 return Poll::Ready(Err(std::io::Error::new(
1013 std::io::ErrorKind::InvalidData,
1014 why,
1015 )))
1016 }
1017
1018 Err(FramingError::Recoverable(num_discard_bytes)) => {
1019 debug_assert!(num_discard_bytes <= buf.len());
1020
1021 let _ = buf.split_to(num_discard_bytes);
1022 continue;
1023 }
1024 }
1025 }
1026 }
1027 }
1028
1029 Builder {
1030 codec: self.codec,
1031 write: self.write,
1032 read: FramingReader {
1033 reader: read,
1034 framing,
1035 buf: BytesMut::default(),
1036 nreserve: default_readbuf_reserve,
1037 state_skip_reading: false,
1038 },
1039 cfg: self.cfg,
1040 ev: self.ev,
1041 reqs: self.reqs,
1042 user_data: self.user_data,
1043 }
1044 }
1045
1046 pub fn with_request(self) -> Builder<Tw, Tr, C, E, RequestContext, U> {
1048 Builder {
1049 codec: self.codec,
1050 write: self.write,
1051 read: self.read,
1052 cfg: self.cfg,
1053 ev: self.ev,
1054 reqs: RequestContext::default(),
1055 user_data: self.user_data,
1056 }
1057 }
1058
1059 pub fn with_request_context(
1062 self,
1063 reqs: impl GetRequestContext,
1064 ) -> Builder<Tw, Tr, C, E, impl GetRequestContext, U> {
1065 Builder {
1066 codec: self.codec,
1067 write: self.write,
1068 read: self.read,
1069 cfg: self.cfg,
1070 ev: self.ev,
1071 reqs,
1072 user_data: self.user_data,
1073 }
1074 }
1075
1076 pub fn with_inbound_channel_capacity(mut self, capacity: usize) -> Self {
1078 self.cfg.inbound_channel_cap = capacity.try_into().ok();
1079 self
1080 }
1081
1082 pub fn with_feature(mut self, feature: Feature) -> Self {
1084 self.cfg.feature_flag |= feature;
1085 self
1086 }
1087
1088 pub fn without_feature(mut self, feature: Feature) -> Self {
1090 self.cfg.feature_flag &= !feature;
1091 self
1092 }
1093}
1094
1095impl<Tw, Tr, C, E, R, U> Builder<Tw, Tr, Arc<C>, E, R, U>
1096where
1097 Tw: AsyncFrameWrite,
1098 Tr: AsyncFrameRead,
1099 C: Codec,
1100 E: InboundEventSubscriber,
1101 R: GetRequestContext,
1102 U: UserData,
1103{
1104 #[must_use = "The connection will be closed immediately if you don't spawn the future!"]
1108 pub fn build(self) -> (Transceiver, impl std::future::Future<Output = ()> + Send) {
1109 let (tx_inb_drv, rx_inb_drv) = flume::unbounded();
1110 let (tx_in_msg, rx_in_msg) =
1111 if let Some(chan_cap) = self.cfg.inbound_channel_cap.map(|x| x.get()) {
1112 flume::bounded(chan_cap)
1113 } else {
1114 flume::unbounded()
1115 };
1116
1117 let conn: Arc<dyn Connection>;
1118 let this = ConnectionImpl {
1119 codec: self.codec,
1120 write: AsyncMutex::new(self.write),
1121 reqs: self.reqs,
1122 tx_drive: tx_inb_drv,
1123 features: self.cfg.feature_flag,
1124 user_data: self.user_data,
1125 _unpin: std::marker::PhantomPinned,
1126 };
1127
1128 let this = Arc::new(this);
1129 let fut_driver = ConnectionImpl::inbound_event_handler(DriverBody {
1130 w_this: Arc::downgrade(&this),
1131 read: self.read,
1132 ev_subs: self.ev,
1133 rx_drive: rx_inb_drv,
1134 tx_msg: tx_in_msg,
1135 });
1136
1137 conn = this;
1138 (Transceiver(Sender(conn), rx_in_msg), fut_driver)
1139 }
1140}
1141
1142#[derive(Debug, thiserror::Error)]
1147pub enum InboundError {
1148 #[error("Response packet is received, but we haven't enabled request feature!")]
1149 RedundantResponse(msg::Response),
1150
1151 #[error("Response packet is not routed.")]
1152 ExpiredResponse(msg::Response),
1153
1154 #[error("Response hash was restored as 0, which is invalid.")]
1155 ResponseHashZero(msg::Response),
1156
1157 #[error("Failed to decode inbound type: {} bytes, error was = {0} ", .1.len())]
1158 InboundDecodeError(DecodeError, bytes::Bytes),
1159
1160 #[error("Disabled notification/request is received")]
1161 DisabledInbound(msg::RecvMsg),
1162}
1163
1164pub trait InboundEventSubscriber: Send + Sync + 'static {
1166 fn on_close(&self, closed_by_us: bool, result: std::io::Result<()>) {
1175 let _ = (closed_by_us, result);
1176 }
1177
1178 fn on_inbound_error(&self, error: InboundError) {
1182 let _ = (error,);
1183 }
1184}
1185
1186pub struct EmptyEventListener;
1188
1189impl InboundEventSubscriber for EmptyEventListener {}
1190
1191pub trait Message {
1195 fn payload(&self) -> &[u8];
1196 fn raw(&self) -> &[u8];
1197 fn raw_bytes(&self) -> Bytes;
1198 fn codec(&self) -> &dyn Codec;
1199
1200 fn parse<'a, T: serde::de::Deserialize<'a>>(&'a self) -> Result<T, DecodeError> {
1202 let mut dst = None;
1203 self.parse_in_place(&mut dst)?;
1204 Ok(dst.unwrap())
1205 }
1206
1207 #[doc(hidden)]
1211 fn parse_in_place<'a, T: serde::de::Deserialize<'a>>(
1212 &'a self,
1213 dst: &mut Option<T>,
1214 ) -> Result<(), DecodeError> {
1215 self.codec().decode_payload(self.payload(), &mut |de| {
1216 if let Some(dst) = dst.as_mut() {
1217 T::deserialize_in_place(de, dst)
1218 } else {
1219 *dst = Some(T::deserialize(de)?);
1220 Ok(())
1221 }
1222 })
1223 }
1224}
1225
1226pub trait MessageReqId: Message {
1227 fn req_id(&self) -> codec::ReqIdRef;
1228}
1229
1230pub trait MessageMethodName: Message {
1231 fn method_raw(&self) -> &[u8];
1232 fn method(&self) -> Option<&str> {
1233 std::str::from_utf8(self.method_raw()).ok()
1234 }
1235}
1236
1237pub trait ExtractUserData {
1238 fn extract_sender(&self) -> Sender;
1239 fn user_data_raw(&self) -> &dyn UserData;
1240 fn user_data_owned(&self) -> OwnedUserData;
1241 fn user_data<T: UserData>(&self) -> Option<&T> {
1242 self.user_data_raw().as_any().downcast_ref()
1243 }
1244}
1245
1246impl ExtractUserData for Sender {
1247 fn extract_sender(&self) -> Sender {
1248 self.clone()
1249 }
1250
1251 fn user_data_raw(&self) -> &dyn UserData {
1252 self.0.user_data()
1253 }
1254
1255 fn user_data_owned(&self) -> OwnedUserData {
1256 OwnedUserData(self.0.clone())
1257 }
1258}
1259
1260#[derive(Debug)]
1262struct InboundBody {
1263 buffer: Bytes,
1264 payload: Range<usize>,
1265 codec: Arc<dyn Codec>,
1266}
1267
1268pub mod msg {
1269 use std::{ops::Range, sync::Arc};
1270
1271 use enum_as_inner::EnumAsInner;
1272
1273 use crate::{
1274 codec::{DecodeError, EncodeError, PredefinedResponseError, ReqId, ReqIdRef},
1275 rpc::MessageReqId,
1276 };
1277
1278 use super::{Connection, DeferredWrite, ExtractUserData, Message, UserData, WriteBuffer};
1279
1280 macro_rules! impl_message {
1281 ($t:ty) => {
1282 impl super::Message for $t {
1283 fn payload(&self) -> &[u8] {
1284 &self.h.buffer[self.h.payload.clone()]
1285 }
1286
1287 fn raw(&self) -> &[u8] {
1288 &self.h.buffer
1289 }
1290
1291 fn raw_bytes(&self) -> bytes::Bytes {
1292 self.h.buffer.clone()
1293 }
1294
1295 fn codec(&self) -> &dyn super::Codec {
1296 &*self.h.codec
1297 }
1298 }
1299 };
1300 }
1301
1302 macro_rules! impl_method_name {
1303 ($t:ty) => {
1304 impl super::MessageMethodName for $t {
1305 fn method_raw(&self) -> &[u8] {
1306 &self.h.buffer[self.method.clone()]
1307 }
1308 }
1309 };
1310 }
1311
1312 macro_rules! impl_req_id {
1313 ($t:ty) => {
1314 impl super::MessageReqId for $t {
1315 fn req_id(&self) -> ReqIdRef {
1316 self.req_id.make_ref(&self.h.buffer)
1317 }
1318 }
1319 };
1320 }
1321
1322 #[derive(Debug)]
1324 pub struct RequestInner {
1325 pub(super) h: super::InboundBody,
1326 pub(super) method: Range<usize>,
1327 pub(super) req_id: ReqId,
1328 }
1329
1330 impl_message!(RequestInner);
1331 impl_method_name!(RequestInner);
1332 impl_req_id!(RequestInner);
1333
1334 #[derive(Debug)]
1335 pub struct Request {
1336 pub(super) body: Option<(RequestInner, Arc<dyn Connection>)>,
1337 }
1338
1339 impl std::ops::Deref for Request {
1340 type Target = RequestInner;
1341
1342 fn deref(&self) -> &Self::Target {
1343 &self.body.as_ref().unwrap().0
1344 }
1345 }
1346
1347 impl ExtractUserData for Request {
1348 fn user_data_raw(&self) -> &dyn UserData {
1349 self.body.as_ref().unwrap().1.user_data()
1350 }
1351 fn user_data_owned(&self) -> super::OwnedUserData {
1352 super::OwnedUserData(self.body.as_ref().unwrap().1.clone())
1353 }
1354 fn extract_sender(&self) -> crate::Sender {
1355 super::Sender(self.body.as_ref().unwrap().1.clone())
1356 }
1357 }
1358
1359 impl Request {
1360 pub async fn response<T: serde::Serialize>(
1361 self,
1362 value: Result<&T, &T>,
1363 ) -> Result<(), super::SendError> {
1364 let mut buf = Default::default();
1365 self.response_with_reuse(&mut buf, value).await
1366 }
1367
1368 #[doc(hidden)]
1370 pub async fn response_with_reuse<T: serde::Serialize>(
1371 self,
1372 buf: &mut super::WriteBuffer,
1373 value: Result<&T, &T>,
1374 ) -> Result<(), super::SendError> {
1375 let conn = self.prepare_response(buf, value)?;
1376 conn.__write_raw(&mut buf.value).await?;
1377 Ok(())
1378 }
1379
1380 pub async fn abort(self) -> Result<(), super::SendError> {
1382 self.error_predefined(PredefinedResponseError::Aborted).await
1383 }
1384
1385 pub async fn error_notify_handler(self) -> Result<(), super::SendError> {
1387 self.error_predefined(PredefinedResponseError::NotifyHandler).await
1388 }
1389
1390 pub async fn error_parse_failed<T>(self) -> Result<(), super::SendError> {
1392 let err = PredefinedResponseError::ParseFailed(std::any::type_name::<T>().into());
1393 self.error_predefined(err).await
1394 }
1395
1396 pub async fn error_internal(
1398 self,
1399 errc: i32,
1400 detail: impl Into<Option<String>>,
1401 ) -> Result<(), super::SendError> {
1402 let err = if let Some(detail) = detail.into() {
1403 PredefinedResponseError::InternalDetailed(errc, detail)
1404 } else {
1405 PredefinedResponseError::Internal(errc)
1406 };
1407
1408 self.error_predefined(err).await
1409 }
1410
1411 async fn error_predefined(
1412 mut self,
1413 err: super::codec::PredefinedResponseError,
1414 ) -> Result<(), super::SendError> {
1415 let (inner, conn) = self.body.take().unwrap();
1416 conn.__send_err_predef(&mut Default::default(), &inner, &err).await
1417 }
1418
1419 #[doc(hidden)]
1421 pub fn response_deferred_with_reuse<T: serde::Serialize>(
1422 self,
1423 buffer: &mut WriteBuffer,
1424 value: Result<&T, &T>,
1425 ) -> Result<(), super::SendError> {
1426 buffer.prepare();
1427 let conn = self.prepare_response(buffer, value).unwrap();
1428
1429 conn.tx_drive()
1430 .send(super::InboundDriverDirective::DeferredWrite(
1431 DeferredWrite::Raw(buffer.value.split()).into(),
1432 ))
1433 .map_err(|_| super::SendError::Disconnected)
1434 }
1435
1436 pub fn response_deferred<T: serde::Serialize>(
1437 self,
1438 value: Result<&T, &T>,
1439 ) -> Result<(), super::SendError> {
1440 self.response_deferred_with_reuse(&mut Default::default(), value)
1441 }
1442
1443 pub fn abort_deferred(self) -> Result<(), super::SendError> {
1444 self.error_predefined_deferred(PredefinedResponseError::Aborted)
1445 }
1446
1447 pub fn error_notify_handler_deferred(self) -> Result<(), super::SendError> {
1448 self.error_predefined_deferred(PredefinedResponseError::NotifyHandler)
1449 }
1450
1451 pub fn error_parse_failed_deferred<T>(self) -> Result<(), super::SendError> {
1452 let err = PredefinedResponseError::ParseFailed(std::any::type_name::<T>().into());
1453 self.error_predefined_deferred(err)
1454 }
1455
1456 pub fn error_internal_deferred(
1457 self,
1458 errc: i32,
1459 detail: impl Into<Option<String>>,
1460 ) -> Result<(), super::SendError> {
1461 let err = if let Some(detail) = detail.into() {
1462 PredefinedResponseError::InternalDetailed(errc, detail)
1463 } else {
1464 PredefinedResponseError::Internal(errc)
1465 };
1466
1467 self.error_predefined_deferred(err)
1468 }
1469
1470 fn error_predefined_deferred(
1471 mut self,
1472 err: super::codec::PredefinedResponseError,
1473 ) -> Result<(), super::SendError> {
1474 let (inner, conn) = self.body.take().unwrap();
1475 conn.tx_drive()
1476 .send(super::InboundDriverDirective::DeferredWrite(
1477 DeferredWrite::ErrorResponse(inner, err).into(),
1478 ))
1479 .map_err(|_| super::SendError::Disconnected)
1480 }
1481
1482 fn prepare_response<T: serde::Serialize>(
1484 mut self,
1485 buf: &mut super::WriteBuffer,
1486 value: Result<&T, &T>,
1487 ) -> Result<Arc<dyn Connection>, EncodeError> {
1488 let (inner, conn) = self.body.take().unwrap();
1489 buf.prepare();
1490
1491 let encode_as_error = value.is_err();
1492 let value = value.unwrap_or_else(|x| x);
1493 inner.codec().encode_response(
1494 inner.req_id(),
1495 encode_as_error,
1496 value,
1497 &mut buf.value,
1498 )?;
1499
1500 Ok(conn.clone())
1501 }
1502 }
1503
1504 impl Drop for Request {
1505 fn drop(&mut self) {
1506 if let Some((inner, conn)) = self.body.take() {
1507 conn.tx_drive()
1508 .send(super::InboundDriverDirective::DeferredWrite(
1509 super::DeferredWrite::ErrorResponse(
1510 inner,
1511 PredefinedResponseError::Unhandled,
1512 ),
1513 ))
1514 .ok();
1515 }
1516 }
1517 }
1518
1519 #[derive(Debug)]
1521 pub struct Notify {
1522 pub(super) h: super::InboundBody,
1523 pub(super) method: Range<usize>,
1524 pub(super) sender: Arc<dyn Connection>,
1525 }
1526
1527 impl_message!(Notify);
1528 impl_method_name!(Notify);
1529
1530 impl ExtractUserData for Notify {
1531 fn user_data_raw(&self) -> &dyn UserData {
1532 self.sender.user_data()
1533 }
1534
1535 fn user_data_owned(&self) -> super::OwnedUserData {
1536 super::OwnedUserData(self.sender.clone())
1537 }
1538
1539 fn extract_sender(&self) -> crate::Sender {
1540 super::Sender(self.sender.clone())
1541 }
1542 }
1543
1544 #[derive(Debug, EnumAsInner)]
1547 pub enum RecvMsg {
1548 Request(Request),
1549 Notify(Notify),
1550 }
1551
1552 impl ExtractUserData for RecvMsg {
1553 fn user_data_raw(&self) -> &dyn UserData {
1554 match self {
1555 Self::Request(x) => x.user_data_raw(),
1556 Self::Notify(x) => x.user_data_raw(),
1557 }
1558 }
1559
1560 fn user_data_owned(&self) -> super::OwnedUserData {
1561 match self {
1562 Self::Request(x) => x.user_data_owned(),
1563 Self::Notify(x) => x.user_data_owned(),
1564 }
1565 }
1566
1567 fn extract_sender(&self) -> crate::Sender {
1568 match self {
1569 Self::Request(x) => x.extract_sender(),
1570 Self::Notify(x) => x.extract_sender(),
1571 }
1572 }
1573 }
1574
1575 impl super::Message for RecvMsg {
1576 fn payload(&self) -> &[u8] {
1577 match self {
1578 Self::Request(x) => x.payload(),
1579 Self::Notify(x) => x.payload(),
1580 }
1581 }
1582
1583 fn raw(&self) -> &[u8] {
1584 match self {
1585 Self::Request(x) => x.raw(),
1586 Self::Notify(x) => x.raw(),
1587 }
1588 }
1589
1590 fn raw_bytes(&self) -> bytes::Bytes {
1591 match self {
1592 Self::Request(x) => x.raw_bytes(),
1593 Self::Notify(x) => x.raw_bytes(),
1594 }
1595 }
1596
1597 fn codec(&self) -> &dyn crate::codec::Codec {
1598 match self {
1599 Self::Request(x) => x.codec(),
1600 Self::Notify(x) => x.codec(),
1601 }
1602 }
1603 }
1604
1605 impl super::MessageMethodName for RecvMsg {
1606 fn method_raw(&self) -> &[u8] {
1607 match self {
1608 Self::Request(x) => x.method_raw(),
1609 Self::Notify(x) => x.method_raw(),
1610 }
1611 }
1612 }
1613
1614 #[derive(Debug)]
1616 pub struct Response {
1617 pub(super) h: super::InboundBody,
1618 pub(super) req_id: ReqId,
1619
1620 pub(super) is_error: bool,
1622 }
1623
1624 impl_message!(Response);
1625 impl_req_id!(Response);
1626
1627 impl Response {
1628 pub fn is_error(&self) -> bool {
1629 self.is_error
1630 }
1631
1632 pub fn result<'a, T: serde::Deserialize<'a>, E: serde::Deserialize<'a>>(
1633 &'a self,
1634 ) -> Result<T, ResponseError<E>> {
1635 if !self.is_error {
1636 Ok(self.parse()?)
1637 } else {
1638 let codec = self.codec();
1639 if let Some(predef) = codec.try_decode_predef_error(self.payload()) {
1640 Err(ResponseError::Predefined(predef))
1641 } else {
1642 Err(ResponseError::Typed(self.parse()?))
1643 }
1644 }
1645 }
1646 }
1647
1648 #[derive(EnumAsInner, Debug, thiserror::Error)]
1649 pub enum ResponseError<T> {
1650 #[error("Requested type returned")]
1651 Typed(T),
1652
1653 #[error("Predefined error returned: {0}")]
1654 Predefined(PredefinedResponseError),
1655
1656 #[error("Decode error: {0}")]
1657 DecodeError(#[from] DecodeError),
1658 }
1659}
1660
1661struct DriverBody<C, T, E, R, U, Tr> {
1666 w_this: Weak<ConnectionImpl<C, T, R, U>>,
1667 read: Tr,
1668 ev_subs: E,
1669 rx_drive: flume::Receiver<InboundDriverDirective>,
1670 tx_msg: flume::Sender<msg::RecvMsg>,
1671}
1672
1673mod inner {
1674 use bytes::BytesMut;
1678 use capture_it::capture;
1679 use futures_util::{future::FusedFuture, FutureExt};
1680 use std::{future::poll_fn, num::NonZeroU64, sync::Arc};
1681
1682 use crate::{
1683 codec::{self, Codec, InboundFrameType},
1684 rpc::{DeferredWrite, SendError},
1685 transport::{AsyncFrameRead, AsyncFrameWrite, FrameReader},
1686 };
1687
1688 use super::{
1689 msg, ConnectionImpl, DriverBody, Feature, GetRequestContext, InboundBody,
1690 InboundDriverDirective, InboundError, InboundEventSubscriber, MessageReqId, UserData,
1691 };
1692
1693 impl<C, T, R, U> ConnectionImpl<C, T, R, U>
1695 where
1696 C: Codec,
1697 T: AsyncFrameWrite,
1698 R: GetRequestContext,
1699 U: UserData,
1700 {
1701 pub(crate) async fn inbound_event_handler<Tr, E>(body: DriverBody<C, T, E, R, U, Tr>)
1702 where
1703 Tr: AsyncFrameRead,
1704 E: InboundEventSubscriber,
1705 {
1706 body.execute().await;
1707 }
1708 }
1709
1710 impl<C, T, E, R, U, Tr> DriverBody<C, T, E, R, U, Tr>
1711 where
1712 C: Codec,
1713 T: AsyncFrameWrite,
1714 E: InboundEventSubscriber,
1715 R: GetRequestContext,
1716 U: UserData,
1717 Tr: AsyncFrameRead,
1718 {
1719 async fn execute(self) {
1720 let DriverBody { w_this, mut read, mut ev_subs, rx_drive: rx_msg, tx_msg } = self;
1721
1722 use futures_util::future::Fuse;
1723 let mut fut_drive_msg = Fuse::terminated();
1724 let mut fut_read = Fuse::terminated();
1725 let mut close_from_remote = false;
1726
1727 let (tx_bg_sender, rx_bg_sender) = flume::unbounded();
1732 let fut_bg_sender = capture!([w_this], async move {
1733 let mut pool = super::WriteBuffer::default();
1734 while let Ok(msg) = rx_bg_sender.recv_async().await {
1735 let Some(this) = w_this.upgrade() else { break };
1736 match msg {
1737 DeferredWrite::ErrorResponse(req, err) => {
1738 let err = this.dyn_ref().__send_err_predef(&mut pool, &req, &err).await;
1739
1740 if let Err(SendError::IoError(e)) = err {
1743 return Err(e);
1744 }
1745 }
1746 DeferredWrite::Raw(mut msg) => {
1747 this.dyn_ref().__write_raw(&mut msg).await?;
1748 }
1749 DeferredWrite::Flush => {
1750 this.dyn_ref().__flush().await?;
1751 }
1752 }
1753 }
1754
1755 Ok::<(), std::io::Error>(())
1756 });
1757 let fut_bg_sender = fut_bg_sender.fuse();
1758 let mut fut_bg_sender = std::pin::pin!(fut_bg_sender);
1759 let mut read = std::pin::pin!(read);
1760
1761 loop {
1763 if fut_drive_msg.is_terminated() {
1764 fut_drive_msg = rx_msg.recv_async().fuse();
1765 }
1766
1767 if fut_read.is_terminated() {
1768 fut_read = poll_fn(|cx| read.as_mut().poll_next(cx)).fuse();
1769 }
1770
1771 futures_util::select! {
1772 msg = fut_drive_msg => {
1773 match msg {
1774 Ok(InboundDriverDirective::DeferredWrite(msg)) => {
1775 tx_bg_sender.send_async(msg).await.ok();
1777 }
1778
1779 Err(_) | Ok(InboundDriverDirective::Close) => {
1780 break;
1783 }
1784 }
1785 }
1786
1787 inbound = fut_read => {
1788 let Some(this) = w_this.upgrade() else { break };
1789
1790 match inbound {
1791 Ok(bytes) => {
1792 Self::on_read(
1793 &this,
1794 &mut ev_subs,
1795 bytes,
1796 &tx_msg
1797 )
1798 .await;
1799 }
1800 Err(e) => {
1801 close_from_remote = true;
1802 ev_subs.on_close(false, Err(e));
1803
1804 break;
1805 }
1806 }
1807 }
1808
1809 result = fut_bg_sender => {
1810 if let Err(err) = result {
1811 close_from_remote = true;
1812 ev_subs.on_close(true, Err(err));
1813 }
1814
1815 break;
1816 }
1817 }
1818 }
1819
1820 if let Some(x) = w_this.upgrade() {
1822 x.dyn_ref().__close().await.ok();
1823 }
1824
1825 if !close_from_remote {
1827 ev_subs.on_close(true, Ok(())); }
1829
1830 'cancel: {
1832 let Some(this) = w_this.upgrade() else { break 'cancel };
1833 let Some(reqs) = this.reqs.get_req_con() else { break 'cancel };
1834
1835 drop(fut_drive_msg);
1837 drop(rx_msg);
1838
1839 reqs.wake_up_all();
1841 }
1842 }
1843
1844 async fn on_read(
1845 this: &Arc<ConnectionImpl<C, T, R, U>>,
1846 ev_subs: &mut E,
1847 frame: bytes::Bytes,
1848 tx_msg: &flume::Sender<msg::RecvMsg>,
1849 ) {
1850 let parsed = this.codec.decode_inbound(&frame);
1851 let (header, payload_span) = match parsed {
1852 Ok(x) => x,
1853 Err(e) => {
1854 ev_subs.on_inbound_error(InboundError::InboundDecodeError(e, frame));
1855 return;
1856 }
1857 };
1858
1859 let h = InboundBody { buffer: frame, payload: payload_span, codec: this.codec.clone() };
1860 match header {
1861 InboundFrameType::Notify { .. } | InboundFrameType::Request { .. } => {
1862 let (msg, disabled) = match header {
1863 InboundFrameType::Notify { method } => (
1864 msg::RecvMsg::Notify(msg::Notify { h, method, sender: this.clone() }),
1865 this.features.contains(Feature::NO_RECEIVE_NOTIFY),
1866 ),
1867 InboundFrameType::Request { method, req_id } => (
1868 msg::RecvMsg::Request(msg::Request {
1869 body: Some((msg::RequestInner { h, method, req_id }, this.clone())),
1870 }),
1871 this.features.contains(Feature::NO_RECEIVE_REQUEST),
1872 ),
1873 _ => unreachable!(),
1874 };
1875
1876 if disabled {
1877 ev_subs.on_inbound_error(InboundError::DisabledInbound(msg));
1878 } else {
1879 tx_msg.send_async(msg).await.ok();
1880 }
1881 }
1882
1883 InboundFrameType::Response { req_id, req_id_hash, is_error } => {
1884 let response = msg::Response { h, is_error, req_id };
1885
1886 let Some(reqs) = this.reqs.get_req_con() else {
1887 ev_subs.on_inbound_error(InboundError::RedundantResponse(response));
1888 return;
1889 };
1890
1891 let Some(req_id_hash) = NonZeroU64::new(req_id_hash) else {
1892 ev_subs.on_inbound_error(InboundError::ResponseHashZero(response));
1893 return;
1894 };
1895
1896 if let Err(msg) = reqs.route_response(req_id_hash, response) {
1897 ev_subs.on_inbound_error(InboundError::ExpiredResponse(msg));
1898 }
1899 }
1900 }
1901 }
1902 }
1903
1904 macro_rules! pin {
1906 ($this:ident, $ident:ident) => {
1907 let mut $ident = $this.write().lock().await;
1908 let mut $ident = unsafe { std::pin::Pin::new_unchecked(&mut *$ident) };
1909 };
1910 }
1911
1912 impl dyn super::Connection {
1913 pub(crate) async fn __send_err_predef(
1914 &self,
1915 buf: &mut super::WriteBuffer,
1916 recv: &msg::RequestInner,
1917 error: &codec::PredefinedResponseError,
1918 ) -> Result<(), super::SendError> {
1919 buf.prepare();
1920
1921 self.codec().encode_response_predefined(recv.req_id(), error, &mut buf.value)?;
1922 self.__write_raw(&mut buf.value).await?;
1923 Ok(())
1924 }
1925
1926 pub(crate) async fn __write_raw(&self, buf: &mut BytesMut) -> std::io::Result<()> {
1927 pin!(self, write);
1928
1929 write.as_mut().begin_write_frame(buf.len())?;
1930 let mut reader = FrameReader::new(buf);
1931
1932 while !reader.is_empty() {
1933 poll_fn(|cx| write.as_mut().poll_write(cx, &mut reader)).await?;
1934 }
1935
1936 Ok(())
1937 }
1938
1939 pub(crate) async fn __flush(&self) -> std::io::Result<()> {
1940 pin!(self, write);
1941 poll_fn(|cx| write.as_mut().poll_flush(cx)).await
1942 }
1943
1944 pub(crate) async fn __close(&self) -> std::io::Result<()> {
1945 pin!(self, write);
1946 poll_fn(|cx| write.as_mut().poll_close(cx)).await
1947 }
1948
1949 pub(crate) fn __is_disconnected(&self) -> bool {
1950 self.tx_drive().is_disconnected()
1951 }
1952 }
1953}