rsocket_rust/transport/
socket.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use async_stream::stream;
6use async_trait::async_trait;
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use dashmap::{mapref::entry::Entry, DashMap};
9use futures::future::{AbortHandle, Abortable};
10use futures::{Sink, SinkExt, Stream, StreamExt};
11use tokio::sync::{mpsc, oneshot, RwLock};
12
13use super::fragmentation::{Joiner, Splitter};
14use super::misc::{debug_frame, Counter, StreamID};
15use super::spi::*;
16use crate::error::{self, RSocketError};
17use crate::frame::{self, Body, Frame};
18use crate::payload::{Payload, SetupPayload};
19use crate::spi::{Flux, RSocket, ServerResponder};
20use crate::utils::EmptyRSocket;
21use crate::{runtime, Result};
22
23#[derive(Clone)]
24pub(crate) struct DuplexSocket {
25    seq: StreamID,
26    responder: Responder,
27    tx: mpsc::UnboundedSender<Frame>,
28    handlers: Arc<DashMap<u32, Handler>>,
29    canceller: mpsc::Sender<u32>,
30    splitter: Option<Splitter>,
31    joiners: Arc<DashMap<u32, Joiner>>,
32    /// AbortHandles for streams and channels associated by sid
33    abort_handles: Arc<DashMap<u32, AbortHandle>>,
34}
35
36#[derive(Clone)]
37struct Responder {
38    inner: Arc<RwLock<Box<dyn RSocket>>>,
39}
40
41#[derive(Debug)]
42enum Handler {
43    ReqRR(oneshot::Sender<Result<Option<Payload>>>),
44    ResRR(Counter),
45    ReqRS(mpsc::Sender<Result<Payload>>),
46    ReqRC(mpsc::Sender<Result<Payload>>),
47}
48
49impl DuplexSocket {
50    pub(crate) async fn new(
51        first_stream_id: u32,
52        tx: mpsc::UnboundedSender<Frame>,
53        splitter: Option<Splitter>,
54    ) -> DuplexSocket {
55        let (canceller_tx, canceller_rx) = mpsc::channel::<u32>(32);
56        let socket = DuplexSocket {
57            seq: StreamID::from(first_stream_id),
58            tx,
59            canceller: canceller_tx,
60            responder: Responder::new(),
61            handlers: Arc::new(DashMap::new()),
62            joiners: Arc::new(DashMap::new()),
63            splitter,
64            abort_handles: Arc::new(DashMap::new()),
65        };
66
67        let cloned_socket = socket.clone();
68
69        runtime::spawn(async move {
70            cloned_socket.loop_canceller(canceller_rx).await;
71        });
72
73        socket
74    }
75
76    pub(crate) async fn setup(&mut self, setup: SetupPayload) -> Result<()> {
77        let mut bu = frame::Setup::builder(0, 0);
78        if let Some(s) = setup.data_mime_type() {
79            bu = bu.set_mime_data(s);
80        }
81        if let Some(s) = setup.metadata_mime_type() {
82            bu = bu.set_mime_metadata(s);
83        }
84        bu = bu.set_keepalive(setup.keepalive_interval());
85        bu = bu.set_lifetime(setup.keepalive_lifetime());
86        let (d, m) = setup.split();
87        if let Some(b) = d {
88            bu = bu.set_data(b);
89        }
90        if let Some(b) = m {
91            bu = bu.set_metadata(b);
92        }
93        self.tx.send(bu.build()).map_err(|e| e.into())
94    }
95
96    #[inline]
97    async fn register_handler(&self, sid: u32, handler: Handler) {
98        self.handlers.insert(sid, handler);
99    }
100
101    #[inline]
102    async fn loop_canceller(&self, mut rx: mpsc::Receiver<u32>) {
103        while let Some(sid) = rx.recv().await {
104            self.handlers.remove(&sid);
105        }
106    }
107
108    pub(crate) async fn dispatch(
109        &mut self,
110        frame: Frame,
111        acceptor: Option<&ServerResponder>,
112    ) -> Result<()> {
113        if let Some(frame) = self.join_frame(frame).await {
114            self.process_once(frame, acceptor).await;
115        }
116        Ok(())
117    }
118
119    #[inline]
120    async fn process_once(&mut self, msg: Frame, acceptor: Option<&ServerResponder>) {
121        let sid = msg.get_stream_id();
122        let flag = msg.get_flag();
123        debug_frame(false, &msg);
124        match msg.get_body() {
125            Body::Setup(v) => {
126                if let Err(e) = self
127                    .on_setup(acceptor, sid, flag, SetupPayload::from(v))
128                    .await
129                {
130                    let errmsg = format!("{}", e);
131                    let sending = frame::Error::builder(0, 0)
132                        .set_code(error::ERR_REJECT_SETUP)
133                        .set_data(Bytes::from(errmsg))
134                        .build();
135                    if self.tx.send(sending).is_err() {
136                        error!("Reject setup failed");
137                    }
138                }
139            }
140            Body::Resume(v) => {
141                // TODO: support resume
142            }
143            Body::ResumeOK(v) => {
144                // TODO: support resume ok
145            }
146            Body::MetadataPush(v) => {
147                let input = Payload::from(v);
148                self.on_metadata_push(input).await;
149            }
150            Body::RequestFNF(v) => {
151                let input = Payload::from(v);
152                self.on_fire_and_forget(sid, input).await;
153            }
154            Body::RequestResponse(v) => {
155                let input = Payload::from(v);
156                self.on_request_response(sid, flag, input).await;
157            }
158            Body::RequestStream(v) => {
159                let input = Payload::from(v);
160                self.on_request_stream(sid, flag, input).await;
161            }
162            Body::RequestChannel(v) => {
163                let input = Payload::from(v);
164                self.on_request_channel(sid, flag, input).await;
165            }
166            Body::Payload(v) => {
167                let input = Payload::from(v);
168                self.on_payload(sid, flag, input).await;
169            }
170            Body::Keepalive(v) => {
171                if flag & Frame::FLAG_RESPOND != 0 {
172                    debug!("got keepalive: {:?}", v);
173                    self.on_keepalive(v).await;
174                }
175            }
176            Body::RequestN(v) => {
177                // TODO: support RequestN
178            }
179            Body::Error(v) => {
180                // TODO: support error
181                self.on_error(sid, flag, v).await;
182            }
183            Body::Cancel() => {
184                self.on_cancel(sid, flag).await;
185            }
186            Body::Lease(v) => {
187                // TODO: support Lease
188            }
189        }
190    }
191
192    #[inline]
193    async fn join_frame(&self, input: Frame) -> Option<Frame> {
194        let (is_follow, is_payload) = input.is_followable_or_payload();
195        if !is_follow {
196            return Some(input);
197        }
198        let sid = input.get_stream_id();
199        if input.get_flag() & Frame::FLAG_FOLLOW != 0 {
200            // TODO: check conflict
201            self.joiners
202                .entry(sid)
203                .or_insert_with(Joiner::new)
204                .push(input);
205            return None;
206        }
207
208        if !is_payload {
209            return Some(input);
210        }
211
212        match self.joiners.remove(&sid) {
213            None => Some(input),
214            Some((_, mut joiner)) => {
215                joiner.push(input);
216                let flag = joiner.get_flag();
217                let first = joiner.first();
218                match &first.body {
219                    frame::Body::RequestResponse(_) => {
220                        let pa: Payload = joiner.into();
221                        let result = frame::RequestResponse::builder(sid, flag)
222                            .set_all(pa.split())
223                            .build();
224                        Some(result)
225                    }
226                    frame::Body::RequestStream(b) => {
227                        let n = b.get_initial_request_n();
228                        let pa: Payload = joiner.into();
229                        let result = frame::RequestStream::builder(sid, flag)
230                            .set_initial_request_n(n)
231                            .set_all(pa.split())
232                            .build();
233                        Some(result)
234                    }
235                    frame::Body::RequestFNF(_) => {
236                        let pa: Payload = joiner.into();
237                        let result = frame::RequestFNF::builder(sid, flag)
238                            .set_all(pa.split())
239                            .build();
240                        Some(result)
241                    }
242                    frame::Body::RequestChannel(b) => {
243                        let n = b.get_initial_request_n();
244                        let pa: Payload = joiner.into();
245                        let result = frame::RequestChannel::builder(sid, flag)
246                            .set_initial_request_n(n)
247                            .set_all(pa.split())
248                            .build();
249                        Some(result)
250                    }
251                    frame::Body::Payload(b) => {
252                        let pa: Payload = joiner.into();
253                        let result = frame::Payload::builder(sid, flag)
254                            .set_all(pa.split())
255                            .build();
256                        Some(result)
257                    }
258                    _ => unreachable!(),
259                }
260            }
261        }
262    }
263
264    #[inline]
265    async fn on_error(&mut self, sid: u32, flag: u16, input: frame::Error) {
266        self.joiners.remove(&sid);
267        // pick handler
268        if let Some((_, handler)) = self.handlers.remove(&sid) {
269            let desc = input
270                .get_data_utf8()
271                .map(|it| it.to_string())
272                .unwrap_or_default();
273            let e = RSocketError::must_new_from_code(input.get_code(), desc);
274            match handler {
275                Handler::ReqRR(tx) => {
276                    if tx.send(Err(e.into())).is_err() {
277                        error!("respond with error for REQUEST_RESPONSE failed!");
278                    }
279                }
280                Handler::ResRR(_) => unreachable!(),
281                Handler::ReqRS(tx) => {
282                    if (tx.send(Err(e.into())).await).is_err() {
283                        error!("respond with error for REQUEST_STREAM failed!");
284                    };
285                }
286                Handler::ReqRC(tx) => {
287                    if (tx.send(Err(e.into())).await).is_err() {
288                        error!("respond with error for REQUEST_CHANNEL failed!");
289                    }
290                }
291            }
292        }
293    }
294
295    #[inline]
296    async fn on_cancel(&mut self, sid: u32, _flag: u16) {
297        if let Some((sid, abort_handle)) = self.abort_handles.remove(&sid) {
298            abort_handle.abort();
299        }
300        self.joiners.remove(&sid);
301        if let Some((_, handler)) = self.handlers.remove(&sid) {
302            let e: Result<_> =
303                Err(RSocketError::RequestCancelled("request has been cancelled".into()).into());
304            match handler {
305                Handler::ReqRR(sender) => {
306                    info!("REQUEST_RESPONSE {} cancelled!", sid);
307                    if sender.send(e).is_err() {
308                        error!("notify cancel for REQUEST_RESPONSE failed: sid={}", sid);
309                    }
310                }
311                Handler::ResRR(c) => {
312                    let lefts = c.count_down();
313                    info!("REQUEST_RESPONSE {} cancelled: lefts={}", sid, lefts);
314                }
315                Handler::ReqRS(sender) => {
316                    info!("REQUEST_STREAM {} cancelled!", sid);
317                }
318                Handler::ReqRC(sender) => {
319                    info!("REQUEST_CHANNEL {} cancelled!", sid);
320                }
321            };
322        }
323    }
324
325    #[inline]
326    async fn on_payload(&mut self, sid: u32, flag: u16, input: Payload) {
327        match self.handlers.entry(sid) {
328            Entry::Occupied(o) => {
329                match o.get() {
330                    Handler::ReqRR(_) => match o.remove() {
331                        Handler::ReqRR(sender) => {
332                            if flag & Frame::FLAG_NEXT != 0 {
333                                if sender.send(Ok(Some(input))).is_err() {
334                                    error!("response successful payload for REQUEST_RESPONSE failed: sid={}",sid);
335                                }
336                            } else if sender.send(Ok(None)).is_err() {
337                                error!("response successful payload for REQUEST_RESPONSE failed: sid={}",sid);
338                            }
339                        }
340                        _ => unreachable!(),
341                    },
342                    Handler::ResRR(c) => unreachable!(),
343                    Handler::ReqRS(sender) => {
344                        if flag & Frame::FLAG_NEXT != 0 {
345                            if sender.is_closed() {
346                                self.send_cancel_frame(sid);
347                            } else if let Err(e) = sender.send(Ok(input)).await {
348                                error!(
349                                    "response successful payload for REQUEST_STREAM failed: sid={}",
350                                    sid
351                                );
352                                self.send_cancel_frame(sid);
353                            }
354                        }
355                        if flag & Frame::FLAG_COMPLETE != 0 {
356                            o.remove();
357                        }
358                    }
359                    Handler::ReqRC(sender) => {
360                        // TODO: support channel
361                        if flag & Frame::FLAG_NEXT != 0 {
362                            if sender.is_closed() {
363                                self.send_cancel_frame(sid);
364                            } else if (sender.clone().send(Ok(input)).await).is_err() {
365                                error!("response successful payload for REQUEST_CHANNEL failed: sid={}",sid);
366                                self.send_cancel_frame(sid);
367                            }
368                        }
369                        if flag & Frame::FLAG_COMPLETE != 0 {
370                            o.remove();
371                        }
372                    }
373                }
374            }
375            Entry::Vacant(_) => warn!("invalid payload id {}: no such request!", sid),
376        }
377    }
378
379    #[inline]
380    fn send_cancel_frame(&self, sid: u32) {
381        let cancel_frame = frame::Cancel::builder(sid, Frame::FLAG_COMPLETE).build();
382        if let Err(e) = self.tx.send(cancel_frame) {
383            error!("Sending CANCEL frame failed: sid={}, reason: {}", sid, e);
384        }
385    }
386
387    pub(crate) async fn bind_responder(&self, responder: Box<dyn RSocket>) {
388        self.responder.set(responder).await;
389    }
390
391    #[inline]
392    async fn on_setup(
393        &self,
394        acceptor: Option<&ServerResponder>,
395        sid: u32,
396        flag: u16,
397        setup: SetupPayload,
398    ) -> Result<()> {
399        match acceptor {
400            None => {
401                self.responder.set(Box::new(EmptyRSocket)).await;
402                Ok(())
403            }
404            Some(gen) => match gen(setup, Box::new(self.clone())) {
405                Ok(it) => {
406                    self.responder.set(it).await;
407                    Ok(())
408                }
409                Err(e) => Err(e),
410            },
411        }
412    }
413
414    #[inline]
415    async fn on_fire_and_forget(&mut self, sid: u32, input: Payload) {
416        if let Err(e) = self.responder.fire_and_forget(input).await {
417            error!("respond fire_and_forget failed: {:?}", e);
418        }
419    }
420
421    #[inline]
422    async fn on_request_response(&mut self, sid: u32, _flag: u16, input: Payload) {
423        let responder = self.responder.clone();
424        let canceller = self.canceller.clone();
425        let mut tx = self.tx.clone();
426        let splitter = self.splitter.clone();
427        let counter = Counter::new(2);
428        self.register_handler(sid, Handler::ResRR(counter.clone()))
429            .await;
430        runtime::spawn(async move {
431            // TODO: use future select
432            let result = responder.request_response(input).await;
433            if counter.count_down() == 0 {
434                // cancelled
435                return;
436            }
437
438            // async remove canceller
439            if (canceller.send(sid).await).is_err() {
440                error!("Send canceller failed: sid={}", sid);
441            }
442
443            match result {
444                Ok(Some(res)) => {
445                    Self::try_send_payload(
446                        &splitter,
447                        &mut tx,
448                        sid,
449                        res,
450                        Frame::FLAG_NEXT | Frame::FLAG_COMPLETE,
451                    )
452                    .await;
453                }
454                Ok(None) => {
455                    Self::try_send_complete(&mut tx, sid, Frame::FLAG_COMPLETE).await;
456                }
457                Err(e) => {
458                    let sending = frame::Error::builder(sid, 0)
459                        .set_code(error::ERR_APPLICATION)
460                        .set_data(Bytes::from(e.to_string()))
461                        .build();
462                    if let Err(e) = tx.send(sending) {
463                        error!("respond REQUEST_RESPONSE failed: {}", e);
464                    }
465                }
466            };
467        });
468    }
469
470    #[inline]
471    async fn on_request_stream(&self, sid: u32, flag: u16, input: Payload) {
472        let responder = self.responder.clone();
473        let mut tx = self.tx.clone();
474        let splitter = self.splitter.clone();
475        let abort_handles = self.abort_handles.clone();
476        runtime::spawn(async move {
477            let (abort_handle, abort_registration) = AbortHandle::new_pair();
478            abort_handles.insert(sid, abort_handle);
479            let mut payloads = Abortable::new(responder.request_stream(input), abort_registration);
480            while let Some(next) = payloads.next().await {
481                match next {
482                    Ok(it) => {
483                        Self::try_send_payload(&splitter, &mut tx, sid, it, Frame::FLAG_NEXT).await;
484                    }
485                    Err(e) => {
486                        let sending = frame::Error::builder(sid, 0)
487                            .set_code(error::ERR_APPLICATION)
488                            .set_data(Bytes::from(format!("{}", e)))
489                            .build();
490                        tx.send(sending).expect("Send stream response failed");
491                    }
492                };
493            }
494            abort_handles.remove(&sid);
495            let complete = frame::Payload::builder(sid, Frame::FLAG_COMPLETE).build();
496            tx.send(complete)
497                .expect("Send stream complete response failed");
498        });
499    }
500
501    #[inline]
502    async fn on_request_channel(&self, sid: u32, flag: u16, first: Payload) {
503        let responder = self.responder.clone();
504        let tx = self.tx.clone();
505        let (sender, mut receiver) = mpsc::channel::<Result<Payload>>(32);
506        sender.send(Ok(first)).await.expect("Send failed!");
507        self.register_handler(sid, Handler::ReqRC(sender)).await;
508        let abort_handles = self.abort_handles.clone();
509        runtime::spawn(async move {
510            // respond client channel
511            let outputs = responder.request_channel(Box::pin(stream! {
512                while let Some(it) = receiver.recv().await{
513                    yield it;
514                }
515            }));
516            let (abort_handle, abort_registration) = AbortHandle::new_pair();
517            abort_handles.insert(sid, abort_handle);
518            let mut outputs = Abortable::new(outputs, abort_registration);
519
520            // TODO: support custom RequestN.
521            let request_n = frame::RequestN::builder(sid, 0).build();
522
523            if let Err(e) = tx.send(request_n) {
524                error!("respond REQUEST_N failed: {}", e);
525            }
526
527            while let Some(next) = outputs.next().await {
528                let sending = match next {
529                    Ok(payload) => {
530                        let (data, metadata) = payload.split();
531                        let mut bu = frame::Payload::builder(sid, Frame::FLAG_NEXT);
532                        if let Some(b) = data {
533                            bu = bu.set_data(b);
534                        }
535                        if let Some(b) = metadata {
536                            bu = bu.set_metadata(b);
537                        }
538                        bu.build()
539                    }
540                    Err(e) => frame::Error::builder(sid, 0)
541                        .set_code(error::ERR_APPLICATION)
542                        .set_data(Bytes::from(format!("{}", e)))
543                        .build(),
544                };
545                tx.send(sending).expect("Send failed!");
546            }
547            abort_handles.remove(&sid);
548            let complete = frame::Payload::builder(sid, Frame::FLAG_COMPLETE).build();
549            if let Err(e) = tx.send(complete) {
550                error!("complete REQUEST_CHANNEL failed: {}", e);
551            }
552        });
553    }
554
555    #[inline]
556    async fn on_metadata_push(&mut self, input: Payload) {
557        if let Err(e) = self.responder.metadata_push(input).await {
558            error!("response metadata_push failed: {:?}", e);
559        }
560    }
561
562    #[inline]
563    async fn on_keepalive(&mut self, keepalive: frame::Keepalive) {
564        let (data, _) = keepalive.split();
565        let mut sending = frame::Keepalive::builder(0, 0);
566        if let Some(b) = data {
567            sending = sending.set_data(b);
568        }
569        if let Err(e) = self.tx.send(sending.build()) {
570            error!("respond KEEPALIVE failed: {}", e);
571        }
572    }
573
574    #[inline]
575    async fn try_send_channel(
576        splitter: &Option<Splitter>,
577        tx: &mut mpsc::UnboundedSender<Frame>,
578        sid: u32,
579        res: Payload,
580        flag: u16,
581    ) {
582        // TODO
583        match splitter {
584            Some(sp) => {
585                let mut cuts: usize = 0;
586                let mut prev: Option<Payload> = None;
587                for next in sp.cut(res, 4) {
588                    if let Some(cur) = prev.take() {
589                        let sending = if cuts == 1 {
590                            frame::RequestChannel::builder(sid, flag | Frame::FLAG_FOLLOW)
591                                .set_all(cur.split())
592                                .build()
593                        } else {
594                            frame::Payload::builder(sid, Frame::FLAG_FOLLOW)
595                                .set_all(cur.split())
596                                .build()
597                        };
598                        // send frame
599                        if let Err(e) = tx.send(sending) {
600                            error!("send request_channel failed: {}", e);
601                            return;
602                        }
603                    }
604                    prev = Some(next);
605                    cuts += 1;
606                }
607
608                let sending = if cuts == 0 {
609                    frame::RequestChannel::builder(sid, flag).build()
610                } else if cuts == 1 {
611                    frame::RequestChannel::builder(sid, flag)
612                        .set_all(prev.unwrap().split())
613                        .build()
614                } else {
615                    frame::Payload::builder(sid, 0)
616                        .set_all(prev.unwrap().split())
617                        .build()
618                };
619                // send frame
620                if let Err(e) = tx.send(sending) {
621                    error!("send request_channel failed: {}", e);
622                }
623            }
624            None => {
625                let sending = frame::RequestChannel::builder(sid, flag)
626                    .set_all(res.split())
627                    .build();
628                if let Err(e) = tx.send(sending) {
629                    error!("send request_channel failed: {}", e);
630                }
631            }
632        }
633    }
634
635    #[inline]
636    async fn try_send_complete(tx: &mut mpsc::UnboundedSender<Frame>, sid: u32, flag: u16) {
637        let sending = frame::Payload::builder(sid, flag).build();
638        if let Err(e) = tx.send(sending) {
639            error!("respond failed: {}", e);
640        }
641    }
642
643    #[inline]
644    async fn try_send_payload(
645        splitter: &Option<Splitter>,
646        tx: &mut mpsc::UnboundedSender<Frame>,
647        sid: u32,
648        res: Payload,
649        flag: u16,
650    ) {
651        match splitter {
652            Some(sp) => {
653                let mut cuts: usize = 0;
654                let mut prev: Option<Payload> = None;
655                for next in sp.cut(res, 0) {
656                    if let Some(cur) = prev.take() {
657                        let sending = if cuts == 1 {
658                            frame::Payload::builder(sid, flag | Frame::FLAG_FOLLOW)
659                                .set_all(cur.split())
660                                .build()
661                        } else {
662                            frame::Payload::builder(sid, Frame::FLAG_FOLLOW)
663                                .set_all(cur.split())
664                                .build()
665                        };
666                        // send frame
667                        if let Err(e) = tx.send(sending) {
668                            error!("send payload failed: {}", e);
669                            return;
670                        }
671                    }
672                    prev = Some(next);
673                    cuts += 1;
674                }
675
676                let sending = if cuts == 0 {
677                    frame::Payload::builder(sid, flag).build()
678                } else {
679                    frame::Payload::builder(sid, flag)
680                        .set_all(prev.unwrap().split())
681                        .build()
682                };
683                // send frame
684                if let Err(e) = tx.send(sending) {
685                    error!("send payload failed: {}", e);
686                }
687            }
688            None => {
689                let sending = frame::Payload::builder(sid, flag)
690                    .set_all(res.split())
691                    .build();
692                if let Err(e) = tx.send(sending) {
693                    error!("respond failed: {}", e);
694                }
695            }
696        }
697    }
698}
699
700#[async_trait]
701impl RSocket for DuplexSocket {
702    async fn metadata_push(&self, req: Payload) -> Result<()> {
703        let sid = self.seq.next();
704        let tx = self.tx.clone();
705        let (_d, m) = req.split();
706        let mut bu = frame::MetadataPush::builder(sid, 0);
707        if let Some(b) = m {
708            bu = bu.set_metadata(b);
709        }
710        tx.send(bu.build())?;
711        Ok(())
712    }
713
714    async fn fire_and_forget(&self, req: Payload) -> Result<()> {
715        let sid = self.seq.next();
716        let tx = self.tx.clone();
717        let splitter = self.splitter.clone();
718
719        match splitter {
720            Some(sp) => {
721                let mut cuts: usize = 0;
722                let mut prev: Option<Payload> = None;
723                for next in sp.cut(req, 0) {
724                    if let Some(cur) = prev.take() {
725                        let sending = if cuts == 1 {
726                            // make first frame as request_fnf.
727                            frame::RequestFNF::builder(sid, Frame::FLAG_FOLLOW)
728                                .set_all(cur.split())
729                                .build()
730                        } else {
731                            // make other frames as payload.
732                            frame::Payload::builder(sid, Frame::FLAG_FOLLOW)
733                                .set_all(cur.split())
734                                .build()
735                        };
736                        // send frame
737                        tx.send(sending)?;
738                    }
739                    prev = Some(next);
740                    cuts += 1;
741                }
742
743                let sending = if cuts == 0 {
744                    frame::RequestFNF::builder(sid, 0).build()
745                } else if cuts == 1 {
746                    frame::RequestFNF::builder(sid, 0)
747                        .set_all(prev.unwrap().split())
748                        .build()
749                } else {
750                    frame::Payload::builder(sid, 0)
751                        .set_all(prev.unwrap().split())
752                        .build()
753                };
754                // send frame
755                tx.send(sending)?;
756            }
757            None => {
758                let sending = frame::RequestFNF::builder(sid, 0)
759                    .set_all(req.split())
760                    .build();
761                tx.send(sending)?;
762            }
763        }
764        Ok(())
765    }
766
767    async fn request_response(&self, req: Payload) -> Result<Option<Payload>> {
768        let (tx, rx) = oneshot::channel::<Result<Option<Payload>>>();
769        let sid = self.seq.next();
770        let handlers = self.handlers.clone();
771        let sender = self.tx.clone();
772
773        let splitter = self.splitter.clone();
774
775        runtime::spawn(async move {
776            // register handler
777            handlers.insert(sid, Handler::ReqRR(tx));
778            match splitter {
779                Some(sp) => {
780                    let mut cuts: usize = 0;
781                    let mut prev: Option<Payload> = None;
782                    for next in sp.cut(req, 0) {
783                        if let Some(cur) = prev.take() {
784                            let sending = if cuts == 1 {
785                                // make first frame as request_response.
786                                frame::RequestResponse::builder(sid, Frame::FLAG_FOLLOW)
787                                    .set_all(cur.split())
788                                    .build()
789                            } else {
790                                // make other frames as payload.
791                                frame::Payload::builder(sid, Frame::FLAG_FOLLOW)
792                                    .set_all(cur.split())
793                                    .build()
794                            };
795                            // send frame
796                            if let Err(e) = sender.send(sending) {
797                                error!("send request_response failed: {}", e);
798                                return;
799                            }
800                        }
801                        prev = Some(next);
802                        cuts += 1;
803                    }
804
805                    let sending = if cuts == 0 {
806                        frame::RequestResponse::builder(sid, 0).build()
807                    } else if cuts == 1 {
808                        frame::RequestResponse::builder(sid, 0)
809                            .set_all(prev.unwrap().split())
810                            .build()
811                    } else {
812                        frame::Payload::builder(sid, 0)
813                            .set_all(prev.unwrap().split())
814                            .build()
815                    };
816                    // send frame
817                    if let Err(e) = sender.send(sending) {
818                        error!("send request_response failed: {}", e);
819                    }
820                }
821                None => {
822                    // crate request frame
823                    let sending = frame::RequestResponse::builder(sid, 0)
824                        .set_all(req.split())
825                        .build();
826                    // send frame
827                    if let Err(e) = sender.send(sending) {
828                        error!("send request_response failed: {}", e);
829                    }
830                }
831            }
832        });
833        match rx.await {
834            Ok(v) => v,
835            Err(_e) => Err(RSocketError::WithDescription("request_response failed".into()).into()),
836        }
837    }
838
839    fn request_stream(&self, input: Payload) -> Flux<Result<Payload>> {
840        let sid = self.seq.next();
841        let tx = self.tx.clone();
842        // register handler
843        let (sender, mut receiver) = mpsc::channel::<Result<Payload>>(32);
844        let handlers = self.handlers.clone();
845        let splitter = self.splitter.clone();
846        runtime::spawn(async move {
847            handlers.insert(sid, Handler::ReqRS(sender));
848            match splitter {
849                Some(sp) => {
850                    let mut cuts: usize = 0;
851                    let mut prev: Option<Payload> = None;
852                    // skip 4 bytes. (initial_request_n is u32)
853                    for next in sp.cut(input, 4) {
854                        if let Some(cur) = prev.take() {
855                            let sending: Frame = if cuts == 1 {
856                                // make first frame as request_stream.
857                                frame::RequestStream::builder(sid, Frame::FLAG_FOLLOW)
858                                    .set_all(cur.split())
859                                    .build()
860                            } else {
861                                // make other frames as payload.
862                                frame::Payload::builder(sid, Frame::FLAG_FOLLOW)
863                                    .set_all(cur.split())
864                                    .build()
865                            };
866                            // send frame
867                            if let Err(e) = tx.send(sending) {
868                                error!("send request_stream failed: {}", e);
869                                return;
870                            }
871                        }
872                        prev = Some(next);
873                        cuts += 1;
874                    }
875
876                    let sending = if cuts == 0 {
877                        frame::RequestStream::builder(sid, 0).build()
878                    } else if cuts == 1 {
879                        frame::RequestStream::builder(sid, 0)
880                            .set_all(prev.unwrap().split())
881                            .build()
882                    } else {
883                        frame::Payload::builder(sid, 0)
884                            .set_all(prev.unwrap().split())
885                            .build()
886                    };
887                    // send frame
888                    if let Err(e) = tx.send(sending) {
889                        error!("send request_stream failed: {}", e);
890                    }
891                }
892                None => {
893                    let sending = frame::RequestStream::builder(sid, 0)
894                        .set_all(input.split())
895                        .build();
896                    if let Err(e) = tx.send(sending) {
897                        error!("send request_stream failed: {}", e);
898                    }
899                }
900            }
901        });
902        Box::pin(stream! {
903            while let Some(it) = receiver.recv().await{
904                yield it;
905            }
906        })
907    }
908
909    fn request_channel(&self, mut reqs: Flux<Result<Payload>>) -> Flux<Result<Payload>> {
910        let sid = self.seq.next();
911        let mut tx = self.tx.clone();
912        // register handler
913        let (sender, mut receiver) = mpsc::channel::<Result<Payload>>(32);
914        let handlers = self.handlers.clone();
915        let splitter = self.splitter.clone();
916        runtime::spawn(async move {
917            handlers.insert(sid, Handler::ReqRC(sender));
918            let mut first = true;
919            while let Some(next) = reqs.next().await {
920                match next {
921                    Ok(it) => {
922                        if first {
923                            first = false;
924                            Self::try_send_channel(&splitter, &mut tx, sid, it, Frame::FLAG_NEXT)
925                                .await
926                        } else {
927                            Self::try_send_payload(&splitter, &mut tx, sid, it, Frame::FLAG_NEXT)
928                                .await
929                        }
930                    }
931                    Err(e) => {
932                        let sending = frame::Error::builder(sid, 0)
933                            .set_code(error::ERR_APPLICATION)
934                            .set_data(Bytes::from(format!("{}", e)))
935                            .build();
936                        if let Err(e) = tx.send(sending) {
937                            error!("send REQUEST_CHANNEL failed: {}", e);
938                        }
939                    }
940                };
941            }
942            let sending = frame::Payload::builder(sid, Frame::FLAG_COMPLETE).build();
943            if let Err(e) = tx.send(sending) {
944                error!("complete REQUEST_CHANNEL failed: {}", e);
945            }
946        });
947        Box::pin(stream! {
948            while let Some(it) = receiver.recv().await{
949                yield it;
950            }
951        })
952    }
953}
954
955impl From<Box<dyn RSocket>> for Responder {
956    fn from(input: Box<dyn RSocket>) -> Responder {
957        Responder {
958            inner: Arc::new(RwLock::new(input)),
959        }
960    }
961}
962
963impl Responder {
964    fn new() -> Responder {
965        let bx = Box::new(EmptyRSocket);
966        Responder {
967            inner: Arc::new(RwLock::new(bx)),
968        }
969    }
970
971    async fn set(&self, rs: Box<dyn RSocket>) {
972        let mut w = self.inner.write().await;
973        *w = rs;
974    }
975}
976
977#[async_trait]
978impl RSocket for Responder {
979    async fn metadata_push(&self, req: Payload) -> Result<()> {
980        let inner = self.inner.read().await;
981        (*inner).metadata_push(req).await
982    }
983
984    async fn fire_and_forget(&self, req: Payload) -> Result<()> {
985        let inner = self.inner.read().await;
986        (*inner).fire_and_forget(req).await
987    }
988
989    async fn request_response(&self, req: Payload) -> Result<Option<Payload>> {
990        let inner = self.inner.read().await;
991        (*inner).request_response(req).await
992    }
993
994    fn request_stream(&self, req: Payload) -> Flux<Result<Payload>> {
995        let inner = self.inner.clone();
996        Box::pin(stream! {
997            let r = inner.read().await;
998            let mut results = (*r).request_stream(req);
999            while let Some(next) = results.next().await {
1000                yield next;
1001            }
1002        })
1003    }
1004
1005    fn request_channel(&self, reqs: Flux<Result<Payload>>) -> Flux<Result<Payload>> {
1006        let inner = self.inner.clone();
1007        Box::pin(stream! {
1008            let r = inner.read().await;
1009            let mut results = (*r).request_channel(reqs);
1010            while let Some(next) = results.next().await{
1011                yield next;
1012            }
1013        })
1014    }
1015}