mptcp/
std.rs

1use std::{
2    io,
3    net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
4};
5
6use crate::{sys::MptcpSocketBuilder, MptcpExt, MptcpOpt, MptcpSocket};
7
8/// Extension trait for std::net::TcpStream to support MPTCP.
9pub trait MptcpStreamExt {
10    type Output;
11
12    /// Establishes an MPTCP connection with the given address and MptcpOpt.
13    ///
14    /// # Arguments
15    ///
16    /// * `addr` - The address to connect to.
17    /// * `opt` - The MptcpOpt options for the connection.
18    ///
19    /// # Returns
20    ///
21    /// Returns an `io::Result` containing the MptcpSocket if the connection is successful,
22    /// or an `io::Error` if an error occurs during the connection.
23    fn connect_mptcp_opt<A: ToSocketAddrs>(
24        addr: A,
25        opt: MptcpOpt,
26    ) -> io::Result<MptcpSocket<Self::Output>>;
27
28    /// Establishes an MPTCP connection with the given address. If MPTCP cannot be used
29    /// the connection will fallback to a regular TCP connection.
30    ///
31    /// # Arguments
32    ///
33    /// * `addr` - The address to connect to.
34    ///
35    /// # Returns
36    ///
37    /// Returns an `io::Result` containing the MptcpSocket if the connection is successful,
38    /// or an `io::Error` if an error occurs during the connection.
39    fn connect_mptcp<A: ToSocketAddrs>(addr: A) -> io::Result<MptcpSocket<Self::Output>> {
40        Self::connect_mptcp_opt(addr, MptcpOpt::Fallback)
41    }
42
43    /// Establishes an MPTCP connection with the given address. Returns an error even if
44    /// MPTCP cannot be used. See `connect_mptcp` for a version that falls back to TCP.
45    ///
46    /// # Arguments
47    ///
48    /// * `addr` - The address to connect to.
49    ///
50    /// # Returns
51    ///
52    /// Returns an `io::Result` containing the MptcpSocket if the connection is successful,
53    /// or an `io::Error` if an error occurs during the connection.
54    fn connect_mptcp_force<A: ToSocketAddrs>(addr: A) -> io::Result<Self::Output> {
55        Ok(Self::connect_mptcp_opt(addr, MptcpOpt::NoFallback)?.into_socket())
56    }
57}
58
59/// Extension trait for std::net::TcpListener.
60pub trait MptcpListenerExt {
61    type Output;
62
63    /// Binds an MPTCP socket to the specified address with the given MptcpOpt.
64    ///
65    /// # Arguments
66    ///
67    /// * `addr` - The address to bind the socket to.
68    /// * `opt` - The MptcpOpt to use for the socket.
69    ///
70    /// # Returns
71    ///
72    /// Returns an `io::Result` containing the MptcpSocket with the specified MptcpOpt.
73    fn bind_mptcp_opt<A: ToSocketAddrs>(
74        addr: A,
75        opt: MptcpOpt,
76    ) -> io::Result<MptcpSocket<Self::Output>>;
77
78    /// Binds an MPTCP socket to the specified address. If MPTCP cannot be used
79    /// the connection will fallback to a regular TCP connection.
80    ///
81    /// # Arguments
82    ///
83    /// * `addr` - The address to bind the socket to.
84    ///
85    /// # Returns
86    ///
87    /// Returns an `io::Result` containing the MptcpSocket with the default MptcpOpt (Fallback).
88    fn bind_mptcp<A: ToSocketAddrs>(addr: A) -> io::Result<MptcpSocket<Self::Output>> {
89        Self::bind_mptcp_opt(addr, MptcpOpt::Fallback)
90    }
91
92    /// Binds an MPTCP socket to the specified address. Returns an error even if
93    /// MPTCP cannot be used. See `bind_mptcp` for a version that falls back to TCP.
94    ///
95    /// # Arguments
96    ///
97    /// * `addr` - The address to bind the socket to.
98    ///
99    /// # Returns
100    ///
101    /// Returns an `io::Result` containing the MptcpSocket with the MptcpOpt set to NoFallback.
102    fn bind_mptcp_force<A: ToSocketAddrs>(addr: A) -> io::Result<Self::Output> {
103        Ok(Self::bind_mptcp_opt(addr, MptcpOpt::NoFallback)?.into_socket())
104    }
105}
106
107fn resolve_each_addr<A: ToSocketAddrs, F, T>(addr: &A, mut f: F) -> io::Result<T>
108where
109    F: FnMut(SocketAddr) -> io::Result<T>,
110{
111    let addrs = addr.to_socket_addrs()?;
112    let mut last_err = None;
113    for addr in addrs {
114        match f(addr) {
115            Ok(l) => return Ok(l),
116            Err(e) => last_err = Some(e),
117        }
118    }
119    Err(last_err.unwrap_or_else(|| {
120        io::Error::new(
121            io::ErrorKind::InvalidInput,
122            "could not resolve to any address",
123        )
124    }))
125}
126
127impl MptcpStreamExt for TcpStream {
128    type Output = Self;
129
130    fn connect_mptcp_opt<A: ToSocketAddrs>(
131        addr: A,
132        opt: MptcpOpt,
133    ) -> io::Result<MptcpSocket<Self::Output>> {
134        match resolve_each_addr(&addr, |addr| {
135            MptcpSocketBuilder::new_for_addr(addr)?.connect(addr)
136        }) {
137            Ok(sock) => Ok(MptcpSocket::Mptcp(sock.into())),
138            Err(_) if matches!(opt, MptcpOpt::Fallback) => {
139                Ok(MptcpSocket::Tcp(Self::connect(addr)?))
140            }
141            Err(err) => Err(err),
142        }
143    }
144}
145
146impl MptcpExt for TcpStream {}
147
148impl From<MptcpSocket<TcpStream>> for TcpStream {
149    fn from(socket: MptcpSocket<TcpStream>) -> Self {
150        socket.into_socket()
151    }
152}
153
154impl MptcpListenerExt for TcpListener {
155    type Output = Self;
156
157    fn bind_mptcp_opt<A: ToSocketAddrs>(
158        addr: A,
159        opt: MptcpOpt,
160    ) -> io::Result<MptcpSocket<Self::Output>> {
161        match resolve_each_addr(&addr, |addr| {
162            MptcpSocketBuilder::new_for_addr(addr)?.bind(addr)
163        }) {
164            Ok(sock) => Ok(MptcpSocket::Mptcp(sock.into())),
165            Err(_) if matches!(opt, MptcpOpt::Fallback) => Ok(MptcpSocket::Tcp(Self::bind(addr)?)),
166            Err(err) => Err(err),
167        }
168    }
169}
170
171impl From<MptcpSocket<TcpListener>> for TcpListener {
172    fn from(socket: MptcpSocket<TcpListener>) -> Self {
173        socket.into_socket()
174    }
175}
176
177#[cfg(all(test, target_os = "linux"))]
178mod test {
179    use std::net::{IpAddr, Ipv4Addr};
180
181    use crate::sys::is_mptcp_enabled;
182
183    use super::*;
184
185    #[test]
186    fn test_resolve_each_addr() {
187        let addr = "127.0.0.1:80";
188        let result = resolve_each_addr(&addr, |addr| {
189            assert_eq!(addr.port(), 80);
190            assert_eq!(addr.ip(), IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
191            Ok(())
192        });
193        assert!(result.is_ok());
194    }
195
196    #[test]
197    fn test_resolve_each_addr_error() {
198        let addr = "thisisanerror";
199        let result = resolve_each_addr(&addr, |_| Ok(()));
200        assert!(result.is_err());
201    }
202
203    #[test]
204    fn test_mptcp_socket() {
205        let mptcp_enabled = is_mptcp_enabled();
206
207        let listener = TcpListener::bind_mptcp("127.0.0.1:0");
208        if mptcp_enabled {
209            assert!(matches!(listener, Ok(MptcpSocket::Mptcp(..))));
210        } else {
211            assert!(matches!(listener, Ok(MptcpSocket::Tcp(..))));
212        }
213
214        let listener = listener.unwrap().into_socket();
215        let local_addr = listener.local_addr().unwrap();
216
217        let stream = TcpStream::connect_mptcp(local_addr);
218        if mptcp_enabled {
219            assert!(matches!(stream, Ok(MptcpSocket::Mptcp(..))));
220        } else {
221            assert!(matches!(stream, Ok(MptcpSocket::Tcp(..))));
222        }
223    }
224
225    #[test]
226    fn test_mptcp_no_fallback() {
227        let mptcp_enabled = is_mptcp_enabled();
228
229        if mptcp_enabled {
230            // If the system supports MPTCP, we cannot test the no fallback option
231            return;
232        }
233
234        let listener = TcpListener::bind_mptcp_force("127.0.0.1:0");
235        assert!(listener.is_err());
236
237        let stream = TcpStream::connect_mptcp_force("127.0.0.1:0");
238        assert!(stream.is_err());
239    }
240}