auxon_sdk/mutation_plane_client/
parent_connection.rs

1use crate::mutation_plane::protocol::{LeafwardsMessage, RootwardsMessage};
2use std::net::SocketAddr;
3use thiserror::Error;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::{TcpSocket, TcpStream};
6use tokio_rustls::client::TlsStream;
7use tokio_rustls::TlsConnector;
8use url::Url;
9
10#[derive(Copy, Clone)]
11pub enum TlsMode {
12    Secure,
13    Insecure,
14}
15
16pub enum MutationParentConnection {
17    Tcp(TcpStream),
18    Tls(TlsStream<TcpStream>),
19}
20
21impl MutationParentConnection {
22    pub async fn connect(
23        endpoint: &Url,
24        allow_insecure_tls: bool,
25    ) -> Result<MutationParentConnection, MutationParentClientInitializationError> {
26        let endpoint = RootwardsEndpoint::parse_and_resolve(endpoint, allow_insecure_tls).await?;
27
28        // take the first addr, arbitrarily
29        let remote_addr = endpoint
30            .addrs
31            .into_iter()
32            .next()
33            .ok_or(MutationParentClientInitializationError::NoIps)?;
34
35        let local_addr: SocketAddr = if remote_addr.is_ipv4() {
36            "0.0.0.0:0"
37        } else {
38            "[::]:0"
39        }
40        .parse()?;
41
42        let socket = if remote_addr.is_ipv4() {
43            TcpSocket::new_v4().map_err(MutationParentClientInitializationError::SocketInit)?
44        } else {
45            TcpSocket::new_v6().map_err(MutationParentClientInitializationError::SocketInit)?
46        };
47
48        socket
49            .bind(local_addr)
50            .map_err(MutationParentClientInitializationError::SocketInit)?;
51        let stream = socket.connect(remote_addr).await.map_err(|error| {
52            MutationParentClientInitializationError::SocketConnection { error, remote_addr }
53        })?;
54
55        if let Some(tls_mode) = endpoint.tls_mode {
56            let config = match tls_mode {
57                TlsMode::Secure => crate::tls::SECURE.clone(),
58                TlsMode::Insecure => crate::tls::INSECURE.clone(),
59            };
60            let cx = TlsConnector::from(config);
61            let stream = cx.connect(endpoint.cert_domain.try_into()?, stream).await?;
62            Ok(MutationParentConnection::Tls(stream))
63        } else {
64            Ok(MutationParentConnection::Tcp(stream))
65        }
66    }
67
68    pub async fn write_msg(&mut self, msg: &RootwardsMessage) -> Result<(), CommsError> {
69        let msg_buf = minicbor::to_vec(msg)?;
70        let msg_len = msg_buf.len() as u32;
71
72        match self {
73            MutationParentConnection::Tcp(s) => {
74                s.write_all(&msg_len.to_be_bytes())
75                    .await
76                    .map_err(minicbor::encode::Error::Write)?;
77                s.write_all(&msg_buf)
78                    .await
79                    .map_err(minicbor::encode::Error::Write)?;
80            }
81            MutationParentConnection::Tls(s) => {
82                // We have to use write_all here, because https://github.com/tokio-rs/tls/issues/41
83                s.write_all(&msg_len.to_be_bytes())
84                    .await
85                    .map_err(minicbor::encode::Error::Write)?;
86                s.write_all(&msg_buf)
87                    .await
88                    .map_err(minicbor::encode::Error::Write)?;
89            }
90        }
91
92        Ok(())
93    }
94
95    pub async fn read_msg(&mut self) -> Result<LeafwardsMessage, CommsError> {
96        match self {
97            MutationParentConnection::Tcp(s) => {
98                let msg_len = s.read_u32().await?; // yes, this is big-endian
99                let mut msg_buf = vec![0u8; msg_len as usize];
100                s.read_exact(msg_buf.as_mut_slice()).await?;
101
102                Ok(minicbor::decode::<LeafwardsMessage>(&msg_buf)?)
103            }
104            MutationParentConnection::Tls(s) => {
105                let msg_len = s.read_u32().await?; // yes, this is big-endian
106                let mut msg_buf = vec![0u8; msg_len as usize];
107                s.read_exact(msg_buf.as_mut_slice()).await?;
108
109                Ok(minicbor::decode::<LeafwardsMessage>(&msg_buf)?)
110            }
111        }
112    }
113}
114pub const MODALITY_MUTATION_CONNECT_PORT_DEFAULT: u16 = 14192;
115pub const MODALITY_MUTATION_CONNECT_TLS_PORT_DEFAULT: u16 = 14194;
116
117pub const MODALITY_MUTATION_URL_SCHEME: &str = "modality-mutation";
118pub const MODALITY_MUTATION_TLS_URL_SCHEME: &str = "modality-mutation-tls";
119
120struct RootwardsEndpoint {
121    cert_domain: String,
122    addrs: Vec<SocketAddr>,
123    tls_mode: Option<TlsMode>,
124}
125
126impl RootwardsEndpoint {
127    async fn parse_and_resolve(
128        url: &Url,
129        allow_insecure_tls: bool,
130    ) -> Result<RootwardsEndpoint, ParseRootwardsEndpointError> {
131        let host = match url.host() {
132            Some(h) => h,
133            None => return Err(ParseRootwardsEndpointError::MissingHost),
134        };
135
136        let is_tls = match url.scheme() {
137            MODALITY_MUTATION_URL_SCHEME => false,
138            MODALITY_MUTATION_TLS_URL_SCHEME => true,
139            s => return Err(ParseRootwardsEndpointError::InvalidScheme(s.to_string())),
140        };
141        let port = match url.port() {
142            Some(p) => p,
143            _ => {
144                if is_tls {
145                    MODALITY_MUTATION_CONNECT_TLS_PORT_DEFAULT
146                } else {
147                    MODALITY_MUTATION_CONNECT_PORT_DEFAULT
148                }
149            }
150        };
151
152        let addrs = match host {
153            url::Host::Domain(domain) => tokio::net::lookup_host((domain, port)).await?.collect(),
154            url::Host::Ipv4(addr) => vec![SocketAddr::from((addr, port))],
155            url::Host::Ipv6(addr) => vec![SocketAddr::from((addr, port))],
156        };
157
158        let tls_mode = match (is_tls, allow_insecure_tls) {
159            (true, true) => Some(TlsMode::Insecure),
160            (true, false) => Some(TlsMode::Secure),
161            (false, _) => None,
162        };
163
164        Ok(RootwardsEndpoint {
165            cert_domain: host.to_string(),
166            addrs,
167            tls_mode,
168        })
169    }
170}
171
172#[derive(Debug, Error)]
173pub enum MutationParentClientInitializationError {
174    #[error("DNS Error: No IPs")]
175    NoIps,
176
177    #[error("Socket initialization error")]
178    SocketInit(#[source] std::io::Error),
179
180    #[error("Socket connection error. Remote address: {}", remote_addr)]
181    SocketConnection {
182        #[source]
183        error: std::io::Error,
184        remote_addr: SocketAddr,
185    },
186
187    #[error(transparent)]
188    InvalidDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError),
189
190    #[error(transparent)]
191    Io(#[from] std::io::Error),
192
193    #[error("Client local address parsing failed.")]
194    ClientLocalAddrParse(#[from] std::net::AddrParseError),
195
196    #[error("Error parsing endpoint")]
197    ParseIngestEndpoint(#[from] ParseRootwardsEndpointError),
198
199    #[error("Mutation plane authentication failed: {0}")]
200    AuthenticationFailed(String),
201
202    #[error("Mutation plane auth outcome received for a different participant")]
203    AuthWrongParticipant,
204
205    #[error("Unexpected auth response")]
206    UnexpectedAuthResponse,
207
208    #[error(transparent)]
209    CommsError(#[from] CommsError),
210}
211
212#[derive(Debug, Error)]
213pub enum CommsError {
214    #[error("Marshalling Error (Write)")]
215    CborEncode(#[from] minicbor::encode::Error<std::io::Error>),
216
217    #[error("Marshalling Error (Read)")]
218    CborDecode(#[from] minicbor::decode::Error),
219
220    #[error("IO")]
221    Io(#[from] std::io::Error),
222}
223
224#[derive(Debug, Error)]
225pub enum ParseRootwardsEndpointError {
226    #[error("Url most contain a host")]
227    MissingHost,
228
229    // TODO update with the real thing
230    #[error(
231        "Invalid URL scheme '{0}'. Must be one of 'modality-mutation' or 'modality-mutation-tls'"
232    )]
233    InvalidScheme(String),
234
235    #[error("IO Error")]
236    Io(#[from] std::io::Error),
237}