Skip to main content

couchbase_core/memdx/
connection.rs

1/*
2 *
3 *  * Copyright (c) 2025 Couchbase, Inc.
4 *  *
5 *  * Licensed under the Apache License, Version 2.0 (the "License");
6 *  * you may not use this file except in compliance with the License.
7 *  * You may obtain a copy of the License at
8 *  *
9 *  *    http://www.apache.org/licenses/LICENSE-2.0
10 *  *
11 *  * Unless required by applicable law or agreed to in writing, software
12 *  * distributed under the License is distributed on an "AS IS" BASIS,
13 *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  * See the License for the specific language governing permissions and
15 *  * limitations under the License.
16 *
17 */
18
19use crate::memdx::error::Error;
20use crate::memdx::error::Result;
21use crate::tls_config::TlsConfig;
22use socket2::TcpKeepalive;
23use std::fmt::Debug;
24use std::io;
25use std::net::{IpAddr, Ipv4Addr, SocketAddr};
26use std::time::Duration;
27use tokio::io::{AsyncRead, AsyncWrite};
28use tokio::net::TcpStream;
29use tokio::time::{timeout_at, Instant};
30
31use crate::address::Address;
32#[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
33use {
34    tokio_rustls::rustls::pki_types::DnsName, tokio_rustls::rustls::pki_types::ServerName,
35    tokio_rustls::TlsConnector,
36};
37
38#[derive(Debug)]
39pub struct ConnectOptions {
40    pub deadline: Instant,
41    pub tcp_keep_alive_time: Duration,
42}
43
44pub trait Stream: Debug + AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static {}
45
46impl Stream for TcpStream {}
47
48#[derive(Debug)]
49pub enum ConnectionType {
50    Tcp(TcpConnection),
51    Tls(TlsConnection),
52}
53
54impl ConnectionType {
55    pub fn into_inner(self) -> Box<dyn Stream> {
56        match self {
57            ConnectionType::Tcp(connection) => Box::new(connection.stream),
58            ConnectionType::Tls(connection) => Box::new(connection.stream),
59        }
60    }
61
62    pub fn local_addr(&self) -> &SocketAddr {
63        match self {
64            ConnectionType::Tcp(connection) => &connection.local_addr,
65            ConnectionType::Tls(connection) => &connection.local_addr,
66        }
67    }
68
69    pub fn peer_addr(&self) -> &SocketAddr {
70        match self {
71            ConnectionType::Tcp(connection) => &connection.peer_addr,
72            ConnectionType::Tls(connection) => &connection.peer_addr,
73        }
74    }
75}
76
77#[derive(Debug)]
78pub struct TcpConnection {
79    stream: TcpStream,
80
81    local_addr: SocketAddr,
82    peer_addr: SocketAddr,
83}
84
85impl TcpConnection {
86    async fn tcp_stream(
87        addr: &str,
88        opts: &ConnectOptions,
89    ) -> Result<(TcpStream, SocketAddr, SocketAddr)> {
90        let tcp_socket = timeout_at(opts.deadline, TcpStream::connect(addr))
91            .await
92            .map_err(|e| {
93                Error::new_connection_failed_error(
94                    "failed to connect to server within timeout",
95                    Box::new(io::Error::new(io::ErrorKind::TimedOut, e)),
96                )
97            })?
98            .map_err(|e| {
99                Error::new_connection_failed_error("failed to create tcp stream", Box::new(e))
100            })?;
101
102        let local_addr = tcp_socket
103            .local_addr()
104            .unwrap_or_else(|_e| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0));
105
106        let peer_addr = tcp_socket
107            .peer_addr()
108            .unwrap_or_else(|_e| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0));
109
110        // Tokio doesn't expose a keep alive function, but they just call into socket2 for set_linger.
111        socket2::SockRef::from(&tcp_socket)
112            .set_tcp_keepalive(&TcpKeepalive::new().with_time(opts.tcp_keep_alive_time))?;
113
114        tcp_socket.set_nodelay(false).map_err(|e| {
115            Error::new_connection_failed_error("failed to set tcp nodelay", Box::new(e))
116        })?;
117
118        Ok((tcp_socket, local_addr, peer_addr))
119    }
120
121    pub async fn connect(addr: Address, opts: ConnectOptions) -> Result<TcpConnection> {
122        let (stream, local_addr, peer_addr) =
123            TcpConnection::tcp_stream(addr.to_string().as_str(), &opts).await?;
124
125        Ok(TcpConnection {
126            stream,
127            local_addr,
128            peer_addr,
129        })
130    }
131
132    fn local_addr(&self) -> &SocketAddr {
133        &self.local_addr
134    }
135
136    fn peer_addr(&self) -> &SocketAddr {
137        &self.peer_addr
138    }
139}
140
141#[derive(Debug)]
142pub struct TlsConnection {
143    #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
144    stream: tokio_rustls::client::TlsStream<TcpStream>,
145    #[cfg(feature = "native-tls")]
146    stream: tokio_native_tls::TlsStream<TcpStream>,
147
148    local_addr: SocketAddr,
149    peer_addr: SocketAddr,
150}
151
152#[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
153impl Stream for tokio_rustls::client::TlsStream<TcpStream> {}
154
155#[cfg(feature = "native-tls")]
156impl Stream for tokio_native_tls::TlsStream<TcpStream> {}
157
158impl TlsConnection {
159    #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
160    pub async fn connect(
161        addr: Address,
162        tls_config: TlsConfig,
163        opts: ConnectOptions,
164    ) -> Result<TlsConnection> {
165        let (tcp_socket, local_addr, peer_addr) =
166            TcpConnection::tcp_stream(addr.to_string().as_str(), &opts).await?;
167
168        let connector = TlsConnector::from(tls_config);
169
170        let server_name = match DnsName::try_from(addr.host) {
171            Ok(name) => ServerName::DnsName(name),
172            Err(_e) => ServerName::IpAddress(tokio_rustls::rustls::pki_types::IpAddr::from(
173                peer_addr.ip(),
174            )),
175        };
176
177        let stream = timeout_at(opts.deadline, connector.connect(server_name, tcp_socket))
178            .await
179            .map_err(|e| {
180                Error::new_connection_failed_error(
181                    "failed to upgrade tcp stream to tls within timeout",
182                    Box::new(io::Error::new(io::ErrorKind::TimedOut, e)),
183                )
184            })?
185            .map_err(|e| {
186                Error::new_connection_failed_error(
187                    "failed to upgrade tcp stream to tls",
188                    Box::new(e),
189                )
190            })?;
191
192        Ok(TlsConnection {
193            stream,
194            local_addr,
195            peer_addr,
196        })
197    }
198
199    #[cfg(feature = "native-tls")]
200    pub async fn connect(
201        addr: Address,
202        tls_config: TlsConfig,
203        opts: ConnectOptions,
204    ) -> Result<TlsConnection> {
205        let (tcp_socket, local_addr, peer_addr) =
206            TcpConnection::tcp_stream(addr.to_string().as_str(), &opts).await?;
207
208        let tls_connector = tokio_native_tls::TlsConnector::from(tls_config);
209
210        let remote_addr = addr.to_string();
211        let stream = timeout_at(
212            opts.deadline,
213            tls_connector.connect(&remote_addr, tcp_socket),
214        )
215        .await
216        .map_err(|e| {
217            Error::new_connection_failed_error(
218                "failed to upgrade tcp stream to tls within timeout",
219                Box::new(io::Error::new(io::ErrorKind::TimedOut, e)),
220            )
221        })?
222        .map_err(|e| {
223            Error::new_connection_failed_error(
224                "failed to upgrade tcp stream to tls",
225                Box::new(io::Error::other(e)),
226            )
227        })?;
228
229        Ok(TlsConnection {
230            stream,
231            local_addr,
232            peer_addr,
233        })
234    }
235
236    fn local_addr(&self) -> &SocketAddr {
237        &self.local_addr
238    }
239
240    fn peer_addr(&self) -> &SocketAddr {
241        &self.peer_addr
242    }
243}