use std::net::{TcpStream, ToSocketAddrs};
use std::sync::mpsc::channel;
use std::thread::spawn;
use super::{Error, Options};
pub fn connect(host: &str, port: u16, opts: &Options) -> Result<TcpStream, Error> {
let timeout = opts.connect_timeout;
let delay = opts.connect_delay;
let mut addrs: Vec<_> = (host, port)
.to_socket_addrs()?
.map(|addr| (0, addr))
.collect();
if let [(_prio, addr)] = addrs.as_slice() {
return TcpStream::connect_timeout(addr, timeout).map_err(Error::from);
}
addrs
.iter_mut()
.filter(|(_prio, addr)| addr.is_ipv6())
.enumerate()
.for_each(|(idx, (prio, _addr))| *prio = 2 * idx);
addrs
.iter_mut()
.filter(|(_prio, addr)| addr.is_ipv4())
.enumerate()
.for_each(|(idx, (prio, _addr))| *prio = 2 * idx + 1);
addrs.sort_unstable_by_key(|(prio, _addr)| *prio);
let mut first_err = None;
let (tx, rx) = channel();
for (_prio, addr) in addrs {
let tx = tx.clone();
spawn(move || {
let _ = tx.send(TcpStream::connect_timeout(&addr, timeout));
});
if let Ok(res) = rx.recv_timeout(delay) {
match res {
Ok(stream) => return Ok(stream),
Err(err) => first_err = first_err.or(Some(err)),
}
}
}
drop(tx);
for res in rx.iter() {
match res {
Ok(stream) => return Ok(stream),
Err(err) => first_err = first_err.or(Some(err)),
}
}
Err(first_err.unwrap().into())
}