modality_mutation_plane_client/
parent_connection.rs

1use modality_mutation_plane::protocol::{LeafwardsMessage, RootwardsMessage};
2use std::net::SocketAddr;
3use thiserror::Error;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::{TcpSocket, TcpStream};
6use tokio_native_tls::TlsStream;
7use url::Url;
8
9#[derive(Copy, Clone)]
10pub enum TlsMode {
11    Secure,
12    Insecure,
13}
14
15pub enum MutationParentConnection {
16    Tcp(TcpStream),
17    Tls(TlsStream<TcpStream>),
18}
19
20impl MutationParentConnection {
21    pub async fn connect(
22        endpoint: &Url,
23        allow_insecure_tls: bool,
24    ) -> Result<MutationParentConnection, MutationParentClientInitializationError> {
25        let endpoint = RootwardsEndpoint::parse_and_resolve(endpoint, allow_insecure_tls).await?;
26
27        // take the first addr, arbitrarily
28        let remote_addr = endpoint
29            .addrs
30            .into_iter()
31            .next()
32            .ok_or(MutationParentClientInitializationError::NoIps)?;
33
34        let local_addr: SocketAddr = if remote_addr.is_ipv4() {
35            "0.0.0.0:0"
36        } else {
37            "[::]:0"
38        }
39        .parse()?;
40
41        let socket = if remote_addr.is_ipv4() {
42            TcpSocket::new_v4().map_err(MutationParentClientInitializationError::SocketInit)?
43        } else {
44            TcpSocket::new_v6().map_err(MutationParentClientInitializationError::SocketInit)?
45        };
46
47        socket
48            .bind(local_addr)
49            .map_err(MutationParentClientInitializationError::SocketInit)?;
50        let stream = socket.connect(remote_addr).await.map_err(|error| {
51            MutationParentClientInitializationError::SocketConnection { error, remote_addr }
52        })?;
53
54        if let Some(tls_mode) = endpoint.tls_mode {
55            let cx = native_tls::TlsConnector::builder()
56                .danger_accept_invalid_certs(match tls_mode {
57                    TlsMode::Secure => false,
58                    TlsMode::Insecure => true,
59                })
60                .build()?;
61            let cx = tokio_native_tls::TlsConnector::from(cx);
62            let stream = cx.connect(&endpoint.cert_domain, stream).await?;
63            Ok(MutationParentConnection::Tls(stream))
64        } else {
65            Ok(MutationParentConnection::Tcp(stream))
66        }
67    }
68
69    pub async fn write_msg(&mut self, msg: &RootwardsMessage) -> Result<(), CommsError> {
70        let msg_buf = minicbor::to_vec(msg)?;
71        let msg_len = msg_buf.len() as u32;
72
73        match self {
74            MutationParentConnection::Tcp(s) => {
75                s.write_all(&msg_len.to_be_bytes())
76                    .await
77                    .map_err(minicbor::encode::Error::Write)?;
78                s.write_all(&msg_buf)
79                    .await
80                    .map_err(minicbor::encode::Error::Write)?;
81            }
82            MutationParentConnection::Tls(s) => {
83                // We have to use write_all here, because https://github.com/tokio-rs/tls/issues/41
84                s.write_all(&msg_len.to_be_bytes())
85                    .await
86                    .map_err(minicbor::encode::Error::Write)?;
87                s.write_all(&msg_buf)
88                    .await
89                    .map_err(minicbor::encode::Error::Write)?;
90            }
91        }
92
93        Ok(())
94    }
95
96    pub async fn read_msg(&mut self) -> Result<LeafwardsMessage, CommsError> {
97        match self {
98            MutationParentConnection::Tcp(s) => {
99                let msg_len = s.read_u32().await?; // yes, this is big-endian
100                let mut msg_buf = vec![0u8; msg_len as usize];
101                s.read_exact(msg_buf.as_mut_slice()).await?;
102
103                Ok(minicbor::decode::<LeafwardsMessage>(&msg_buf)?)
104            }
105            MutationParentConnection::Tls(s) => {
106                let msg_len = s.read_u32().await?; // yes, this is big-endian
107                let mut msg_buf = vec![0u8; msg_len as usize];
108                s.read_exact(msg_buf.as_mut_slice()).await?;
109
110                Ok(minicbor::decode::<LeafwardsMessage>(&msg_buf)?)
111            }
112        }
113    }
114}
115pub const MODALITY_MUTATION_CONNECT_PORT_DEFAULT: u16 = 14192;
116pub const MODALITY_MUTATION_CONNECT_TLS_PORT_DEFAULT: u16 = 14194;
117
118pub const MODALITY_MUTATION_URL_SCHEME: &str = "modality-mutation";
119pub const MODALITY_MUTATION_TLS_URL_SCHEME: &str = "modality-mutation-tls";
120
121struct RootwardsEndpoint {
122    cert_domain: String,
123    addrs: Vec<SocketAddr>,
124    tls_mode: Option<TlsMode>,
125}
126
127impl RootwardsEndpoint {
128    async fn parse_and_resolve(
129        url: &Url,
130        allow_insecure_tls: bool,
131    ) -> Result<RootwardsEndpoint, ParseRootwardsEndpointError> {
132        let host = match url.host() {
133            Some(h) => h,
134            None => return Err(ParseRootwardsEndpointError::MissingHost),
135        };
136
137        let is_tls = match url.scheme() {
138            MODALITY_MUTATION_URL_SCHEME => false,
139            MODALITY_MUTATION_TLS_URL_SCHEME => true,
140            s => return Err(ParseRootwardsEndpointError::InvalidScheme(s.to_string())),
141        };
142        let port = match url.port() {
143            Some(p) => p,
144            _ => {
145                if is_tls {
146                    MODALITY_MUTATION_CONNECT_TLS_PORT_DEFAULT
147                } else {
148                    MODALITY_MUTATION_CONNECT_PORT_DEFAULT
149                }
150            }
151        };
152
153        let addrs = match host {
154            url::Host::Domain(domain) => tokio::net::lookup_host((domain, port)).await?.collect(),
155            url::Host::Ipv4(addr) => vec![SocketAddr::from((addr, port))],
156            url::Host::Ipv6(addr) => vec![SocketAddr::from((addr, port))],
157        };
158
159        let tls_mode = match (is_tls, allow_insecure_tls) {
160            (true, true) => Some(TlsMode::Insecure),
161            (true, false) => Some(TlsMode::Secure),
162            (false, _) => None,
163        };
164
165        Ok(RootwardsEndpoint {
166            cert_domain: host.to_string(),
167            addrs,
168            tls_mode,
169        })
170    }
171}
172
173#[derive(Debug, Error)]
174pub enum MutationParentClientInitializationError {
175    #[error("DNS Error: No IPs")]
176    NoIps,
177
178    #[error("Socket initialization error")]
179    SocketInit(#[source] std::io::Error),
180
181    #[error("Socket connection error. Remote address: {}", remote_addr)]
182    SocketConnection {
183        #[source]
184        error: std::io::Error,
185        remote_addr: SocketAddr,
186    },
187
188    #[error("TLS Error")]
189    Tls(#[from] native_tls::Error),
190
191    #[error("Client local address parsing failed.")]
192    ClientLocalAddrParse(#[from] std::net::AddrParseError),
193
194    #[error("Error parsing endpoint")]
195    ParseIngestEndpoint(#[from] ParseRootwardsEndpointError),
196}
197
198#[derive(Debug, Error)]
199pub enum CommsError {
200    #[error("Marshalling Error (Write)")]
201    CborEncode(#[from] minicbor::encode::Error<std::io::Error>),
202
203    #[error("Marshalling Error (Read)")]
204    CborDecode(#[from] minicbor::decode::Error),
205
206    #[error("IO")]
207    Io(#[from] std::io::Error),
208}
209
210#[derive(Debug, Error)]
211pub enum ParseRootwardsEndpointError {
212    #[error("Url most contain a host")]
213    MissingHost,
214
215    // TODO update with the real thing
216    #[error(
217        "Invalid URL scheme '{0}'. Must be one of 'modality-mutation' or 'modality-mutation-tls'"
218    )]
219    InvalidScheme(String),
220
221    #[error("IO Error")]
222    Io(#[from] std::io::Error),
223}