use std::collections::{vec_deque, VecDeque};
use std::{fmt, iter::FromIterator, iter::FusedIterator, net::SocketAddr};
use ntex_util::future::Either;
pub trait Address: Unpin + 'static {
fn host(&self) -> &str;
fn port(&self) -> Option<u16>;
fn addr(&self) -> Option<SocketAddr> {
None
}
}
impl Address for String {
fn host(&self) -> &str {
self
}
fn port(&self) -> Option<u16> {
None
}
}
impl Address for &'static str {
fn host(&self) -> &str {
self
}
fn port(&self) -> Option<u16> {
None
}
}
impl Address for SocketAddr {
fn host(&self) -> &str {
""
}
fn port(&self) -> Option<u16> {
None
}
fn addr(&self) -> Option<SocketAddr> {
Some(*self)
}
}
#[derive(Eq, PartialEq, Debug, Hash)]
pub struct Connect<T> {
pub(super) req: T,
pub(super) port: u16,
pub(super) addr: Option<Either<SocketAddr, VecDeque<SocketAddr>>>,
}
impl<T: Address> Connect<T> {
pub fn new(req: T) -> Connect<T> {
let (_, port) = parse(req.host());
Connect {
req,
port: port.unwrap_or(0),
addr: None,
}
}
pub fn with(req: T, addr: SocketAddr) -> Connect<T> {
Connect {
req,
port: 0,
addr: Some(Either::Left(addr)),
}
}
pub fn set_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn set_addr(mut self, addr: Option<SocketAddr>) -> Self {
if let Some(addr) = addr {
self.addr = Some(Either::Left(addr));
}
self
}
pub fn set_addrs<I>(mut self, addrs: I) -> Self
where
I: IntoIterator<Item = SocketAddr>,
{
let mut addrs = VecDeque::from_iter(addrs);
self.addr = if addrs.len() < 2 {
addrs.pop_front().map(Either::Left)
} else {
Some(Either::Right(addrs))
};
self
}
pub fn host(&self) -> &str {
self.req.host()
}
pub fn port(&self) -> u16 {
self.req.port().unwrap_or(self.port)
}
pub fn addrs(&self) -> ConnectAddrsIter<'_> {
if let Some(addr) = self.req.addr() {
ConnectAddrsIter {
inner: Either::Left(Some(addr)),
}
} else {
let inner = match self.addr {
None => Either::Left(None),
Some(Either::Left(addr)) => Either::Left(Some(addr)),
Some(Either::Right(ref addrs)) => Either::Right(addrs.iter()),
};
ConnectAddrsIter { inner }
}
}
pub fn take_addrs(&mut self) -> ConnectTakeAddrsIter {
if let Some(addr) = self.req.addr() {
ConnectTakeAddrsIter {
inner: Either::Left(Some(addr)),
}
} else {
let inner = match self.addr.take() {
None => Either::Left(None),
Some(Either::Left(addr)) => Either::Left(Some(addr)),
Some(Either::Right(addrs)) => Either::Right(addrs.into_iter()),
};
ConnectTakeAddrsIter { inner }
}
}
pub fn get_ref(&self) -> &T {
&self.req
}
}
impl<T: Address> From<T> for Connect<T> {
fn from(addr: T) -> Self {
Connect::new(addr)
}
}
impl<T: Address> fmt::Display for Connect<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.host(), self.port())
}
}
#[derive(Clone)]
pub struct ConnectAddrsIter<'a> {
inner: Either<Option<SocketAddr>, vec_deque::Iter<'a, SocketAddr>>,
}
impl Iterator for ConnectAddrsIter<'_> {
type Item = SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
match self.inner {
Either::Left(ref mut opt) => opt.take(),
Either::Right(ref mut iter) => iter.next().copied(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self.inner {
Either::Left(Some(_)) => (1, Some(1)),
Either::Left(None) => (0, Some(0)),
Either::Right(ref iter) => iter.size_hint(),
}
}
}
impl fmt::Debug for ConnectAddrsIter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.clone()).finish()
}
}
impl ExactSizeIterator for ConnectAddrsIter<'_> {}
impl FusedIterator for ConnectAddrsIter<'_> {}
#[derive(Debug)]
pub struct ConnectTakeAddrsIter {
inner: Either<Option<SocketAddr>, vec_deque::IntoIter<SocketAddr>>,
}
impl Iterator for ConnectTakeAddrsIter {
type Item = SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
match self.inner {
Either::Left(ref mut opt) => opt.take(),
Either::Right(ref mut iter) => iter.next(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self.inner {
Either::Left(Some(_)) => (1, Some(1)),
Either::Left(None) => (0, Some(0)),
Either::Right(ref iter) => iter.size_hint(),
}
}
}
impl ExactSizeIterator for ConnectTakeAddrsIter {}
impl FusedIterator for ConnectTakeAddrsIter {}
fn parse(host: &str) -> (&str, Option<u16>) {
let mut parts_iter = host.splitn(2, ':');
if let Some(host) = parts_iter.next() {
let port_str = parts_iter.next().unwrap_or("");
if let Ok(port) = port_str.parse::<u16>() {
(host, Some(port))
} else {
(host, None)
}
} else {
(host, None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn address() {
assert_eq!("test".host(), "test");
assert_eq!("test".port(), None);
let s = "test".to_string();
assert_eq!(s.host(), "test");
assert_eq!(s.port(), None);
}
#[test]
fn connect() {
let mut connect = Connect::new("www.rust-lang.org");
assert_eq!(connect.host(), "www.rust-lang.org");
assert_eq!(connect.port(), 0);
assert_eq!(*connect.get_ref(), "www.rust-lang.org");
connect = connect.set_port(80);
assert_eq!(connect.port(), 80);
let addrs = connect.addrs().clone();
assert_eq!(format!("{:?}", addrs), "[]");
assert!(connect.addrs().next().is_none());
let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
connect = connect.set_addrs(vec![addr]);
let addrs = connect.addrs().clone();
assert_eq!(format!("{:?}", addrs), "[127.0.0.1:8080]");
let addrs: Vec<_> = connect.take_addrs().collect();
assert_eq!(addrs.len(), 1);
assert!(addrs.contains(&addr));
let addr2: SocketAddr = "127.0.0.1:8081".parse().unwrap();
connect = connect.set_addrs(vec![addr, addr2]);
let addrs: Vec<_> = connect.addrs().collect();
assert_eq!(addrs.len(), 2);
assert!(addrs.contains(&addr));
assert!(addrs.contains(&addr2));
let addrs: Vec<_> = connect.take_addrs().collect();
assert_eq!(addrs.len(), 2);
assert!(addrs.contains(&addr));
assert!(addrs.contains(&addr2));
assert!(connect.addrs().next().is_none());
connect = connect.set_addrs(vec![addr]);
assert_eq!(format!("{}", connect), "www.rust-lang.org:80");
let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let mut connect = Connect::new(addr);
assert_eq!(connect.host(), "");
assert_eq!(connect.port(), 0);
let addrs: Vec<_> = connect.addrs().collect();
assert_eq!(addrs.len(), 1);
assert!(addrs.contains(&addr));
let addrs: Vec<_> = connect.take_addrs().collect();
assert_eq!(addrs.len(), 1);
assert!(addrs.contains(&addr));
}
}