ais_keystore_lib/shared/
conn.rs

1//! Abstractions around TCP and Unix socket connections used by the client.
2
3use dusa_collection_utils::core::errors::{ErrorArrayItem, Errors};
4use simple_comms::protocol::proto::Proto;
5use tokio::net::{TcpStream, UnixStream};
6
7use super::consts::{PORT, SOCKETPATH};
8
9/// Mutable reference wrapper over the two supported stream types.
10pub enum StreamMut<'a> {
11    Tcp(&'a mut Box<TcpStream>),
12    Unix(&'a mut Box<UnixStream>),
13}
14
15/// Represents an open connection to the keystore.
16pub struct ConnectionStream {
17    pub addy: Option<String>,
18    pub protocol: Proto,
19    pub tcp_stream: Option<Box<TcpStream>>,
20    pub unix_stream: Option<Box<UnixStream>>,
21}
22
23impl ConnectionStream {
24    /// Create a new TCP connection to the keystore server.
25    pub async fn new_tcp_connection(addy: String) -> Result<Self, ErrorArrayItem> {
26        let protocol = Proto::TCP;
27        let address = format!("{}:{}", addy, PORT);
28
29        let stream = TcpStream::connect(address)
30            .await
31            .map_err(ErrorArrayItem::from)?;
32
33        Ok(Self {
34            addy: Some(addy),
35            protocol,
36            tcp_stream: Some(Box::new(stream)),
37            unix_stream: None,
38        })
39    }
40
41    /// Create a new Unix domain socket connection to the keystore server.
42    pub async fn new_unix_connection() -> Result<Self, ErrorArrayItem> {
43        let protocol = Proto::UNIX;
44
45        let stream = UnixStream::connect(SOCKETPATH)
46            .await
47            .map_err(ErrorArrayItem::from)?;
48
49        Ok(Self {
50            addy: None,
51            protocol,
52            tcp_stream: None,
53            unix_stream: Some(Box::new(stream)),
54        })
55    }
56
57    /// Ensure the internal connection is alive, reconnecting if necessary.
58    pub async fn ensure_connection(&mut self) -> Result<Self, ErrorArrayItem> {
59        let conn_proto: Proto = self.get_protocol();
60
61        match self.get_stream_mut() {
62            StreamMut::Tcp(tcp) => {
63                if tcp.peer_addr().is_ok() {
64                    return Ok(Self {
65                        addy: self.addy.clone(),
66                        protocol: self.protocol,
67                        tcp_stream: self.tcp_stream.take(),
68                        unix_stream: self.unix_stream.take(),
69                    });
70                }
71            }
72            StreamMut::Unix(unix) => {
73                if unix.peer_addr().is_ok() {
74                    return Ok(Self {
75                        addy: self.addy.clone(),
76                        protocol: self.protocol,
77                        tcp_stream: self.tcp_stream.take(),
78                        unix_stream: self.unix_stream.take(),
79                    });
80                }
81            }
82        }
83
84        match conn_proto {
85            Proto::TCP => {
86                if let Some(ref address) = self.addy {
87                    let new_stream = TcpStream::connect(format!("{}:{}", address, PORT))
88                        .await
89                        .map_err(ErrorArrayItem::from)?;
90                    self.tcp_stream = Some(Box::new(new_stream));
91                    self.unix_stream = None;
92                    Ok(Self {
93                        addy: self.addy.clone(),
94                        protocol: self.protocol,
95                        tcp_stream: self.tcp_stream.take(),
96                        unix_stream: self.unix_stream.take(),
97                    })
98                } else {
99                    Err(ErrorArrayItem::new(
100                        Errors::Network,
101                        "Missing address for TCP connection",
102                    ))
103                }
104            }
105            Proto::UNIX => {
106                let new_stream = UnixStream::connect(SOCKETPATH)
107                    .await
108                    .map_err(ErrorArrayItem::from)?;
109                self.unix_stream = Some(Box::new(new_stream));
110                self.tcp_stream = None;
111                Ok(Self {
112                    addy: self.addy.clone(),
113                    protocol: self.protocol,
114                    tcp_stream: self.tcp_stream.take(),
115                    unix_stream: self.unix_stream.take(),
116                })
117            }
118        }
119    }
120
121    /// Obtain a mutable reference to the underlying stream regardless of the protocol.
122    pub fn get_stream_mut(&mut self) -> StreamMut {
123
124        if let Some(tcp) = self.tcp_stream.as_mut() {
125            return StreamMut::Tcp(tcp)
126        };
127        
128        if let Some(unix) = self.unix_stream.as_mut() {
129            return StreamMut::Unix(unix)
130        };
131
132        unreachable!()
133    }
134
135    /// Return the protocol associated with this connection.
136    pub fn get_protocol(&self) -> Proto {
137        self.protocol
138    }
139}