1use std::time::Duration;
2use std::{fmt, io};
3
4use async_trait::async_trait;
5
6mod framed;
7pub use framed::*;
8
9mod inmemory;
10pub use inmemory::*;
11
12mod tcp;
13pub use tcp::*;
14
15#[cfg(test)]
16mod test;
17
18#[cfg(test)]
19pub use test::*;
20
21#[cfg(unix)]
22mod unix;
23
24#[cfg(unix)]
25pub use unix::*;
26
27#[cfg(windows)]
28mod windows;
29
30pub use tokio::io::{Interest, Ready};
31#[cfg(windows)]
32pub use windows::*;
33
34const SLEEP_DURATION: Duration = Duration::from_millis(1);
36
37#[async_trait]
39pub trait Reconnectable {
40 async fn reconnect(&mut self) -> io::Result<()>;
42}
43
44#[async_trait]
46pub trait Transport: Reconnectable + fmt::Debug + Send + Sync {
47 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize>;
55
56 fn try_write(&self, buf: &[u8]) -> io::Result<usize>;
63
64 async fn ready(&self, interest: Interest) -> io::Result<Ready>;
67}
68
69#[async_trait]
70impl Transport for Box<dyn Transport> {
71 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
72 Transport::try_read(AsRef::as_ref(self), buf)
73 }
74
75 fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
76 Transport::try_write(AsRef::as_ref(self), buf)
77 }
78
79 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
80 Transport::ready(AsRef::as_ref(self), interest).await
81 }
82}
83
84#[async_trait]
85impl Reconnectable for Box<dyn Transport> {
86 async fn reconnect(&mut self) -> io::Result<()> {
87 Reconnectable::reconnect(AsMut::as_mut(self)).await
88 }
89}
90
91#[async_trait]
92pub trait TransportExt {
93 async fn readable(&self) -> io::Result<()>;
95
96 async fn writeable(&self) -> io::Result<()>;
98
99 async fn readable_or_writeable(&self) -> io::Result<()>;
101
102 async fn read_exact(&self, buf: &mut [u8]) -> io::Result<usize>;
109
110 async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize>;
131
132 async fn read_to_string(&self, buf: &mut String) -> io::Result<usize>;
148
149 async fn write_all(&self, buf: &[u8]) -> io::Result<()>;
155}
156
157#[async_trait]
158impl<T: Transport> TransportExt for T {
159 async fn readable(&self) -> io::Result<()> {
160 self.ready(Interest::READABLE).await?;
161 Ok(())
162 }
163
164 async fn writeable(&self) -> io::Result<()> {
165 self.ready(Interest::WRITABLE).await?;
166 Ok(())
167 }
168
169 async fn readable_or_writeable(&self) -> io::Result<()> {
170 self.ready(Interest::READABLE | Interest::WRITABLE).await?;
171 Ok(())
172 }
173
174 async fn read_exact(&self, buf: &mut [u8]) -> io::Result<usize> {
175 let mut i = 0;
176
177 while i < buf.len() {
178 self.readable().await?;
179
180 match self.try_read(&mut buf[i..]) {
181 Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
188
189 Ok(n) => i += n,
190
191 Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
194 tokio::time::sleep(SLEEP_DURATION).await
196 }
197
198 Err(x) => return Err(x),
199 }
200 }
201
202 Ok(i)
203 }
204
205 async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
206 let mut i = 0;
207 let mut tmp = [0u8; 1024];
208
209 loop {
210 self.readable().await?;
211
212 match self.try_read(&mut tmp) {
213 Ok(0) => return Ok(i),
214 Ok(n) => {
215 buf.extend_from_slice(&tmp[..n]);
216 i += n;
217 }
218 Err(x)
219 if x.kind() == io::ErrorKind::WouldBlock
220 || x.kind() == io::ErrorKind::Interrupted =>
221 {
222 tokio::time::sleep(SLEEP_DURATION).await
224 }
225
226 Err(x) => return Err(x),
227 }
228 }
229 }
230
231 async fn read_to_string(&self, buf: &mut String) -> io::Result<usize> {
232 let mut tmp = Vec::new();
233 let n = self.read_to_end(&mut tmp).await?;
234 buf.push_str(
235 &String::from_utf8(tmp).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?,
236 );
237 Ok(n)
238 }
239
240 async fn write_all(&self, buf: &[u8]) -> io::Result<()> {
241 let mut i = 0;
242
243 while i < buf.len() {
244 self.writeable().await?;
245
246 match self.try_write(&buf[i..]) {
247 Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)),
254
255 Ok(n) => i += n,
256
257 Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
260 tokio::time::sleep(SLEEP_DURATION).await
262 }
263
264 Err(x) => return Err(x),
265 }
266 }
267
268 Ok(())
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use test_log::test;
275
276 use super::*;
277
278 #[test(tokio::test)]
279 async fn read_exact_should_fail_if_try_read_encounters_error_other_than_would_block() {
280 let transport = TestTransport {
281 f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
282 f_ready: Box::new(|_| Ok(Ready::READABLE)),
283 ..Default::default()
284 };
285
286 let mut buf = [0; 1];
287 assert_eq!(
288 transport.read_exact(&mut buf).await.unwrap_err().kind(),
289 io::ErrorKind::NotConnected
290 );
291 }
292
293 #[test(tokio::test)]
294 async fn read_exact_should_fail_if_try_read_returns_0_before_necessary_bytes_read() {
295 let transport = TestTransport {
296 f_try_read: Box::new(|_| Ok(0)),
297 f_ready: Box::new(|_| Ok(Ready::READABLE)),
298 ..Default::default()
299 };
300
301 let mut buf = [0; 1];
302 assert_eq!(
303 transport.read_exact(&mut buf).await.unwrap_err().kind(),
304 io::ErrorKind::UnexpectedEof
305 );
306 }
307
308 #[test(tokio::test)]
309 async fn read_exact_should_continue_to_call_try_read_until_buffer_is_filled() {
310 let transport = TestTransport {
311 f_try_read: Box::new(|buf| {
312 static mut CNT: u8 = 0;
313 unsafe {
314 buf[0] = b'a' + CNT;
315 CNT += 1;
316 }
317 Ok(1)
318 }),
319 f_ready: Box::new(|_| Ok(Ready::READABLE)),
320 ..Default::default()
321 };
322
323 let mut buf = [0; 3];
324 assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
325 assert_eq!(&buf, b"abc");
326 }
327
328 #[test(tokio::test)]
329 async fn read_exact_should_continue_to_call_try_read_while_it_returns_would_block() {
330 let transport = TestTransport {
332 f_try_read: Box::new(|buf| {
333 static mut CNT: u8 = 0;
334 unsafe {
335 buf[0] = b'a' + CNT;
336 CNT += 1;
337 if CNT % 2 == 1 {
338 Ok(1)
339 } else {
340 Err(io::Error::from(io::ErrorKind::WouldBlock))
341 }
342 }
343 }),
344 f_ready: Box::new(|_| Ok(Ready::READABLE)),
345 ..Default::default()
346 };
347
348 let mut buf = [0; 3];
349 assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
350 assert_eq!(&buf, b"ace");
351 }
352
353 #[test(tokio::test)]
354 async fn read_exact_should_return_0_if_given_a_buffer_of_0_len() {
355 let transport = TestTransport {
356 f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
357 f_ready: Box::new(|_| Ok(Ready::READABLE)),
358 ..Default::default()
359 };
360
361 let mut buf = [0; 0];
362 assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 0);
363 }
364
365 #[test(tokio::test)]
366 async fn read_to_end_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt(
367 ) {
368 let transport = TestTransport {
369 f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
370 f_ready: Box::new(|_| Ok(Ready::READABLE)),
371 ..Default::default()
372 };
373
374 assert_eq!(
375 transport
376 .read_to_end(&mut Vec::new())
377 .await
378 .unwrap_err()
379 .kind(),
380 io::ErrorKind::NotConnected
381 );
382 }
383
384 #[test(tokio::test)]
385 async fn read_to_end_should_read_until_0_bytes_returned_from_try_read() {
386 let transport = TestTransport {
387 f_try_read: Box::new(|buf| {
388 static mut CNT: u8 = 0;
389 unsafe {
390 if CNT == 0 {
391 buf[..5].copy_from_slice(b"hello");
392 CNT += 1;
393 Ok(5)
394 } else {
395 Ok(0)
396 }
397 }
398 }),
399 f_ready: Box::new(|_| Ok(Ready::READABLE)),
400 ..Default::default()
401 };
402
403 let mut buf = Vec::new();
404 assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 5);
405 assert_eq!(buf, b"hello");
406 }
407
408 #[test(tokio::test)]
409 async fn read_to_end_should_continue_reading_when_interrupt_or_would_block_encountered() {
410 let transport = TestTransport {
411 f_try_read: Box::new(|buf| {
412 static mut CNT: u8 = 0;
413 unsafe {
414 CNT += 1;
415 if CNT == 1 {
416 buf[..6].copy_from_slice(b"hello ");
417 Ok(6)
418 } else if CNT == 2 {
419 Err(io::Error::from(io::ErrorKind::WouldBlock))
420 } else if CNT == 3 {
421 buf[..5].copy_from_slice(b"world");
422 Ok(5)
423 } else if CNT == 4 {
424 Err(io::Error::from(io::ErrorKind::Interrupted))
425 } else if CNT == 5 {
426 buf[..6].copy_from_slice(b", test");
427 Ok(6)
428 } else {
429 Ok(0)
430 }
431 }
432 }),
433 f_ready: Box::new(|_| Ok(Ready::READABLE)),
434 ..Default::default()
435 };
436
437 let mut buf = Vec::new();
438 assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 17);
439 assert_eq!(buf, b"hello world, test");
440 }
441
442 #[test(tokio::test)]
443 async fn read_to_string_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt(
444 ) {
445 let transport = TestTransport {
446 f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
447 f_ready: Box::new(|_| Ok(Ready::READABLE)),
448 ..Default::default()
449 };
450
451 assert_eq!(
452 transport
453 .read_to_string(&mut String::new())
454 .await
455 .unwrap_err()
456 .kind(),
457 io::ErrorKind::NotConnected
458 );
459 }
460
461 #[test(tokio::test)]
462 async fn read_to_string_should_fail_if_non_utf8_characters_read() {
463 let transport = TestTransport {
464 f_try_read: Box::new(|buf| {
465 static mut CNT: u8 = 0;
466 unsafe {
467 if CNT == 0 {
468 buf[0] = 0;
469 buf[1] = 159;
470 buf[2] = 146;
471 buf[3] = 150;
472 CNT += 1;
473 Ok(4)
474 } else {
475 Ok(0)
476 }
477 }
478 }),
479 f_ready: Box::new(|_| Ok(Ready::READABLE)),
480 ..Default::default()
481 };
482
483 let mut buf = String::new();
484 assert_eq!(
485 transport.read_to_string(&mut buf).await.unwrap_err().kind(),
486 io::ErrorKind::InvalidData
487 );
488 }
489
490 #[test(tokio::test)]
491 async fn read_to_string_should_read_until_0_bytes_returned_from_try_read() {
492 let transport = TestTransport {
493 f_try_read: Box::new(|buf| {
494 static mut CNT: u8 = 0;
495 unsafe {
496 if CNT == 0 {
497 buf[..5].copy_from_slice(b"hello");
498 CNT += 1;
499 Ok(5)
500 } else {
501 Ok(0)
502 }
503 }
504 }),
505 f_ready: Box::new(|_| Ok(Ready::READABLE)),
506 ..Default::default()
507 };
508
509 let mut buf = String::new();
510 assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 5);
511 assert_eq!(buf, "hello");
512 }
513
514 #[test(tokio::test)]
515 async fn read_to_string_should_continue_reading_when_interrupt_or_would_block_encountered() {
516 let transport = TestTransport {
517 f_try_read: Box::new(|buf| {
518 static mut CNT: u8 = 0;
519 unsafe {
520 CNT += 1;
521 if CNT == 1 {
522 buf[..6].copy_from_slice(b"hello ");
523 Ok(6)
524 } else if CNT == 2 {
525 Err(io::Error::from(io::ErrorKind::WouldBlock))
526 } else if CNT == 3 {
527 buf[..5].copy_from_slice(b"world");
528 Ok(5)
529 } else if CNT == 4 {
530 Err(io::Error::from(io::ErrorKind::Interrupted))
531 } else if CNT == 5 {
532 buf[..6].copy_from_slice(b", test");
533 Ok(6)
534 } else {
535 Ok(0)
536 }
537 }
538 }),
539 f_ready: Box::new(|_| Ok(Ready::READABLE)),
540 ..Default::default()
541 };
542
543 let mut buf = String::new();
544 assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 17);
545 assert_eq!(buf, "hello world, test");
546 }
547
548 #[test(tokio::test)]
549 async fn write_all_should_fail_if_try_write_encounters_error_other_than_would_block() {
550 let transport = TestTransport {
551 f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
552 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
553 ..Default::default()
554 };
555
556 assert_eq!(
557 transport.write_all(b"abc").await.unwrap_err().kind(),
558 io::ErrorKind::NotConnected
559 );
560 }
561
562 #[test(tokio::test)]
563 async fn write_all_should_fail_if_try_write_returns_0_before_all_bytes_written() {
564 let transport = TestTransport {
565 f_try_write: Box::new(|_| Ok(0)),
566 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
567 ..Default::default()
568 };
569
570 assert_eq!(
571 transport.write_all(b"abc").await.unwrap_err().kind(),
572 io::ErrorKind::WriteZero
573 );
574 }
575
576 #[test(tokio::test)]
577 async fn write_all_should_continue_to_call_try_write_until_all_bytes_written() {
578 let transport = TestTransport {
580 f_try_write: Box::new(|buf| {
581 static mut CNT: u8 = 0;
582 unsafe {
583 assert_eq!(buf[0], b'a' + CNT);
584 CNT += 1;
585 Ok(1)
586 }
587 }),
588 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
589 ..Default::default()
590 };
591
592 transport.write_all(b"abc").await.unwrap();
593 }
594
595 #[test(tokio::test)]
596 async fn write_all_should_continue_to_call_try_write_while_it_returns_would_block() {
597 let transport = TestTransport {
599 f_try_write: Box::new(|buf| {
600 static mut CNT: u8 = 0;
601 unsafe {
602 if CNT % 2 == 0 {
603 assert_eq!(buf[0], b'a' + CNT);
604 CNT += 1;
605 Ok(1)
606 } else {
607 CNT += 1;
608 Err(io::Error::from(io::ErrorKind::WouldBlock))
609 }
610 }
611 }),
612 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
613 ..Default::default()
614 };
615
616 transport.write_all(b"ace").await.unwrap();
617 }
618
619 #[test(tokio::test)]
620 async fn write_all_should_return_immediately_if_given_buffer_of_0_len() {
621 let transport = TestTransport {
622 f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
623 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
624 ..Default::default()
625 };
626
627 let buf = [0; 0];
629 transport.write_all(&buf).await.unwrap();
630 }
631}