1use std::marker::PhantomData;
2use std::net::SocketAddr;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use futures::{future, FutureExt, Sink, SinkExt, Stream, StreamExt};
9use tokio::sync::{mpsc, Mutex, Notify};
10
11use crate::error::{RSocketError, ERR_CONN_CLOSED};
12use crate::frame::{self, Frame};
13use crate::payload::{Payload, SetupPayload, SetupPayloadBuilder};
14use crate::runtime;
15use crate::spi::{ClientResponder, Flux, RSocket};
16use crate::transport::{
17 self, Connection, DuplexSocket, FrameSink, FrameStream, Splitter, Transport,
18};
19use crate::Result;
20
21#[derive(Clone)]
22pub struct Client {
23 closed: Arc<Notify>,
24 socket: DuplexSocket,
25 closing: mpsc::Sender<()>,
26}
27
28pub struct ClientBuilder<T, C> {
29 transport: Option<T>,
30 setup: SetupPayloadBuilder,
31 responder: Option<ClientResponder>,
32 closer: Option<Box<dyn FnMut() + Send + Sync>>,
33 mtu: usize,
34 _c: PhantomData<C>,
35}
36
37impl<T, C> ClientBuilder<T, C>
38where
39 T: Send + Sync + Transport<Conn = C>,
40 C: Send + Sync + Connection,
41{
42 pub(crate) fn new() -> ClientBuilder<T, C> {
43 ClientBuilder {
44 transport: None,
45 responder: None,
46 setup: SetupPayload::builder(),
47 closer: None,
48 mtu: 0,
49 _c: PhantomData,
50 }
51 }
52
53 pub fn fragment(mut self, mtu: usize) -> Self {
54 if mtu > 0 && mtu < transport::MIN_MTU {
55 warn!("invalid fragment mtu: at least {}!", transport::MIN_MTU)
56 } else {
57 self.mtu = mtu;
58 }
59 self
60 }
61
62 pub fn transport(mut self, transport: T) -> Self {
63 self.transport = Some(transport);
64 self
65 }
66
67 pub fn setup(mut self, setup: Payload) -> Self {
68 let (d, m) = setup.split();
69 self.setup = self.setup.set_data_bytes(d);
70 self.setup = self.setup.set_metadata_bytes(m);
71 self
72 }
73
74 pub fn keepalive(
75 mut self,
76 tick_period: Duration,
77 ack_timeout: Duration,
78 missed_acks: u64,
79 ) -> Self {
80 self.setup = self
81 .setup
82 .set_keepalive(tick_period, ack_timeout, missed_acks);
83 self
84 }
85
86 pub fn mime_type(
87 mut self,
88 metadata_mime_type: impl Into<String>,
89 data_mime_type: impl Into<String>,
90 ) -> Self {
91 self = self.metadata_mime_type(metadata_mime_type);
92 self = self.data_mime_type(data_mime_type);
93 self
94 }
95
96 pub fn data_mime_type(mut self, mime_type: impl Into<String>) -> Self {
97 self.setup = self.setup.set_data_mime_type(mime_type);
98 self
99 }
100
101 pub fn metadata_mime_type(mut self, mime_type: impl Into<String>) -> Self {
102 self.setup = self.setup.set_metadata_mime_type(mime_type);
103 self
104 }
105
106 pub fn acceptor(mut self, acceptor: ClientResponder) -> Self {
107 self.responder = Some(acceptor);
108 self
109 }
110
111 pub fn on_close(mut self, callback: Box<dyn FnMut() + Sync + Send>) -> Self {
112 self.closer = Some(callback);
113 self
114 }
115}
116
117impl<T, C> ClientBuilder<T, C>
118where
119 T: Send + Sync + Transport<Conn = C> + 'static,
120 C: Send + Sync + Connection + 'static,
121{
122 pub async fn start(mut self) -> Result<Client> {
123 let tp: T = self.transport.take().expect("missint transport");
124
125 let splitter = if self.mtu == 0 {
126 None
127 } else {
128 Some(Splitter::new(self.mtu))
129 };
130
131 let (snd_tx, mut snd_rx) = mpsc::unbounded_channel::<Frame>();
132 let cloned_snd_tx = snd_tx.clone();
133 let mut socket = DuplexSocket::new(1, snd_tx, splitter).await;
134
135 let mut cloned_socket = socket.clone();
136
137 if let Some(f) = self.responder {
138 let responder = f();
139 socket.bind_responder(responder).await;
140 }
141
142 let conn = tp.connect().await?;
143 let (mut sink, mut stream) = conn.split();
144
145 let setup = self.setup.build();
146
147 let tick_period = setup.keepalive_interval();
149 runtime::spawn(async move {
150 loop {
151 match tokio::time::timeout(tick_period, snd_rx.recv()).await {
153 Ok(Some(frame)) => {
154 if let frame::Body::Error(e) = frame.get_body_ref() {
155 if e.get_code() == ERR_CONN_CLOSED {
156 break;
157 }
158 }
159 if let Err(e) = sink.send(frame).await {
160 error!("write frame failed: {}", e);
161 break;
162 }
163 }
164 Ok(None) => break,
165 Err(_) => {
166 let keepalive_frame =
168 frame::Keepalive::builder(0, Frame::FLAG_RESPOND).build();
169 if let Err(e) = sink.send(keepalive_frame).await {
170 error!("write frame failed: {}", e);
171 break;
172 }
173 }
174 }
175 }
176 });
177
178 let closer = self.closer.take();
180 let close_notify = Arc::new(Notify::new());
181 let close_notify_clone = close_notify.clone();
182 let (closing, mut closing_rx) = mpsc::channel::<()>(1);
183
184 let (read_tx, mut read_rx) = mpsc::unbounded_channel::<Frame>();
185
186 runtime::spawn(async move {
188 loop {
189 tokio::select! {
190 res = stream.next() => {
191 match res {
192 Some(next) => match next {
193 Ok(frame) => {
194 if let Err(e) = read_tx.send(frame) {
195 error!("forward frame failed: {}", e);
196 break;
197 }
198 }
199 Err(e) => {
200 error!("read frame failed: {}", e);
201 break;
202 }
203 }
204 None => break,
205 }
206 }
207 _ = closing_rx.recv() => {
208 break
209 }
210 }
211 }
212 });
213
214 runtime::spawn(async move {
216 while let Some(next) = read_rx.recv().await {
217 if let Err(e) = cloned_socket.dispatch(next, None).await {
218 error!("dispatch frame failed: {}", e);
219 break;
220 }
221 }
222
223 let close_frame = frame::Error::builder(0, 0)
225 .set_code(ERR_CONN_CLOSED)
226 .build();
227 if let Err(e) = cloned_snd_tx.send(close_frame) {
228 debug!("send close notify frame failed: {}", e);
229 }
230
231 close_notify_clone.notify_one();
233
234 if let Some(mut invoke) = closer {
236 invoke();
237 }
238 });
239
240 socket.setup(setup).await?;
241
242 Ok(Client::new(socket, close_notify, closing))
243 }
244}
245
246impl Client {
247 fn new(socket: DuplexSocket, closed: Arc<Notify>, closing: mpsc::Sender<()>) -> Client {
248 Client {
249 socket,
250 closed,
251 closing,
252 }
253 }
254
255 pub async fn wait_for_close(self) {
256 self.closed.notified().await
257 }
258}
259
260#[async_trait]
261impl RSocket for Client {
262 async fn metadata_push(&self, req: Payload) -> Result<()> {
263 self.socket.metadata_push(req).await
264 }
265
266 async fn fire_and_forget(&self, req: Payload) -> Result<()> {
267 self.socket.fire_and_forget(req).await
268 }
269
270 async fn request_response(&self, req: Payload) -> Result<Option<Payload>> {
271 self.socket.request_response(req).await
272 }
273
274 fn request_stream(&self, req: Payload) -> Flux<Result<Payload>> {
275 self.socket.request_stream(req)
276 }
277
278 fn request_channel(&self, reqs: Flux<Result<Payload>>) -> Flux<Result<Payload>> {
279 self.socket.request_channel(reqs)
280 }
281}