1use super::Comm;
2use parking_lot::{Mutex, MutexGuard};
3use std::error::Error;
4use std::io::{Read, Write};
5use std::net::SocketAddr;
6use std::net::TcpStream;
7use std::sync::Arc;
8use std::time::Duration;
9
10#[allow(clippy::module_name_repetitions)]
11pub struct TcpComm {
12 addr: SocketAddr,
13 stream: Mutex<Option<TcpStream>>,
14 timeout: Duration,
15 busy: Mutex<()>,
16}
17
18#[allow(clippy::module_name_repetitions)]
19pub type TcpCommunicator = Arc<TcpComm>;
20
21macro_rules! handle_tcp_stream_error {
22 ($stream: expr, $err: expr, $any: expr) => {{
23 if $any || $err.kind() == std::io::ErrorKind::TimedOut {
24 $stream.take();
25 }
26 $err
27 }};
28}
29
30impl Comm for TcpComm {
31 fn lock(&self) -> MutexGuard<()> {
32 self.busy.lock()
33 }
34 fn reconnect(&self) {
35 self.stream.lock().take();
36 }
37 fn write(&self, buf: &[u8]) -> Result<(), std::io::Error> {
38 let mut stream = self.get_stream()?;
39 stream
40 .as_mut()
41 .unwrap()
42 .write_all(buf)
43 .map_err(|e| handle_tcp_stream_error!(stream, e, true))
44 }
45 fn read_exact(&self, buf: &mut [u8]) -> Result<(), std::io::Error> {
46 let mut stream = self.get_stream()?;
47 stream
48 .as_mut()
49 .unwrap()
50 .read_exact(buf)
51 .map_err(|e| handle_tcp_stream_error!(stream, e, false))
52 }
53}
54
55impl TcpComm {
56 pub fn create(path: &str, timeout: Duration) -> Result<Self, Box<dyn Error>> {
57 Ok(Self {
58 addr: path.parse()?,
59 stream: <_>::default(),
60 busy: <_>::default(),
61 timeout,
62 })
63 }
64 fn get_stream(&self) -> Result<MutexGuard<Option<TcpStream>>, std::io::Error> {
65 let mut lock = self.stream.lock();
66 if lock.as_mut().is_none() {
67 let stream = TcpStream::connect_timeout(&self.addr, self.timeout)?;
68 stream.set_read_timeout(Some(self.timeout))?;
69 stream.set_write_timeout(Some(self.timeout))?;
70 stream.set_nodelay(true)?;
71 lock.replace(stream);
72 }
73 Ok(lock)
74 }
75}