1use crate::client::{Connection as Client, Config as ClientConfig, ReconStrat};
2use crate::server::{Connection as Server, Config as ServerConfig};
3use crate::util::{watch, TimeoutReader, ByteStream};
4use crate::packet::{Packet, PlainBytes};
5use crate::packet::builder::{PacketReceiver, PacketReceiverError};
6use crate::error::TaskError;
7use crate::handler::{client, server, TaskHandle, SendBack};
8
9use tokio::io::AsyncWriteExt;
10use tokio::sync::oneshot;
11use tokio::time::{interval, Duration, MissedTickBehavior};
12
13use std::io;
14
15
16pub(crate) fn client<S, P>(
18 byte_stream: S,
19 cfg: ClientConfig,
20 mut recon_strat: Option<ReconStrat<S>>
21) -> Client<P>
22where
23 S: ByteStream,
24 P: Packet<PlainBytes> + Send + 'static,
25 P::Header: Send
26{
27 let (sender, mut cfg_rx, mut bg_handler) = client::Handler::new(cfg);
28
29 let (tx_close, mut rx_close) = oneshot::channel();
30 let task = tokio::spawn(async move {
31 client_bg_reconnect!(
32 client_bg_stream(
33 byte_stream,
34 bg_handler,
35 cfg_rx,
36 rx_close,
37 recon_strat,
38 |stream, cfg| {
39 Ok(PacketStream::new(stream, cfg.timeout, cfg.body_limit))
40 }
41 )
42 );
43 });
44
45 let task = TaskHandle { close: tx_close, task };
46
47 Client::new_raw(sender, task)
48}
49
50pub(crate) fn server<S, P>(stream: S, cfg: ServerConfig) -> Server<P>
52where
53 S: ByteStream,
54 P: Packet<PlainBytes> + Send + 'static,
55 P::Header: Send
56{
57 let stream = PacketStream::new(stream, cfg.timeout, cfg.body_limit);
58 let (receiver, mut cfg_rx, mut bg_handler) = server::Handler::new(cfg);
59
60 let (tx_close, mut rx_close) = oneshot::channel();
61 let task = tokio::spawn(async move {
62 let r = server_bg_stream(
63 stream,
64 &mut bg_handler,
65 &mut cfg_rx,
66 &mut rx_close
67 ).await;
68
69 if let Err(e) = &r {
70 tracing::error!("server_bg_stream error {:?}", e)
71 }
72
73 r
74 });
75
76 let task = TaskHandle { close: tx_close, task };
77
78 Server::new_raw(receiver, task)
79}
80
81struct PacketStream<S, P>
83where
84 S: ByteStream,
85 P: Packet<PlainBytes>
86{
87 stream: TimeoutReader<S>,
88 builder: PacketReceiver<P, PlainBytes>
90}
91
92impl<S, P> PacketStream<S, P>
93where
94 S: ByteStream,
95 P: Packet<PlainBytes>
96{
97 fn new(stream: S, timeout: Duration, body_limit: u32) -> Self {
98 Self {
99 stream: TimeoutReader::new(stream, timeout),
100 builder: PacketReceiver::new(body_limit)
101 }
102 }
103
104 fn timeout(&self) -> Duration {
105 self.stream.timeout()
106 }
107
108 async fn send(&mut self, packet: P) -> Result<(), io::Error> {
109 let bytes = packet.into_bytes();
110 let slice = bytes.as_slice();
111 self.stream.write_all(slice).await?;
112 self.stream.flush().await?;
113 Ok(())
114 }
115
116 async fn receive(&mut self) -> Result<P, PacketReceiverError<P::Header>> {
118 self.builder.read_header(&mut self.stream, |_| Ok(())).await?;
119 self.builder.read_body(&mut self.stream, |_| Ok(())).await
120 }
121
122 async fn shutdown(&mut self) -> Result<(), io::Error> {
123 self.stream.shutdown().await
124 }
125}
126
127bg_stream!(
128 client_bg_stream, client::Handler<P, PlainBytes>, PlainBytes, ClientConfig
129);
130bg_stream!(
131 server_bg_stream, server::Handler<P, PlainBytes>, PlainBytes, ServerConfig
132);
133
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::packet::test::{TestPacket};
139 use crate::server::Message;
140 use crate::util::PinnedFuture;
141
142 use tokio::net::{TcpStream, TcpListener};
143 use tokio::time::{sleep, Duration};
144
145
146 async fn tcp_streams() -> (TcpStream, TcpStream) {
148 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
149 let addr = listener.local_addr().unwrap();
150
151 let connect = TcpStream::connect(addr);
152 let accept = listener.accept();
153 let (connect, accept) = tokio::join!(connect, accept);
154
155 (connect.unwrap(), accept.unwrap().0)
156 }
157
158 #[tokio::test]
159 async fn test_plain_stream() {
160 let timeout = Duration::from_secs(1);
161
162 let (alice, bob) = tcp_streams().await;
163
164 let alice: Client<TestPacket<_>> = client(alice, ClientConfig {
165 timeout,
166 body_limit: 200
167 }, None);
168
169 let bob_task = tokio::spawn(async move {
170 let mut bob: Server<TestPacket<_>> = server(bob, ServerConfig {
171 timeout,
172 body_limit: 200
173 });
174
175 let req = bob.receive().await.unwrap();
177 match req {
178 Message::Request(req, resp) => {
179 assert_eq!(req.num1, 1);
180 assert_eq!(req.num2, 2);
181
182 let res = TestPacket::new(3, 4);
184 resp.send(res).unwrap();
185 },
186 _ => panic!("expected request")
187 };
188
189 let req = bob.receive().await.unwrap();
190 match req {
191 Message::RequestReceiver(req, stream) => {
192 assert_eq!(req.num1, 5);
193 assert_eq!(req.num2, 6);
194
195 let res = TestPacket::new(7, 8);
197 stream.send(res).await.unwrap();
198
199 let res = TestPacket::new(9, 10);
200 stream.send(res).await.unwrap();
201 },
202 _ => panic!("expected stream")
203 };
204
205 let req = bob.receive().await.unwrap();
206 match req {
207 Message::RequestSender(req, mut stream) => {
208 assert_eq!(req.num1, 11);
209 assert_eq!(req.num2, 12);
210
211 let res = stream.receive().await.unwrap();
213 assert_eq!(res.num1, 13);
214 assert_eq!(res.num2, 14);
215
216 let res = stream.receive().await.unwrap();
217 assert_eq!(res.num1, 15);
218 assert_eq!(res.num2, 16);
219 },
220 _ => panic!("expected stream")
221 };
222
223 bob.wait().await.unwrap();
224 });
225
226 let req = TestPacket::new(1, 2);
228 let res = alice.request(req).await.unwrap();
229 assert_eq!(res.num1, 3);
230 assert_eq!(res.num2, 4);
231
232 let req = TestPacket::new(5, 6);
234 let mut stream = alice.request_receiver(req).await.unwrap();
235
236 let res = stream.receive().await.unwrap();
237 assert_eq!(res.num1, 7);
238 assert_eq!(res.num2, 8);
239
240 let res = stream.receive().await.unwrap();
241 assert_eq!(res.num1, 9);
242 assert_eq!(res.num2, 10);
243 drop(stream);
244
245 let req = TestPacket::new(11, 12);
247 let stream = alice.request_sender(req).await.unwrap();
248
249 let req = TestPacket::new(13, 14);
250 stream.send(req).await.unwrap();
251
252 let req = TestPacket::new(15, 16);
253 stream.send(req).await.unwrap();
254 drop(stream);
255
256 println!("waiting for alice to close");
257
258 alice.close().await.unwrap();
259
260 bob_task.await.unwrap();
262 }
263
264 #[tokio::test]
265 async fn test_plain_stream_reconnect() {
266 let timeout = Duration::from_millis(20);
267
268 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
269 let addr = listener.local_addr().unwrap();
270
271 let server = tokio::spawn(async move {
272 let mut c = 0;
273 'main: loop {
274 c += 1;
278
279 let accept = listener.accept().await.unwrap().0;
280
281 let mut bob: Server<TestPacket<_>> = server(
282 accept,
283 ServerConfig {
284 timeout,
285 body_limit: 200
286 }
287 );
288
289 loop {
290
291 let req = bob.receive().await;
293 let req = match req {
294 Some(r) => r,
295 None => continue 'main
296 };
297
298 match req {
299 Message::Request(req, resp) => {
300 let res = TestPacket::new(req.num1, req.num2);
302 resp.send(res).unwrap();
303
304 if req.num1 == 3 {
305 break
306 }
307 },
308 _ => panic!("expected request")
309 };
310
311 if c == 1 {
312 sleep(Duration::from_millis(100)).await;
314 bob.abort();
315 continue 'main;
316 }
317
318 }
319
320 bob.wait().await.expect("bob failed");
321 break
322 }
323 });
324
325 let alice: Client<TestPacket<_>> = client(
326 TcpStream::connect(addr).await.unwrap(),
327 ClientConfig {
328 timeout,
329 body_limit: 200
330 },
331 Some(ReconStrat::new(move |err_count| {
332 let addr = addr.clone();
333 assert!(err_count < 10);
334 PinnedFuture::new(async move {
335 sleep(Duration::from_millis(10)).await;
336 TcpStream::connect(addr).await
337 })
338 }))
339 );
340
341 let req = TestPacket::new(1, 2);
343 let res = alice.request(req).await.unwrap();
344 assert_eq!(res.num1, 1);
345 assert_eq!(res.num2, 2);
346
347 let mut retry_counter = 0;
348
349 loop {
351
352 assert!(retry_counter < 10);
353
354 let req = TestPacket::new(3, 4);
355 let res = alice.request(req).await;
356 let res = match res {
357 Ok(r) => r,
358 Err(_) => {
359 retry_counter += 1;
360 sleep(Duration::from_millis(100)).await;
361 continue
362 }
363 };
364 assert_eq!(res.num1, 3);
365 assert_eq!(res.num2, 4);
366 break
367
368 }
369
370 alice.close().await.unwrap();
371
372 server.await.unwrap();
374 }
375}