ais_keystore_lib/shared/
conn.rs1use dusa_collection_utils::core::errors::{ErrorArrayItem, Errors};
4use simple_comms::protocol::proto::Proto;
5use tokio::net::{TcpStream, UnixStream};
6
7use super::consts::{PORT, SOCKETPATH};
8
9pub enum StreamMut<'a> {
11 Tcp(&'a mut Box<TcpStream>),
12 Unix(&'a mut Box<UnixStream>),
13}
14
15pub struct ConnectionStream {
17 pub addy: Option<String>,
18 pub protocol: Proto,
19 pub tcp_stream: Option<Box<TcpStream>>,
20 pub unix_stream: Option<Box<UnixStream>>,
21}
22
23impl ConnectionStream {
24 pub async fn new_tcp_connection(addy: String) -> Result<Self, ErrorArrayItem> {
26 let protocol = Proto::TCP;
27 let address = format!("{}:{}", addy, PORT);
28
29 let stream = TcpStream::connect(address)
30 .await
31 .map_err(ErrorArrayItem::from)?;
32
33 Ok(Self {
34 addy: Some(addy),
35 protocol,
36 tcp_stream: Some(Box::new(stream)),
37 unix_stream: None,
38 })
39 }
40
41 pub async fn new_unix_connection() -> Result<Self, ErrorArrayItem> {
43 let protocol = Proto::UNIX;
44
45 let stream = UnixStream::connect(SOCKETPATH)
46 .await
47 .map_err(ErrorArrayItem::from)?;
48
49 Ok(Self {
50 addy: None,
51 protocol,
52 tcp_stream: None,
53 unix_stream: Some(Box::new(stream)),
54 })
55 }
56
57 pub async fn ensure_connection(&mut self) -> Result<Self, ErrorArrayItem> {
59 let conn_proto: Proto = self.get_protocol();
60
61 match self.get_stream_mut() {
62 StreamMut::Tcp(tcp) => {
63 if tcp.peer_addr().is_ok() {
64 return Ok(Self {
65 addy: self.addy.clone(),
66 protocol: self.protocol,
67 tcp_stream: self.tcp_stream.take(),
68 unix_stream: self.unix_stream.take(),
69 });
70 }
71 }
72 StreamMut::Unix(unix) => {
73 if unix.peer_addr().is_ok() {
74 return Ok(Self {
75 addy: self.addy.clone(),
76 protocol: self.protocol,
77 tcp_stream: self.tcp_stream.take(),
78 unix_stream: self.unix_stream.take(),
79 });
80 }
81 }
82 }
83
84 match conn_proto {
85 Proto::TCP => {
86 if let Some(ref address) = self.addy {
87 let new_stream = TcpStream::connect(format!("{}:{}", address, PORT))
88 .await
89 .map_err(ErrorArrayItem::from)?;
90 self.tcp_stream = Some(Box::new(new_stream));
91 self.unix_stream = None;
92 Ok(Self {
93 addy: self.addy.clone(),
94 protocol: self.protocol,
95 tcp_stream: self.tcp_stream.take(),
96 unix_stream: self.unix_stream.take(),
97 })
98 } else {
99 Err(ErrorArrayItem::new(
100 Errors::Network,
101 "Missing address for TCP connection",
102 ))
103 }
104 }
105 Proto::UNIX => {
106 let new_stream = UnixStream::connect(SOCKETPATH)
107 .await
108 .map_err(ErrorArrayItem::from)?;
109 self.unix_stream = Some(Box::new(new_stream));
110 self.tcp_stream = None;
111 Ok(Self {
112 addy: self.addy.clone(),
113 protocol: self.protocol,
114 tcp_stream: self.tcp_stream.take(),
115 unix_stream: self.unix_stream.take(),
116 })
117 }
118 }
119 }
120
121 pub fn get_stream_mut(&mut self) -> StreamMut {
123
124 if let Some(tcp) = self.tcp_stream.as_mut() {
125 return StreamMut::Tcp(tcp)
126 };
127
128 if let Some(unix) = self.unix_stream.as_mut() {
129 return StreamMut::Unix(unix)
130 };
131
132 unreachable!()
133 }
134
135 pub fn get_protocol(&self) -> Proto {
137 self.protocol
138 }
139}