auxon_sdk/mutation_plane_client/
parent_connection.rs1use crate::mutation_plane::protocol::{LeafwardsMessage, RootwardsMessage};
2use std::net::SocketAddr;
3use thiserror::Error;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::{TcpSocket, TcpStream};
6use tokio_rustls::client::TlsStream;
7use tokio_rustls::TlsConnector;
8use url::Url;
9
10#[derive(Copy, Clone)]
11pub enum TlsMode {
12 Secure,
13 Insecure,
14}
15
16pub enum MutationParentConnection {
17 Tcp(TcpStream),
18 Tls(TlsStream<TcpStream>),
19}
20
21impl MutationParentConnection {
22 pub async fn connect(
23 endpoint: &Url,
24 allow_insecure_tls: bool,
25 ) -> Result<MutationParentConnection, MutationParentClientInitializationError> {
26 let endpoint = RootwardsEndpoint::parse_and_resolve(endpoint, allow_insecure_tls).await?;
27
28 let remote_addr = endpoint
30 .addrs
31 .into_iter()
32 .next()
33 .ok_or(MutationParentClientInitializationError::NoIps)?;
34
35 let local_addr: SocketAddr = if remote_addr.is_ipv4() {
36 "0.0.0.0:0"
37 } else {
38 "[::]:0"
39 }
40 .parse()?;
41
42 let socket = if remote_addr.is_ipv4() {
43 TcpSocket::new_v4().map_err(MutationParentClientInitializationError::SocketInit)?
44 } else {
45 TcpSocket::new_v6().map_err(MutationParentClientInitializationError::SocketInit)?
46 };
47
48 socket
49 .bind(local_addr)
50 .map_err(MutationParentClientInitializationError::SocketInit)?;
51 let stream = socket.connect(remote_addr).await.map_err(|error| {
52 MutationParentClientInitializationError::SocketConnection { error, remote_addr }
53 })?;
54
55 if let Some(tls_mode) = endpoint.tls_mode {
56 let config = match tls_mode {
57 TlsMode::Secure => crate::tls::SECURE.clone(),
58 TlsMode::Insecure => crate::tls::INSECURE.clone(),
59 };
60 let cx = TlsConnector::from(config);
61 let stream = cx.connect(endpoint.cert_domain.try_into()?, stream).await?;
62 Ok(MutationParentConnection::Tls(stream))
63 } else {
64 Ok(MutationParentConnection::Tcp(stream))
65 }
66 }
67
68 pub async fn write_msg(&mut self, msg: &RootwardsMessage) -> Result<(), CommsError> {
69 let msg_buf = minicbor::to_vec(msg)?;
70 let msg_len = msg_buf.len() as u32;
71
72 match self {
73 MutationParentConnection::Tcp(s) => {
74 s.write_all(&msg_len.to_be_bytes())
75 .await
76 .map_err(minicbor::encode::Error::Write)?;
77 s.write_all(&msg_buf)
78 .await
79 .map_err(minicbor::encode::Error::Write)?;
80 }
81 MutationParentConnection::Tls(s) => {
82 s.write_all(&msg_len.to_be_bytes())
84 .await
85 .map_err(minicbor::encode::Error::Write)?;
86 s.write_all(&msg_buf)
87 .await
88 .map_err(minicbor::encode::Error::Write)?;
89 }
90 }
91
92 Ok(())
93 }
94
95 pub async fn read_msg(&mut self) -> Result<LeafwardsMessage, CommsError> {
96 match self {
97 MutationParentConnection::Tcp(s) => {
98 let msg_len = s.read_u32().await?; let mut msg_buf = vec![0u8; msg_len as usize];
100 s.read_exact(msg_buf.as_mut_slice()).await?;
101
102 Ok(minicbor::decode::<LeafwardsMessage>(&msg_buf)?)
103 }
104 MutationParentConnection::Tls(s) => {
105 let msg_len = s.read_u32().await?; let mut msg_buf = vec![0u8; msg_len as usize];
107 s.read_exact(msg_buf.as_mut_slice()).await?;
108
109 Ok(minicbor::decode::<LeafwardsMessage>(&msg_buf)?)
110 }
111 }
112 }
113}
114pub const MODALITY_MUTATION_CONNECT_PORT_DEFAULT: u16 = 14192;
115pub const MODALITY_MUTATION_CONNECT_TLS_PORT_DEFAULT: u16 = 14194;
116
117pub const MODALITY_MUTATION_URL_SCHEME: &str = "modality-mutation";
118pub const MODALITY_MUTATION_TLS_URL_SCHEME: &str = "modality-mutation-tls";
119
120struct RootwardsEndpoint {
121 cert_domain: String,
122 addrs: Vec<SocketAddr>,
123 tls_mode: Option<TlsMode>,
124}
125
126impl RootwardsEndpoint {
127 async fn parse_and_resolve(
128 url: &Url,
129 allow_insecure_tls: bool,
130 ) -> Result<RootwardsEndpoint, ParseRootwardsEndpointError> {
131 let host = match url.host() {
132 Some(h) => h,
133 None => return Err(ParseRootwardsEndpointError::MissingHost),
134 };
135
136 let is_tls = match url.scheme() {
137 MODALITY_MUTATION_URL_SCHEME => false,
138 MODALITY_MUTATION_TLS_URL_SCHEME => true,
139 s => return Err(ParseRootwardsEndpointError::InvalidScheme(s.to_string())),
140 };
141 let port = match url.port() {
142 Some(p) => p,
143 _ => {
144 if is_tls {
145 MODALITY_MUTATION_CONNECT_TLS_PORT_DEFAULT
146 } else {
147 MODALITY_MUTATION_CONNECT_PORT_DEFAULT
148 }
149 }
150 };
151
152 let addrs = match host {
153 url::Host::Domain(domain) => tokio::net::lookup_host((domain, port)).await?.collect(),
154 url::Host::Ipv4(addr) => vec![SocketAddr::from((addr, port))],
155 url::Host::Ipv6(addr) => vec![SocketAddr::from((addr, port))],
156 };
157
158 let tls_mode = match (is_tls, allow_insecure_tls) {
159 (true, true) => Some(TlsMode::Insecure),
160 (true, false) => Some(TlsMode::Secure),
161 (false, _) => None,
162 };
163
164 Ok(RootwardsEndpoint {
165 cert_domain: host.to_string(),
166 addrs,
167 tls_mode,
168 })
169 }
170}
171
172#[derive(Debug, Error)]
173pub enum MutationParentClientInitializationError {
174 #[error("DNS Error: No IPs")]
175 NoIps,
176
177 #[error("Socket initialization error")]
178 SocketInit(#[source] std::io::Error),
179
180 #[error("Socket connection error. Remote address: {}", remote_addr)]
181 SocketConnection {
182 #[source]
183 error: std::io::Error,
184 remote_addr: SocketAddr,
185 },
186
187 #[error(transparent)]
188 InvalidDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError),
189
190 #[error(transparent)]
191 Io(#[from] std::io::Error),
192
193 #[error("Client local address parsing failed.")]
194 ClientLocalAddrParse(#[from] std::net::AddrParseError),
195
196 #[error("Error parsing endpoint")]
197 ParseIngestEndpoint(#[from] ParseRootwardsEndpointError),
198
199 #[error("Mutation plane authentication failed: {0}")]
200 AuthenticationFailed(String),
201
202 #[error("Mutation plane auth outcome received for a different participant")]
203 AuthWrongParticipant,
204
205 #[error("Unexpected auth response")]
206 UnexpectedAuthResponse,
207
208 #[error(transparent)]
209 CommsError(#[from] CommsError),
210}
211
212#[derive(Debug, Error)]
213pub enum CommsError {
214 #[error("Marshalling Error (Write)")]
215 CborEncode(#[from] minicbor::encode::Error<std::io::Error>),
216
217 #[error("Marshalling Error (Read)")]
218 CborDecode(#[from] minicbor::decode::Error),
219
220 #[error("IO")]
221 Io(#[from] std::io::Error),
222}
223
224#[derive(Debug, Error)]
225pub enum ParseRootwardsEndpointError {
226 #[error("Url most contain a host")]
227 MissingHost,
228
229 #[error(
231 "Invalid URL scheme '{0}'. Must be one of 'modality-mutation' or 'modality-mutation-tls'"
232 )]
233 InvalidScheme(String),
234
235 #[error("IO Error")]
236 Io(#[from] std::io::Error),
237}