use std::collections::{HashMap, VecDeque};
use std::io;
use std::io::ErrorKind;
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::time::Duration;
use crate::service::dns::{BlockingDnsResolver, DnsQuery, DnsResolver};
use crate::service::endpoint::{Context, DisconnectReason, Endpoint, EndpointWithContext};
use crate::service::node::IONode;
use crate::service::select::{Selector, SelectorToken};
use crate::service::time::{SystemTimeClockSource, TimeSource};
use crate::stream::ConnectionInfoProvider;
pub mod dns;
pub mod endpoint;
mod node;
pub mod select;
pub mod time;
const ENDPOINT_CREATION_THROTTLE_NS: u64 = Duration::from_secs(1).as_nanos() as u64;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
#[repr(transparent)]
pub struct Handle(SelectorToken);
pub struct IOService<S: Selector, E, C, TS, D: DnsResolver> {
selector: S,
pending_endpoints: VecDeque<(Handle, D::Query, u64, E)>,
io_nodes: HashMap<SelectorToken, IONode<S::Target, E>>,
next_endpoint_create_time_ns: u64,
context: PhantomData<C>,
auto_disconnect: Option<Box<dyn Fn() -> Duration>>,
time_source: TS,
dns_resolver: D,
dns_query_timeout_ns: Option<u64>,
}
pub trait IntoIOService<E> {
fn into_io_service(self) -> IOService<Self, E, (), SystemTimeClockSource, BlockingDnsResolver>
where
Self: Selector,
Self: Sized;
}
pub trait IntoIOServiceWithContext<E, C: Context> {
fn into_io_service_with_context(self) -> IOService<Self, E, C, SystemTimeClockSource, BlockingDnsResolver>
where
Self: Selector,
Self: Sized;
}
impl<S: Selector, E, C, TS, D: DnsResolver> IOService<S, E, C, TS, D> {
pub fn new(selector: S, time_source: TS, dns_resolver: D) -> IOService<S, E, C, TS, D> {
Self {
selector,
pending_endpoints: VecDeque::new(),
io_nodes: HashMap::new(),
next_endpoint_create_time_ns: 0,
context: PhantomData,
auto_disconnect: None,
time_source,
dns_resolver,
dns_query_timeout_ns: None,
}
}
pub fn with_auto_disconnect(self, auto_disconnect: Duration) -> IOService<S, E, C, TS, D> {
self.with_auto_disconnect_supplier(move || auto_disconnect)
}
pub fn with_auto_disconnect_supplier<F>(self, f: F) -> IOService<S, E, C, TS, D>
where
F: Fn() -> Duration + 'static,
{
Self {
auto_disconnect: Some(Box::new(f)),
..self
}
}
pub fn with_dns_query_timeout(self, timeout: Duration) -> IOService<S, E, C, TS, D> {
Self {
dns_query_timeout_ns: Some(timeout.as_nanos() as u64),
..self
}
}
pub fn with_time_source<T: TimeSource>(self, time_source: T) -> IOService<S, E, C, T, D> {
IOService {
time_source,
pending_endpoints: Default::default(),
context: self.context,
auto_disconnect: self.auto_disconnect,
io_nodes: Default::default(),
next_endpoint_create_time_ns: self.next_endpoint_create_time_ns,
selector: self.selector,
dns_resolver: self.dns_resolver,
dns_query_timeout_ns: self.dns_query_timeout_ns,
}
}
pub fn with_dns_resolver<DR: DnsResolver>(self, dns_resolver: DR) -> IOService<S, E, C, TS, DR> {
IOService {
time_source: self.time_source,
pending_endpoints: Default::default(),
context: self.context,
auto_disconnect: self.auto_disconnect,
io_nodes: Default::default(),
next_endpoint_create_time_ns: self.next_endpoint_create_time_ns,
selector: self.selector,
dns_resolver,
dns_query_timeout_ns: self.dns_query_timeout_ns,
}
}
pub fn register(&mut self, endpoint: E) -> io::Result<Handle>
where
E: ConnectionInfoProvider,
TS: TimeSource,
{
let handle = Handle(self.selector.next_token());
let info = endpoint.connection_info();
let query = self.dns_resolver.new_query(info.host(), info.port())?;
let now = self.time_source.current_time_nanos();
self.pending_endpoints.push_back((handle, query, now, endpoint));
Ok(handle)
}
pub fn register_with_factory<F>(&mut self, endpoint_factory: F) -> io::Result<Handle>
where
E: ConnectionInfoProvider,
TS: TimeSource,
F: FnOnce(Handle) -> io::Result<E>,
{
let handle = Handle(self.selector.next_token());
let endpoint = endpoint_factory(handle)?;
let info = endpoint.connection_info();
let query = self.dns_resolver.new_query(info.host(), info.port())?;
let now = self.time_source.current_time_nanos();
self.pending_endpoints.push_back((handle, query, now, endpoint));
Ok(handle)
}
pub fn deregister(&mut self, handle: Handle) -> Option<E> {
match self.io_nodes.remove(&handle.0) {
Some(io_node) => Some(io_node.into_endpoint().1),
None => {
let mut index_to_remove = None;
for (index, endpoint) in self.pending_endpoints.iter().enumerate() {
if endpoint.0 == handle {
index_to_remove = Some(index);
break;
}
}
if let Some(index_to_remove) = index_to_remove {
self.pending_endpoints
.remove(index_to_remove)
.map(|(_, _, _, endpoint)| endpoint)
} else {
None
}
}
}
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = (Handle, &S::Target, &E)> {
self.io_nodes.values().map(|io_node| {
let (stream, (handle, endpoint)) = io_node.as_parts();
(*handle, stream, endpoint)
})
}
#[inline]
pub fn iter_mut(&mut self) -> impl Iterator<Item = (Handle, &mut S::Target, &mut E)> {
self.io_nodes.values_mut().map(|io_node| {
let (stream, (handle, endpoint)) = io_node.as_parts_mut();
(*handle, stream, endpoint)
})
}
#[inline]
pub fn pending(&self) -> impl Iterator<Item = (&Handle, &E)> {
self.pending_endpoints
.iter()
.map(|(handle, _, _, endpoint)| (handle, endpoint))
}
#[inline]
fn resolve_dns(&self, query: &mut impl DnsQuery, created_time_ns: u64) -> io::Result<Option<SocketAddr>>
where
TS: TimeSource,
{
if let Some(dns_query_timeout) = self.dns_query_timeout_ns {
let now = self.time_source.current_time_nanos();
if now > created_time_ns + dns_query_timeout {
return Err(io::Error::new(ErrorKind::TimedOut, "dns resolution timed out"));
}
}
match query.poll() {
Ok(addrs) => {
let addr = addrs
.into_iter()
.next()
.ok_or_else(|| io::Error::other("dns resolution dio not return any address"))?;
Ok(Some(addr))
}
Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(None),
Err(err) => Err(err),
}
}
#[cold]
fn check_pending_endpoints<F>(&mut self, create_target: F) -> io::Result<()>
where
E: ConnectionInfoProvider,
TS: TimeSource,
F: FnOnce(&mut E, SocketAddr) -> io::Result<Option<<S as Selector>::Target>>,
{
let current_time_ns = self.time_source.current_time_nanos();
if current_time_ns > self.next_endpoint_create_time_ns {
if let Some((handle, mut query, query_time_ns, mut endpoint)) = self.pending_endpoints.pop_front() {
if let Some(addr) = self.resolve_dns(&mut query, query_time_ns)? {
match create_target(&mut endpoint, addr)? {
Some(stream) => {
let ttl = self.auto_disconnect.as_ref().map(|auto_disconnect| auto_disconnect());
let mut io_node = IONode::new(stream, handle, endpoint, ttl, &self.time_source, addr);
self.selector.register(handle.0, &mut io_node)?;
self.io_nodes.insert(handle.0, io_node);
}
None => {
let info = endpoint.connection_info();
let query = self.dns_resolver.new_query(info.host(), info.port())?;
let now = self.time_source.current_time_nanos();
self.pending_endpoints.push_back((handle, query, now, endpoint))
}
}
} else {
self.pending_endpoints
.push_back((handle, query, query_time_ns, endpoint))
}
}
self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS;
}
Ok(())
}
}
impl<S, E, TS, D> IOService<S, E, (), TS, D>
where
S: Selector,
E: Endpoint<Target = S::Target>,
TS: TimeSource,
D: DnsResolver,
{
pub fn poll<F>(&mut self, mut action: F) -> io::Result<()>
where
F: FnMut(&mut E::Target, &mut E) -> io::Result<()>,
{
if !self.pending_endpoints.is_empty() {
self.check_pending_endpoints(|endpoint, addr| endpoint.create_target(addr))?;
}
self.selector.poll(&mut self.io_nodes)?;
if let Some(auto_disconnect) = self.auto_disconnect.as_ref() {
let current_time_ns = self.time_source.current_time_nanos();
self.io_nodes.retain(|_token, io_node| {
let force_disconnect = current_time_ns > io_node.disconnect_time_ns;
if force_disconnect {
return if io_node.as_endpoint_mut().1.can_auto_disconnect() {
self.selector.unregister(io_node).unwrap();
let (handle, mut endpoint) = io_node.endpoint.take().unwrap();
if endpoint.can_recreate(DisconnectReason::auto_disconnect(io_node.ttl)) {
let info = endpoint.connection_info();
let query = self.dns_resolver.new_query(info.host(), info.port()).unwrap();
let now = self.time_source.current_time_nanos();
self.pending_endpoints.push_back((handle, query, now, endpoint));
} else {
panic!("unrecoverable error when polling endpoint");
}
false
} else {
let extend = auto_disconnect().as_nanos() as u64;
io_node.disconnect_time_ns = io_node.disconnect_time_ns.saturating_add(extend);
true
};
}
true
});
}
self.io_nodes.retain(|_token, io_node| {
let (target, (_, endpoint)) = io_node.as_parts_mut();
if let Err(err) = action(target, endpoint) {
self.selector.unregister(io_node).unwrap();
let (handle, mut endpoint) = io_node.endpoint.take().unwrap();
if endpoint.can_recreate(DisconnectReason::other(err)) {
let info = endpoint.connection_info();
let query = self.dns_resolver.new_query(info.host(), info.port()).unwrap();
let now = self.time_source.current_time_nanos();
self.pending_endpoints.push_back((handle, query, now, endpoint));
} else {
panic!("unrecoverable error when polling endpoint");
}
return false;
}
true
});
Ok(())
}
pub fn dispatch<F, T>(&mut self, handle: Handle, mut action: F) -> io::Result<Option<T>>
where
F: FnMut(&mut E::Target, &mut E) -> std::io::Result<T>,
{
match self.io_nodes.get_mut(&handle.0) {
Some(io_node) => {
let (stream, (_, endpoint)) = io_node.as_parts_mut();
let result = action(stream, endpoint)?;
Ok(Some(result))
}
None => Ok(None),
}
}
}
impl<S, E, C, TS, D> IOService<S, E, C, TS, D>
where
S: Selector,
C: Context,
E: EndpointWithContext<C, Target = S::Target>,
TS: TimeSource,
D: DnsResolver,
{
pub fn poll<F>(&mut self, ctx: &mut C, mut action: F) -> io::Result<()>
where
F: FnMut(&mut E::Target, &mut C, &mut E) -> io::Result<()>,
{
if !self.pending_endpoints.is_empty() {
self.check_pending_endpoints(|endpoint, addr| endpoint.create_target(addr, ctx))?;
}
self.selector.poll(&mut self.io_nodes)?;
if let Some(auto_disconnect) = self.auto_disconnect.as_ref() {
let current_time_ns = self.time_source.current_time_nanos();
self.io_nodes.retain(|_token, io_node| {
let force_disconnect = current_time_ns > io_node.disconnect_time_ns;
if force_disconnect {
return if io_node.as_endpoint_mut().1.can_auto_disconnect(ctx) {
self.selector.unregister(io_node).unwrap();
let (handle, mut endpoint) = io_node.endpoint.take().unwrap();
if endpoint.can_recreate(DisconnectReason::auto_disconnect(io_node.ttl), ctx) {
let info = endpoint.connection_info();
let query = self.dns_resolver.new_query(info.host(), info.port()).unwrap();
let now = self.time_source.current_time_nanos();
self.pending_endpoints.push_back((handle, query, now, endpoint));
} else {
panic!("unrecoverable error when polling endpoint");
}
false
} else {
let extend = auto_disconnect().as_nanos() as u64;
io_node.disconnect_time_ns = io_node.disconnect_time_ns.saturating_add(extend);
true
};
}
true
});
}
self.io_nodes.retain(|_token, io_node| {
let (target, (_, endpoint)) = io_node.as_parts_mut();
if let Err(err) = action(target, ctx, endpoint) {
self.selector.unregister(io_node).unwrap();
let (handle, mut endpoint) = io_node.endpoint.take().unwrap();
if endpoint.can_recreate(DisconnectReason::other(err), ctx) {
let info = endpoint.connection_info();
let query = self.dns_resolver.new_query(info.host(), info.port()).unwrap();
let now = self.time_source.current_time_nanos();
self.pending_endpoints.push_back((handle, query, now, endpoint));
} else {
panic!("unrecoverable error when polling endpoint");
}
return false;
}
true
});
Ok(())
}
pub fn dispatch<F, T>(&mut self, handle: Handle, ctx: &mut C, mut action: F) -> io::Result<Option<T>>
where
F: FnMut(&mut E::Target, &mut E, &mut C) -> std::io::Result<T>,
{
match self.io_nodes.get_mut(&handle.0) {
Some(io_node) => {
let (stream, (_, endpoint)) = io_node.as_parts_mut();
let result = action(stream, endpoint, ctx)?;
Ok(Some(result))
}
None => Ok(None),
}
}
}