modality_mutation_plane_client/
parent_connection.rs1use modality_mutation_plane::protocol::{LeafwardsMessage, RootwardsMessage};
2use std::net::SocketAddr;
3use thiserror::Error;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::{TcpSocket, TcpStream};
6use tokio_native_tls::TlsStream;
7use url::Url;
8
9#[derive(Copy, Clone)]
10pub enum TlsMode {
11 Secure,
12 Insecure,
13}
14
15pub enum MutationParentConnection {
16 Tcp(TcpStream),
17 Tls(TlsStream<TcpStream>),
18}
19
20impl MutationParentConnection {
21 pub async fn connect(
22 endpoint: &Url,
23 allow_insecure_tls: bool,
24 ) -> Result<MutationParentConnection, MutationParentClientInitializationError> {
25 let endpoint = RootwardsEndpoint::parse_and_resolve(endpoint, allow_insecure_tls).await?;
26
27 let remote_addr = endpoint
29 .addrs
30 .into_iter()
31 .next()
32 .ok_or(MutationParentClientInitializationError::NoIps)?;
33
34 let local_addr: SocketAddr = if remote_addr.is_ipv4() {
35 "0.0.0.0:0"
36 } else {
37 "[::]:0"
38 }
39 .parse()?;
40
41 let socket = if remote_addr.is_ipv4() {
42 TcpSocket::new_v4().map_err(MutationParentClientInitializationError::SocketInit)?
43 } else {
44 TcpSocket::new_v6().map_err(MutationParentClientInitializationError::SocketInit)?
45 };
46
47 socket
48 .bind(local_addr)
49 .map_err(MutationParentClientInitializationError::SocketInit)?;
50 let stream = socket.connect(remote_addr).await.map_err(|error| {
51 MutationParentClientInitializationError::SocketConnection { error, remote_addr }
52 })?;
53
54 if let Some(tls_mode) = endpoint.tls_mode {
55 let cx = native_tls::TlsConnector::builder()
56 .danger_accept_invalid_certs(match tls_mode {
57 TlsMode::Secure => false,
58 TlsMode::Insecure => true,
59 })
60 .build()?;
61 let cx = tokio_native_tls::TlsConnector::from(cx);
62 let stream = cx.connect(&endpoint.cert_domain, stream).await?;
63 Ok(MutationParentConnection::Tls(stream))
64 } else {
65 Ok(MutationParentConnection::Tcp(stream))
66 }
67 }
68
69 pub async fn write_msg(&mut self, msg: &RootwardsMessage) -> Result<(), CommsError> {
70 let msg_buf = minicbor::to_vec(msg)?;
71 let msg_len = msg_buf.len() as u32;
72
73 match self {
74 MutationParentConnection::Tcp(s) => {
75 s.write_all(&msg_len.to_be_bytes())
76 .await
77 .map_err(minicbor::encode::Error::Write)?;
78 s.write_all(&msg_buf)
79 .await
80 .map_err(minicbor::encode::Error::Write)?;
81 }
82 MutationParentConnection::Tls(s) => {
83 s.write_all(&msg_len.to_be_bytes())
85 .await
86 .map_err(minicbor::encode::Error::Write)?;
87 s.write_all(&msg_buf)
88 .await
89 .map_err(minicbor::encode::Error::Write)?;
90 }
91 }
92
93 Ok(())
94 }
95
96 pub async fn read_msg(&mut self) -> Result<LeafwardsMessage, CommsError> {
97 match self {
98 MutationParentConnection::Tcp(s) => {
99 let msg_len = s.read_u32().await?; let mut msg_buf = vec![0u8; msg_len as usize];
101 s.read_exact(msg_buf.as_mut_slice()).await?;
102
103 Ok(minicbor::decode::<LeafwardsMessage>(&msg_buf)?)
104 }
105 MutationParentConnection::Tls(s) => {
106 let msg_len = s.read_u32().await?; let mut msg_buf = vec![0u8; msg_len as usize];
108 s.read_exact(msg_buf.as_mut_slice()).await?;
109
110 Ok(minicbor::decode::<LeafwardsMessage>(&msg_buf)?)
111 }
112 }
113 }
114}
115pub const MODALITY_MUTATION_CONNECT_PORT_DEFAULT: u16 = 14192;
116pub const MODALITY_MUTATION_CONNECT_TLS_PORT_DEFAULT: u16 = 14194;
117
118pub const MODALITY_MUTATION_URL_SCHEME: &str = "modality-mutation";
119pub const MODALITY_MUTATION_TLS_URL_SCHEME: &str = "modality-mutation-tls";
120
121struct RootwardsEndpoint {
122 cert_domain: String,
123 addrs: Vec<SocketAddr>,
124 tls_mode: Option<TlsMode>,
125}
126
127impl RootwardsEndpoint {
128 async fn parse_and_resolve(
129 url: &Url,
130 allow_insecure_tls: bool,
131 ) -> Result<RootwardsEndpoint, ParseRootwardsEndpointError> {
132 let host = match url.host() {
133 Some(h) => h,
134 None => return Err(ParseRootwardsEndpointError::MissingHost),
135 };
136
137 let is_tls = match url.scheme() {
138 MODALITY_MUTATION_URL_SCHEME => false,
139 MODALITY_MUTATION_TLS_URL_SCHEME => true,
140 s => return Err(ParseRootwardsEndpointError::InvalidScheme(s.to_string())),
141 };
142 let port = match url.port() {
143 Some(p) => p,
144 _ => {
145 if is_tls {
146 MODALITY_MUTATION_CONNECT_TLS_PORT_DEFAULT
147 } else {
148 MODALITY_MUTATION_CONNECT_PORT_DEFAULT
149 }
150 }
151 };
152
153 let addrs = match host {
154 url::Host::Domain(domain) => tokio::net::lookup_host((domain, port)).await?.collect(),
155 url::Host::Ipv4(addr) => vec![SocketAddr::from((addr, port))],
156 url::Host::Ipv6(addr) => vec![SocketAddr::from((addr, port))],
157 };
158
159 let tls_mode = match (is_tls, allow_insecure_tls) {
160 (true, true) => Some(TlsMode::Insecure),
161 (true, false) => Some(TlsMode::Secure),
162 (false, _) => None,
163 };
164
165 Ok(RootwardsEndpoint {
166 cert_domain: host.to_string(),
167 addrs,
168 tls_mode,
169 })
170 }
171}
172
173#[derive(Debug, Error)]
174pub enum MutationParentClientInitializationError {
175 #[error("DNS Error: No IPs")]
176 NoIps,
177
178 #[error("Socket initialization error")]
179 SocketInit(#[source] std::io::Error),
180
181 #[error("Socket connection error. Remote address: {}", remote_addr)]
182 SocketConnection {
183 #[source]
184 error: std::io::Error,
185 remote_addr: SocketAddr,
186 },
187
188 #[error("TLS Error")]
189 Tls(#[from] native_tls::Error),
190
191 #[error("Client local address parsing failed.")]
192 ClientLocalAddrParse(#[from] std::net::AddrParseError),
193
194 #[error("Error parsing endpoint")]
195 ParseIngestEndpoint(#[from] ParseRootwardsEndpointError),
196}
197
198#[derive(Debug, Error)]
199pub enum CommsError {
200 #[error("Marshalling Error (Write)")]
201 CborEncode(#[from] minicbor::encode::Error<std::io::Error>),
202
203 #[error("Marshalling Error (Read)")]
204 CborDecode(#[from] minicbor::decode::Error),
205
206 #[error("IO")]
207 Io(#[from] std::io::Error),
208}
209
210#[derive(Debug, Error)]
211pub enum ParseRootwardsEndpointError {
212 #[error("Url most contain a host")]
213 MissingHost,
214
215 #[error(
217 "Invalid URL scheme '{0}'. Must be one of 'modality-mutation' or 'modality-mutation-tls'"
218 )]
219 InvalidScheme(String),
220
221 #[error("IO Error")]
222 Io(#[from] std::io::Error),
223}