rsocket_rust/core/
client.rs

1use std::marker::PhantomData;
2use std::net::SocketAddr;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use futures::{future, FutureExt, Sink, SinkExt, Stream, StreamExt};
9use tokio::sync::{mpsc, Mutex, Notify};
10
11use crate::error::{RSocketError, ERR_CONN_CLOSED};
12use crate::frame::{self, Frame};
13use crate::payload::{Payload, SetupPayload, SetupPayloadBuilder};
14use crate::runtime;
15use crate::spi::{ClientResponder, Flux, RSocket};
16use crate::transport::{
17    self, Connection, DuplexSocket, FrameSink, FrameStream, Splitter, Transport,
18};
19use crate::Result;
20
21#[derive(Clone)]
22pub struct Client {
23    closed: Arc<Notify>,
24    socket: DuplexSocket,
25    closing: mpsc::Sender<()>,
26}
27
28pub struct ClientBuilder<T, C> {
29    transport: Option<T>,
30    setup: SetupPayloadBuilder,
31    responder: Option<ClientResponder>,
32    closer: Option<Box<dyn FnMut() + Send + Sync>>,
33    mtu: usize,
34    _c: PhantomData<C>,
35}
36
37impl<T, C> ClientBuilder<T, C>
38where
39    T: Send + Sync + Transport<Conn = C>,
40    C: Send + Sync + Connection,
41{
42    pub(crate) fn new() -> ClientBuilder<T, C> {
43        ClientBuilder {
44            transport: None,
45            responder: None,
46            setup: SetupPayload::builder(),
47            closer: None,
48            mtu: 0,
49            _c: PhantomData,
50        }
51    }
52
53    pub fn fragment(mut self, mtu: usize) -> Self {
54        if mtu > 0 && mtu < transport::MIN_MTU {
55            warn!("invalid fragment mtu: at least {}!", transport::MIN_MTU)
56        } else {
57            self.mtu = mtu;
58        }
59        self
60    }
61
62    pub fn transport(mut self, transport: T) -> Self {
63        self.transport = Some(transport);
64        self
65    }
66
67    pub fn setup(mut self, setup: Payload) -> Self {
68        let (d, m) = setup.split();
69        self.setup = self.setup.set_data_bytes(d);
70        self.setup = self.setup.set_metadata_bytes(m);
71        self
72    }
73
74    pub fn keepalive(
75        mut self,
76        tick_period: Duration,
77        ack_timeout: Duration,
78        missed_acks: u64,
79    ) -> Self {
80        self.setup = self
81            .setup
82            .set_keepalive(tick_period, ack_timeout, missed_acks);
83        self
84    }
85
86    pub fn mime_type(
87        mut self,
88        metadata_mime_type: impl Into<String>,
89        data_mime_type: impl Into<String>,
90    ) -> Self {
91        self = self.metadata_mime_type(metadata_mime_type);
92        self = self.data_mime_type(data_mime_type);
93        self
94    }
95
96    pub fn data_mime_type(mut self, mime_type: impl Into<String>) -> Self {
97        self.setup = self.setup.set_data_mime_type(mime_type);
98        self
99    }
100
101    pub fn metadata_mime_type(mut self, mime_type: impl Into<String>) -> Self {
102        self.setup = self.setup.set_metadata_mime_type(mime_type);
103        self
104    }
105
106    pub fn acceptor(mut self, acceptor: ClientResponder) -> Self {
107        self.responder = Some(acceptor);
108        self
109    }
110
111    pub fn on_close(mut self, callback: Box<dyn FnMut() + Sync + Send>) -> Self {
112        self.closer = Some(callback);
113        self
114    }
115}
116
117impl<T, C> ClientBuilder<T, C>
118where
119    T: Send + Sync + Transport<Conn = C> + 'static,
120    C: Send + Sync + Connection + 'static,
121{
122    pub async fn start(mut self) -> Result<Client> {
123        let tp: T = self.transport.take().expect("missint transport");
124
125        let splitter = if self.mtu == 0 {
126            None
127        } else {
128            Some(Splitter::new(self.mtu))
129        };
130
131        let (snd_tx, mut snd_rx) = mpsc::unbounded_channel::<Frame>();
132        let cloned_snd_tx = snd_tx.clone();
133        let mut socket = DuplexSocket::new(1, snd_tx, splitter).await;
134
135        let mut cloned_socket = socket.clone();
136
137        if let Some(f) = self.responder {
138            let responder = f();
139            socket.bind_responder(responder).await;
140        }
141
142        let conn = tp.connect().await?;
143        let (mut sink, mut stream) = conn.split();
144
145        let setup = self.setup.build();
146
147        // begin write loop
148        let tick_period = setup.keepalive_interval();
149        runtime::spawn(async move {
150            loop {
151                // send keepalive if timeout
152                match tokio::time::timeout(tick_period, snd_rx.recv()).await {
153                    Ok(Some(frame)) => {
154                        if let frame::Body::Error(e) = frame.get_body_ref() {
155                            if e.get_code() == ERR_CONN_CLOSED {
156                                break;
157                            }
158                        }
159                        if let Err(e) = sink.send(frame).await {
160                            error!("write frame failed: {}", e);
161                            break;
162                        }
163                    }
164                    Ok(None) => break,
165                    Err(_) => {
166                        // keepalive
167                        let keepalive_frame =
168                            frame::Keepalive::builder(0, Frame::FLAG_RESPOND).build();
169                        if let Err(e) = sink.send(keepalive_frame).await {
170                            error!("write frame failed: {}", e);
171                            break;
172                        }
173                    }
174                }
175            }
176        });
177
178        // begin read loop
179        let closer = self.closer.take();
180        let close_notify = Arc::new(Notify::new());
181        let close_notify_clone = close_notify.clone();
182        let (closing, mut closing_rx) = mpsc::channel::<()>(1);
183
184        let (read_tx, mut read_rx) = mpsc::unbounded_channel::<Frame>();
185
186        // read frames from stream, then writes into channel
187        runtime::spawn(async move {
188            loop {
189                tokio::select! {
190                    res = stream.next() => {
191                        match res {
192                            Some(next) => match next {
193                                Ok(frame) => {
194                                    if let Err(e) = read_tx.send(frame) {
195                                        error!("forward frame failed: {}", e);
196                                        break;
197                                    }
198                                }
199                                Err(e) => {
200                                    error!("read frame failed: {}", e);
201                                    break;
202                                }
203                            }
204                            None => break,
205                        }
206                    }
207                    _ = closing_rx.recv() => {
208                        break
209                    }
210                }
211            }
212        });
213
214        // process frames
215        runtime::spawn(async move {
216            while let Some(next) = read_rx.recv().await {
217                if let Err(e) = cloned_socket.dispatch(next, None).await {
218                    error!("dispatch frame failed: {}", e);
219                    break;
220                }
221            }
222
223            // workaround: send a notify frame that the connection has been closed.
224            let close_frame = frame::Error::builder(0, 0)
225                .set_code(ERR_CONN_CLOSED)
226                .build();
227            if let Err(e) = cloned_snd_tx.send(close_frame) {
228                debug!("send close notify frame failed: {}", e);
229            }
230
231            // notify client closed
232            close_notify_clone.notify_one();
233
234            // invoke on_close handler
235            if let Some(mut invoke) = closer {
236                invoke();
237            }
238        });
239
240        socket.setup(setup).await?;
241
242        Ok(Client::new(socket, close_notify, closing))
243    }
244}
245
246impl Client {
247    fn new(socket: DuplexSocket, closed: Arc<Notify>, closing: mpsc::Sender<()>) -> Client {
248        Client {
249            socket,
250            closed,
251            closing,
252        }
253    }
254
255    pub async fn wait_for_close(self) {
256        self.closed.notified().await
257    }
258}
259
260#[async_trait]
261impl RSocket for Client {
262    async fn metadata_push(&self, req: Payload) -> Result<()> {
263        self.socket.metadata_push(req).await
264    }
265
266    async fn fire_and_forget(&self, req: Payload) -> Result<()> {
267        self.socket.fire_and_forget(req).await
268    }
269
270    async fn request_response(&self, req: Payload) -> Result<Option<Payload>> {
271        self.socket.request_response(req).await
272    }
273
274    fn request_stream(&self, req: Payload) -> Flux<Result<Payload>> {
275        self.socket.request_stream(req)
276    }
277
278    fn request_channel(&self, reqs: Flux<Result<Payload>>) -> Flux<Result<Payload>> {
279        self.socket.request_channel(reqs)
280    }
281}