1use std::{
2 io,
3 net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
4};
5
6use crate::{sys::MptcpSocketBuilder, MptcpExt, MptcpOpt, MptcpSocket};
7
8pub trait MptcpStreamExt {
10 type Output;
11
12 fn connect_mptcp_opt<A: ToSocketAddrs>(
24 addr: A,
25 opt: MptcpOpt,
26 ) -> io::Result<MptcpSocket<Self::Output>>;
27
28 fn connect_mptcp<A: ToSocketAddrs>(addr: A) -> io::Result<MptcpSocket<Self::Output>> {
40 Self::connect_mptcp_opt(addr, MptcpOpt::Fallback)
41 }
42
43 fn connect_mptcp_force<A: ToSocketAddrs>(addr: A) -> io::Result<Self::Output> {
55 Ok(Self::connect_mptcp_opt(addr, MptcpOpt::NoFallback)?.into_socket())
56 }
57}
58
59pub trait MptcpListenerExt {
61 type Output;
62
63 fn bind_mptcp_opt<A: ToSocketAddrs>(
74 addr: A,
75 opt: MptcpOpt,
76 ) -> io::Result<MptcpSocket<Self::Output>>;
77
78 fn bind_mptcp<A: ToSocketAddrs>(addr: A) -> io::Result<MptcpSocket<Self::Output>> {
89 Self::bind_mptcp_opt(addr, MptcpOpt::Fallback)
90 }
91
92 fn bind_mptcp_force<A: ToSocketAddrs>(addr: A) -> io::Result<Self::Output> {
103 Ok(Self::bind_mptcp_opt(addr, MptcpOpt::NoFallback)?.into_socket())
104 }
105}
106
107fn resolve_each_addr<A: ToSocketAddrs, F, T>(addr: &A, mut f: F) -> io::Result<T>
108where
109 F: FnMut(SocketAddr) -> io::Result<T>,
110{
111 let addrs = addr.to_socket_addrs()?;
112 let mut last_err = None;
113 for addr in addrs {
114 match f(addr) {
115 Ok(l) => return Ok(l),
116 Err(e) => last_err = Some(e),
117 }
118 }
119 Err(last_err.unwrap_or_else(|| {
120 io::Error::new(
121 io::ErrorKind::InvalidInput,
122 "could not resolve to any address",
123 )
124 }))
125}
126
127impl MptcpStreamExt for TcpStream {
128 type Output = Self;
129
130 fn connect_mptcp_opt<A: ToSocketAddrs>(
131 addr: A,
132 opt: MptcpOpt,
133 ) -> io::Result<MptcpSocket<Self::Output>> {
134 match resolve_each_addr(&addr, |addr| {
135 MptcpSocketBuilder::new_for_addr(addr)?.connect(addr)
136 }) {
137 Ok(sock) => Ok(MptcpSocket::Mptcp(sock.into())),
138 Err(_) if matches!(opt, MptcpOpt::Fallback) => {
139 Ok(MptcpSocket::Tcp(Self::connect(addr)?))
140 }
141 Err(err) => Err(err),
142 }
143 }
144}
145
146impl MptcpExt for TcpStream {}
147
148impl From<MptcpSocket<TcpStream>> for TcpStream {
149 fn from(socket: MptcpSocket<TcpStream>) -> Self {
150 socket.into_socket()
151 }
152}
153
154impl MptcpListenerExt for TcpListener {
155 type Output = Self;
156
157 fn bind_mptcp_opt<A: ToSocketAddrs>(
158 addr: A,
159 opt: MptcpOpt,
160 ) -> io::Result<MptcpSocket<Self::Output>> {
161 match resolve_each_addr(&addr, |addr| {
162 MptcpSocketBuilder::new_for_addr(addr)?.bind(addr)
163 }) {
164 Ok(sock) => Ok(MptcpSocket::Mptcp(sock.into())),
165 Err(_) if matches!(opt, MptcpOpt::Fallback) => Ok(MptcpSocket::Tcp(Self::bind(addr)?)),
166 Err(err) => Err(err),
167 }
168 }
169}
170
171impl From<MptcpSocket<TcpListener>> for TcpListener {
172 fn from(socket: MptcpSocket<TcpListener>) -> Self {
173 socket.into_socket()
174 }
175}
176
177#[cfg(all(test, target_os = "linux"))]
178mod test {
179 use std::net::{IpAddr, Ipv4Addr};
180
181 use crate::sys::is_mptcp_enabled;
182
183 use super::*;
184
185 #[test]
186 fn test_resolve_each_addr() {
187 let addr = "127.0.0.1:80";
188 let result = resolve_each_addr(&addr, |addr| {
189 assert_eq!(addr.port(), 80);
190 assert_eq!(addr.ip(), IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
191 Ok(())
192 });
193 assert!(result.is_ok());
194 }
195
196 #[test]
197 fn test_resolve_each_addr_error() {
198 let addr = "thisisanerror";
199 let result = resolve_each_addr(&addr, |_| Ok(()));
200 assert!(result.is_err());
201 }
202
203 #[test]
204 fn test_mptcp_socket() {
205 let mptcp_enabled = is_mptcp_enabled();
206
207 let listener = TcpListener::bind_mptcp("127.0.0.1:0");
208 if mptcp_enabled {
209 assert!(matches!(listener, Ok(MptcpSocket::Mptcp(..))));
210 } else {
211 assert!(matches!(listener, Ok(MptcpSocket::Tcp(..))));
212 }
213
214 let listener = listener.unwrap().into_socket();
215 let local_addr = listener.local_addr().unwrap();
216
217 let stream = TcpStream::connect_mptcp(local_addr);
218 if mptcp_enabled {
219 assert!(matches!(stream, Ok(MptcpSocket::Mptcp(..))));
220 } else {
221 assert!(matches!(stream, Ok(MptcpSocket::Tcp(..))));
222 }
223 }
224
225 #[test]
226 fn test_mptcp_no_fallback() {
227 let mptcp_enabled = is_mptcp_enabled();
228
229 if mptcp_enabled {
230 return;
232 }
233
234 let listener = TcpListener::bind_mptcp_force("127.0.0.1:0");
235 assert!(listener.is_err());
236
237 let stream = TcpStream::connect_mptcp_force("127.0.0.1:0");
238 assert!(stream.is_err());
239 }
240}