lafere/plain/
mod.rs

1use crate::client::{Connection as Client, Config as ClientConfig, ReconStrat};
2use crate::server::{Connection as Server, Config as ServerConfig};
3use crate::util::{watch, TimeoutReader, ByteStream};
4use crate::packet::{Packet, PlainBytes};
5use crate::packet::builder::{PacketReceiver, PacketReceiverError};
6use crate::error::TaskError;
7use crate::handler::{client, server, TaskHandle, SendBack};
8
9use tokio::io::AsyncWriteExt;
10use tokio::sync::oneshot;
11use tokio::time::{interval, Duration, MissedTickBehavior};
12
13use std::io;
14
15
16/// Creates a new client from a stream, without using any encryption.
17pub(crate) fn client<S, P>(
18	byte_stream: S,
19	cfg: ClientConfig,
20	mut recon_strat: Option<ReconStrat<S>>
21) -> Client<P>
22where
23	S: ByteStream,
24	P: Packet<PlainBytes> + Send + 'static,
25	P::Header: Send
26{
27	let (sender, mut cfg_rx, mut bg_handler) = client::Handler::new(cfg);
28
29	let (tx_close, mut rx_close) = oneshot::channel();
30	let task = tokio::spawn(async move {
31		client_bg_reconnect!(
32			client_bg_stream(
33				byte_stream,
34				bg_handler,
35				cfg_rx,
36				rx_close,
37				recon_strat,
38				|stream, cfg| {
39					Ok(PacketStream::new(stream, cfg.timeout, cfg.body_limit))
40				}
41			)
42		);
43	});
44
45	let task = TaskHandle { close: tx_close, task };
46
47	Client::new_raw(sender, task)
48}
49
50/// Creates a new server from a stream, without using any encryption.
51pub(crate) fn server<S, P>(stream: S, cfg: ServerConfig) -> Server<P>
52where
53	S: ByteStream,
54	P: Packet<PlainBytes> + Send + 'static,
55	P::Header: Send
56{
57	let stream = PacketStream::new(stream, cfg.timeout, cfg.body_limit);
58	let (receiver, mut cfg_rx, mut bg_handler) = server::Handler::new(cfg);
59
60	let (tx_close, mut rx_close) = oneshot::channel();
61	let task = tokio::spawn(async move {
62		let r = server_bg_stream(
63			stream,
64			&mut bg_handler,
65			&mut cfg_rx,
66			&mut rx_close
67		).await;
68
69		if let Err(e) = &r {
70			tracing::error!("server_bg_stream error {:?}", e)
71		}
72
73		r
74	});
75
76	let task = TaskHandle { close: tx_close, task };
77
78	Server::new_raw(receiver, task)
79}
80
81/// inner manages a stream
82struct PacketStream<S, P>
83where
84	S: ByteStream,
85	P: Packet<PlainBytes>
86{
87	stream: TimeoutReader<S>,
88	// buffer to receive a message
89	builder: PacketReceiver<P, PlainBytes>
90}
91
92impl<S, P> PacketStream<S, P>
93where
94	S: ByteStream,
95	P: Packet<PlainBytes>
96{
97	fn new(stream: S, timeout: Duration, body_limit: u32) -> Self {
98		Self {
99			stream: TimeoutReader::new(stream, timeout),
100			builder: PacketReceiver::new(body_limit)
101		}
102	}
103
104	fn timeout(&self) -> Duration {
105		self.stream.timeout()
106	}
107
108	async fn send(&mut self, packet: P) -> Result<(), io::Error> {
109		let bytes = packet.into_bytes();
110		let slice = bytes.as_slice();
111		self.stream.write_all(slice).await?;
112		self.stream.flush().await?;
113		Ok(())
114	}
115
116	/// this function is abort safe
117	async fn receive(&mut self) -> Result<P, PacketReceiverError<P::Header>> {
118		self.builder.read_header(&mut self.stream, |_| Ok(())).await?;
119		self.builder.read_body(&mut self.stream, |_| Ok(())).await
120	}
121
122	async fn shutdown(&mut self) -> Result<(), io::Error> {
123		self.stream.shutdown().await
124	}
125}
126
127bg_stream!(
128	client_bg_stream, client::Handler<P, PlainBytes>, PlainBytes, ClientConfig
129);
130bg_stream!(
131	server_bg_stream, server::Handler<P, PlainBytes>, PlainBytes, ServerConfig
132);
133
134
135#[cfg(test)]
136mod tests {
137	use super::*;
138	use crate::packet::test::{TestPacket};
139	use crate::server::Message;
140	use crate::util::PinnedFuture;
141
142	use tokio::net::{TcpStream, TcpListener};
143	use tokio::time::{sleep, Duration};
144
145
146	/// create two tcp stream which communicate with each other
147	async fn tcp_streams() -> (TcpStream, TcpStream) {
148		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
149		let addr = listener.local_addr().unwrap();
150
151		let connect = TcpStream::connect(addr);
152		let accept = listener.accept();
153		let (connect, accept) = tokio::join!(connect, accept);
154
155		(connect.unwrap(), accept.unwrap().0)
156	}
157
158	#[tokio::test]
159	async fn test_plain_stream() {
160		let timeout = Duration::from_secs(1);
161
162		let (alice, bob) = tcp_streams().await;
163
164		let alice: Client<TestPacket<_>> = client(alice, ClientConfig {
165			timeout,
166			body_limit: 200
167		}, None);
168
169		let bob_task = tokio::spawn(async move {
170			let mut bob: Server<TestPacket<_>> = server(bob, ServerConfig {
171				timeout,
172				body_limit: 200
173			});
174
175			// let's receive a request message
176			let req = bob.receive().await.unwrap();
177			match req {
178				Message::Request(req, resp) => {
179					assert_eq!(req.num1, 1);
180					assert_eq!(req.num2, 2);
181
182					// send response
183					let res = TestPacket::new(3, 4);
184					resp.send(res).unwrap();
185				},
186				_ => panic!("expected request")
187			};
188
189			let req = bob.receive().await.unwrap();
190			match req {
191				Message::RequestReceiver(req, stream) => {
192					assert_eq!(req.num1, 5);
193					assert_eq!(req.num2, 6);
194
195					// send response
196					let res = TestPacket::new(7, 8);
197					stream.send(res).await.unwrap();
198
199					let res = TestPacket::new(9, 10);
200					stream.send(res).await.unwrap();
201				},
202				_ => panic!("expected stream")
203			};
204
205			let req = bob.receive().await.unwrap();
206			match req {
207				Message::RequestSender(req, mut stream) => {
208					assert_eq!(req.num1, 11);
209					assert_eq!(req.num2, 12);
210
211					// send response
212					let res = stream.receive().await.unwrap();
213					assert_eq!(res.num1, 13);
214					assert_eq!(res.num2, 14);
215
216					let res = stream.receive().await.unwrap();
217					assert_eq!(res.num1, 15);
218					assert_eq!(res.num2, 16);
219				},
220				_ => panic!("expected stream")
221			};
222
223			bob.wait().await.unwrap();
224		});
225
226		// let's make a request
227		let req = TestPacket::new(1, 2);
228		let res = alice.request(req).await.unwrap();
229		assert_eq!(res.num1, 3);
230		assert_eq!(res.num2, 4);
231
232		// let's create a stream to listen
233		let req = TestPacket::new(5, 6);
234		let mut stream = alice.request_receiver(req).await.unwrap();
235
236		let res = stream.receive().await.unwrap();
237		assert_eq!(res.num1, 7);
238		assert_eq!(res.num2, 8);
239
240		let res = stream.receive().await.unwrap();
241		assert_eq!(res.num1, 9);
242		assert_eq!(res.num2, 10);
243		drop(stream);
244
245		// now request a stream.sender
246		let req = TestPacket::new(11, 12);
247		let stream = alice.request_sender(req).await.unwrap();
248
249		let req = TestPacket::new(13, 14);
250		stream.send(req).await.unwrap();
251
252		let req = TestPacket::new(15, 16);
253		stream.send(req).await.unwrap();
254		drop(stream);
255
256		println!("waiting for alice to close");
257
258		alice.close().await.unwrap();
259
260		// wait until bob's task finishes
261		bob_task.await.unwrap();
262	}
263
264	#[tokio::test]
265	async fn test_plain_stream_reconnect() {
266		let timeout = Duration::from_millis(20);
267
268		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
269		let addr = listener.local_addr().unwrap();
270
271		let server = tokio::spawn(async move {
272			let mut c = 0;
273			'main: loop {
274				// if i == 0
275				// close stream early
276
277				c += 1;
278
279				let accept = listener.accept().await.unwrap().0;
280
281				let mut bob: Server<TestPacket<_>> = server(
282					accept,
283					ServerConfig {
284						timeout,
285						body_limit: 200
286					}
287				);
288
289				loop {
290
291					// let's receive a request message
292					let req = bob.receive().await;
293					let req = match req {
294						Some(r) => r,
295						None => continue 'main
296					};
297
298					match req {
299						Message::Request(req, resp) => {
300							// send response
301							let res = TestPacket::new(req.num1, req.num2);
302							resp.send(res).unwrap();
303
304							if req.num1 == 3 {
305								break
306							}
307						},
308						_ => panic!("expected request")
309					};
310
311					if c == 1 {
312						// we need to wait so the 
313						sleep(Duration::from_millis(100)).await;
314						bob.abort();
315						continue 'main;
316					}
317
318				}
319
320				bob.wait().await.expect("bob failed");
321				break
322			}
323		});
324
325		let alice: Client<TestPacket<_>> = client(
326			TcpStream::connect(addr).await.unwrap(),
327			ClientConfig {
328				timeout,
329				body_limit: 200
330			},
331			Some(ReconStrat::new(move |err_count| {
332				let addr = addr.clone();
333				assert!(err_count < 10);
334				PinnedFuture::new(async move {
335					sleep(Duration::from_millis(10)).await;
336					TcpStream::connect(addr).await
337				})
338			}))
339		);
340
341		// first request should succeed
342		let req = TestPacket::new(1, 2);
343		let res = alice.request(req).await.unwrap();
344		assert_eq!(res.num1, 1);
345		assert_eq!(res.num2, 2);
346
347		let mut retry_counter = 0;
348
349		// loop until we get a response
350		loop {
351
352			assert!(retry_counter < 10);
353
354			let req = TestPacket::new(3, 4);
355			let res = alice.request(req).await;
356			let res = match res {
357				Ok(r) => r,
358				Err(_) => {
359					retry_counter += 1;
360					sleep(Duration::from_millis(100)).await;
361					continue
362				}
363			};
364			assert_eq!(res.num1, 3);
365			assert_eq!(res.num2, 4);
366			break
367
368		}
369
370		alice.close().await.unwrap();
371
372		// wait until bob's task finishes
373		server.await.unwrap();
374	}
375}