use std::fmt;
use std::io;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock, Weak};
use std::time::Duration;
use dashmap::DashMap;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Mutex, Notify};
use tokio::task::JoinHandle;
use crate::atom::{Atom, AtomTable};
use crate::distribution::resolver::NodeResolver;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ConnectError {
ResolveFailure,
ConnectionRefused,
Timeout,
Io(String),
}
impl fmt::Display for ConnectError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ResolveFailure => formatter.write_str("distribution node resolution failed"),
Self::ConnectionRefused => formatter.write_str("distribution TCP connection refused"),
Self::Timeout => formatter.write_str("distribution TCP connection timed out"),
Self::Io(error) => write!(formatter, "distribution TCP connection failed: {error}"),
}
}
}
impl std::error::Error for ConnectError {}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ConnectionDownReason {
PeerClosed,
ReadError,
WriteError,
ManualDisconnect,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct ConnectionDownEvent {
pub node: Atom,
pub reason: ConnectionDownReason,
}
type ConnectionDownCallback = dyn Fn(ConnectionDownEvent) + Send + Sync + 'static;
type InboundIdentifier = dyn Fn(SocketAddr) -> Option<Atom> + Send + Sync + 'static;
type ControlFrameHandler = dyn Fn(&[u8], &[u8]) + Send + Sync + 'static;
#[derive(Clone, Default)]
pub struct ConnectionDownHook {
callback: Arc<RwLock<Option<Arc<ConnectionDownCallback>>>>,
}
impl ConnectionDownHook {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register<F>(&self, callback: F)
where
F: Fn(ConnectionDownEvent) + Send + Sync + 'static,
{
let mut slot = self
.callback
.write()
.unwrap_or_else(|error| error.into_inner());
*slot = Some(Arc::new(callback));
}
pub fn unregister(&self) {
let mut slot = self
.callback
.write()
.unwrap_or_else(|error| error.into_inner());
*slot = None;
}
#[must_use]
pub fn is_registered(&self) -> bool {
self.callback
.read()
.unwrap_or_else(|error| error.into_inner())
.is_some()
}
fn invoke(&self, event: ConnectionDownEvent) {
let callback = self
.callback
.read()
.unwrap_or_else(|error| error.into_inner())
.clone();
if let Some(callback) = callback {
callback(event);
}
}
}
pub struct DistConnection {
node: Atom,
peer_addr: SocketAddr,
writer: Mutex<OwnedWriteHalf>,
down: AtomicBool,
manager: Weak<ConnectionManagerInner>,
}
impl DistConnection {
fn new(
node: Atom,
peer_addr: SocketAddr,
writer: OwnedWriteHalf,
manager: Weak<ConnectionManagerInner>,
) -> Self {
Self {
node,
peer_addr,
writer: Mutex::new(writer),
down: AtomicBool::new(false),
manager,
}
}
#[must_use]
pub fn node(&self) -> Atom {
self.node
}
#[must_use]
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
#[must_use]
pub fn is_down(&self) -> bool {
self.down.load(Ordering::Acquire)
}
pub async fn write_raw(self: &Arc<Self>, bytes: &[u8]) -> io::Result<()> {
let result = {
let mut writer = self.writer.lock().await;
writer.write_all(bytes).await
};
if result.is_err() {
self.mark_down(ConnectionDownReason::WriteError);
}
result
}
fn mark_down(self: &Arc<Self>, reason: ConnectionDownReason) {
if self.down.swap(true, Ordering::AcqRel) {
return;
}
if let Some(manager) = self.manager.upgrade() {
manager.connection_down(self.node, self, reason);
}
}
}
pub struct AcceptHandle {
local_addr: SocketAddr,
shutdown: Arc<Notify>,
task: JoinHandle<()>,
}
impl AcceptHandle {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn shutdown(&self) {
self.shutdown.notify_waiters();
}
#[must_use]
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
}
impl Drop for AcceptHandle {
fn drop(&mut self) {
self.shutdown.notify_waiters();
self.task.abort();
}
}
struct ConnectionManagerInner {
connections: DashMap<Atom, Arc<DistConnection>>,
atom_table: Arc<AtomTable>,
resolver: Arc<dyn NodeResolver + Send + Sync>,
connect_timeout: Duration,
connection_down_hook: ConnectionDownHook,
inbound_identifier: RwLock<Option<Arc<InboundIdentifier>>>,
control_frame_handler: RwLock<Option<Arc<ControlFrameHandler>>>,
pending_inbound: DashMap<SocketAddr, TcpStream>,
}
impl ConnectionManagerInner {
fn connection_down(
&self,
node: Atom,
connection: &Arc<DistConnection>,
reason: ConnectionDownReason,
) {
let removed = self
.connections
.remove_if(&node, |_, current| Arc::ptr_eq(current, connection))
.is_some();
if removed {
self.connection_down_hook
.invoke(ConnectionDownEvent { node, reason });
}
}
}
#[derive(Clone)]
pub struct ConnectionManager {
inner: Arc<ConnectionManagerInner>,
}
impl ConnectionManager {
#[must_use]
pub fn new(atom_table: Arc<AtomTable>, resolver: Arc<dyn NodeResolver + Send + Sync>) -> Self {
Self::with_connect_timeout(atom_table, resolver, DEFAULT_CONNECT_TIMEOUT)
}
#[must_use]
pub fn with_connect_timeout(
atom_table: Arc<AtomTable>,
resolver: Arc<dyn NodeResolver + Send + Sync>,
connect_timeout: Duration,
) -> Self {
Self {
inner: Arc::new(ConnectionManagerInner {
connections: DashMap::new(),
atom_table,
resolver,
connect_timeout,
connection_down_hook: ConnectionDownHook::new(),
inbound_identifier: RwLock::new(None),
control_frame_handler: RwLock::new(None),
pending_inbound: DashMap::new(),
}),
}
}
#[must_use]
pub fn connect_timeout(&self) -> Duration {
self.inner.connect_timeout
}
#[must_use]
pub fn connection_down_hook(&self) -> ConnectionDownHook {
self.inner.connection_down_hook.clone()
}
pub fn register_connection_down<F>(&self, callback: F)
where
F: Fn(ConnectionDownEvent) + Send + Sync + 'static,
{
self.inner.connection_down_hook.register(callback);
}
pub fn register_inbound_identifier<F>(&self, identifier: F)
where
F: Fn(SocketAddr) -> Option<Atom> + Send + Sync + 'static,
{
let mut slot = self
.inner
.inbound_identifier
.write()
.unwrap_or_else(|error| error.into_inner());
let identifier: Arc<dyn Fn(SocketAddr) -> Option<Atom> + Send + Sync> =
Arc::new(identifier);
*slot = Some(identifier.clone());
drop(slot);
self.identify_pending_inbound(&identifier);
}
pub fn register_control_frame_handler<F>(&self, handler: F)
where
F: Fn(&[u8], &[u8]) + Send + Sync + 'static,
{
let mut slot = self
.inner
.control_frame_handler
.write()
.unwrap_or_else(|error| error.into_inner());
*slot = Some(Arc::new(handler));
}
pub fn unregister_inbound_identifier(&self) {
let mut slot = self
.inner
.inbound_identifier
.write()
.unwrap_or_else(|error| error.into_inner());
*slot = None;
}
#[must_use]
pub fn connection_count(&self) -> usize {
self.inner.connections.len()
}
#[must_use]
pub fn pending_inbound_count(&self) -> usize {
self.inner.pending_inbound.len()
}
#[must_use]
pub fn get_connection(&self, node: Atom) -> Option<Arc<DistConnection>> {
self.inner
.connections
.get(&node)
.map(|entry| Arc::clone(entry.value()))
}
#[must_use]
pub fn connected_nodes(&self) -> Vec<Atom> {
let mut nodes: Vec<_> = self
.inner
.connections
.iter()
.map(|entry| *entry.key())
.collect();
nodes.sort_unstable_by_key(|node| node.index());
nodes
}
pub async fn connect_node(&self, node: Atom) -> bool {
if self.get_connection(node).is_some() {
return true;
}
let Some(node_name) = self.inner.atom_table.resolve(node).map(str::to_owned) else {
return false;
};
self.connect(&node_name).await.is_ok()
}
pub fn disconnect_node(&self, node: Atom) -> bool {
let Some(connection) = self.get_connection(node) else {
return true;
};
connection.mark_down(ConnectionDownReason::ManualDisconnect);
true
}
pub async fn start(
listen_addr: SocketAddr,
resolver: Arc<dyn NodeResolver + Send + Sync>,
) -> io::Result<(Self, AcceptHandle)> {
let manager = Self::new(Arc::new(AtomTable::with_common_atoms()), resolver);
let handle = manager.listen(listen_addr).await?;
Ok((manager, handle))
}
pub async fn listen(&self, listen_addr: SocketAddr) -> io::Result<AcceptHandle> {
let listener = TcpListener::bind(listen_addr).await?;
let local_addr = listener.local_addr()?;
let shutdown = Arc::new(Notify::new());
let task_shutdown = Arc::clone(&shutdown);
let manager = self.clone();
let task = tokio::spawn(async move {
manager.accept_loop(listener, task_shutdown).await;
});
Ok(AcceptHandle {
local_addr,
shutdown,
task,
})
}
pub async fn connect(&self, node_name: &str) -> Result<Arc<DistConnection>, ConnectError> {
let addr = self
.inner
.resolver
.resolve(node_name)
.await
.map_err(|_| ConnectError::ResolveFailure)?;
let stream = match tokio::time::timeout(
self.inner.connect_timeout,
TcpStream::connect(addr),
)
.await
{
Ok(Ok(stream)) => stream,
Ok(Err(error)) if error.kind() == io::ErrorKind::ConnectionRefused => {
return Err(ConnectError::ConnectionRefused);
}
Ok(Err(error)) => return Err(ConnectError::Io(error.to_string())),
Err(_) => return Err(ConnectError::Timeout),
};
let node = self.inner.atom_table.intern(node_name);
let peer_addr = stream.peer_addr().unwrap_or(addr);
Ok(self.register_connection(node, peer_addr, stream))
}
fn register_connection(
&self,
node: Atom,
peer_addr: SocketAddr,
stream: TcpStream,
) -> Arc<DistConnection> {
let (read_half, write_half) = stream.into_split();
let connection = Arc::new(DistConnection::new(
node,
peer_addr,
write_half,
Arc::downgrade(&self.inner),
));
self.inner.connections.insert(node, Arc::clone(&connection));
self.spawn_read_lifecycle(Arc::clone(&connection), read_half);
connection
}
#[cfg(test)]
pub(crate) fn register_test_connection(
&self,
node: Atom,
peer_addr: SocketAddr,
stream: std::net::TcpStream,
) -> io::Result<Arc<DistConnection>> {
stream.set_nonblocking(true)?;
let stream = TcpStream::from_std(stream)?;
Ok(self.register_connection(node, peer_addr, stream))
}
fn spawn_read_lifecycle(&self, connection: Arc<DistConnection>, mut read_half: OwnedReadHalf) {
let manager = Arc::clone(&self.inner);
tokio::spawn(async move {
loop {
let mut header = [0_u8; 8];
match read_half.read_exact(&mut header).await {
Ok(0) => {
connection.mark_down(ConnectionDownReason::PeerClosed);
break;
}
Ok(_) => {
let control_len =
u32::from_be_bytes([header[0], header[1], header[2], header[3]])
as usize;
let payload_len =
u32::from_be_bytes([header[4], header[5], header[6], header[7]])
as usize;
let Some(total_len) = control_len.checked_add(payload_len) else {
connection.mark_down(ConnectionDownReason::ReadError);
break;
};
let mut frame = vec![0_u8; total_len];
if read_half.read_exact(&mut frame).await.is_err() {
connection.mark_down(ConnectionDownReason::ReadError);
break;
}
let handler = manager
.control_frame_handler
.read()
.unwrap_or_else(|error| error.into_inner())
.clone();
if let Some(handler) = handler {
let (control, payload) = frame.split_at(control_len);
handler(control, payload);
}
}
Err(_) => {
connection.mark_down(ConnectionDownReason::ReadError);
break;
}
}
}
});
}
async fn accept_loop(&self, listener: TcpListener, shutdown: Arc<Notify>) {
loop {
tokio::select! {
_ = shutdown.notified() => {
break;
}
accepted = listener.accept() => {
let Ok((stream, peer_addr)) = accepted else {
continue;
};
self.handle_accepted(stream, peer_addr);
}
}
}
}
fn handle_accepted(&self, stream: TcpStream, peer_addr: SocketAddr) {
let identifier = self
.inner
.inbound_identifier
.read()
.unwrap_or_else(|error| error.into_inner())
.clone();
if let Some(node) = identifier
.as_ref()
.and_then(|identifier| identifier(peer_addr))
{
self.register_connection(node, peer_addr, stream);
} else {
self.inner.pending_inbound.insert(peer_addr, stream);
}
}
fn identify_pending_inbound(&self, identifier: &Arc<InboundIdentifier>) {
let identified: Vec<_> = self
.inner
.pending_inbound
.iter()
.filter_map(|entry| identifier(*entry.key()).map(|node| (*entry.key(), node)))
.collect();
for (peer_addr, node) in identified {
if let Some((_, stream)) = self.inner.pending_inbound.remove(&peer_addr) {
self.register_connection(node, peer_addr, stream);
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::net::TcpListener;
use super::*;
use crate::distribution::resolver::StaticResolver;
fn manager_with_resolver(resolver: Arc<StaticResolver>) -> ConnectionManager {
ConnectionManager::new(Arc::new(AtomTable::with_common_atoms()), resolver)
}
#[tokio::test]
async fn empty_manager_has_no_connections() {
let manager = manager_with_resolver(Arc::new(StaticResolver::new(
std::collections::HashMap::new(),
)));
let node = manager.inner.atom_table.intern("missing@127.0.0.1");
assert_eq!(manager.connection_count(), 0);
assert!(manager.get_connection(node).is_none());
}
#[tokio::test]
async fn outbound_connect_inserts_table_entry() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.unwrap_or_else(|error| {
panic!("failed to bind local listener: {error}");
});
let addr = listener.local_addr().unwrap_or_else(|error| {
panic!("failed to inspect local listener: {error}");
});
tokio::spawn(async move {
let _accepted = listener.accept().await;
});
let resolver = Arc::new(StaticResolver::new(std::collections::HashMap::from([(
"remote@127.0.0.1".to_string(),
addr,
)])));
let manager = manager_with_resolver(resolver);
let connection = manager
.connect("remote@127.0.0.1")
.await
.unwrap_or_else(|error| panic!("connect failed: {error}"));
let node = manager.inner.atom_table.intern("remote@127.0.0.1");
assert!(Arc::ptr_eq(
&connection,
&manager
.get_connection(node)
.expect("connection should be present"),
));
}
#[tokio::test]
async fn connect_node_is_idempotent_and_lists_node() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.unwrap_or_else(|error| panic!("failed to bind local listener: {error}"));
let addr = listener
.local_addr()
.unwrap_or_else(|error| panic!("failed to inspect local listener: {error}"));
tokio::spawn(async move {
let _accepted = listener.accept().await;
});
let resolver = Arc::new(StaticResolver::new(std::collections::HashMap::from([(
"remote@127.0.0.1".to_string(),
addr,
)])));
let manager = manager_with_resolver(resolver);
let node = manager.inner.atom_table.intern("remote@127.0.0.1");
assert!(manager.connect_node(node).await);
assert!(manager.connect_node(node).await);
assert_eq!(manager.connected_nodes(), vec![node]);
assert_eq!(manager.connection_count(), 1);
}
#[tokio::test]
async fn connect_node_returns_false_for_unresolved_node() {
let manager = manager_with_resolver(Arc::new(StaticResolver::new(
std::collections::HashMap::new(),
)));
let node = manager.inner.atom_table.intern("missing@127.0.0.1");
assert!(!manager.connect_node(node).await);
assert!(manager.connected_nodes().is_empty());
}
#[tokio::test]
async fn inbound_connection_waits_for_identification_seam() {
let resolver = Arc::new(StaticResolver::new(std::collections::HashMap::new()));
let manager = manager_with_resolver(resolver);
let accept = manager
.listen("127.0.0.1:0".parse().unwrap_or_else(|error| {
panic!("failed to parse listen address: {error}");
}))
.await
.unwrap_or_else(|error| panic!("failed to start accept loop: {error}"));
let pending_stream = TcpStream::connect(accept.local_addr())
.await
.unwrap_or_else(|error| panic!("failed to open pending inbound stream: {error}"));
tokio::time::sleep(Duration::from_millis(25)).await;
assert_eq!(manager.connection_count(), 0);
assert_eq!(manager.pending_inbound_count(), 1);
let node = manager.inner.atom_table.intern("client@127.0.0.1");
manager.register_inbound_identifier(move |_| Some(node));
tokio::time::sleep(Duration::from_millis(25)).await;
assert!(manager.get_connection(node).is_some());
assert_eq!(manager.pending_inbound_count(), 0);
drop(pending_stream);
}
#[tokio::test]
async fn dropping_peer_removes_connection_and_notifies_once() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.unwrap_or_else(|error| {
panic!("failed to bind local listener: {error}");
});
let addr = listener.local_addr().unwrap_or_else(|error| {
panic!("failed to inspect local listener: {error}");
});
let accepted = tokio::spawn(async move { listener.accept().await });
let resolver = Arc::new(StaticResolver::new(std::collections::HashMap::from([(
"remote@127.0.0.1".to_string(),
addr,
)])));
let manager = manager_with_resolver(resolver);
let callback_count = Arc::new(AtomicUsize::new(0));
let callback_count_for_hook = Arc::clone(&callback_count);
manager.register_connection_down(move |_| {
callback_count_for_hook.fetch_add(1, Ordering::SeqCst);
});
let node = manager.inner.atom_table.intern("remote@127.0.0.1");
let _connection = manager
.connect("remote@127.0.0.1")
.await
.unwrap_or_else(|error| panic!("connect failed: {error}"));
let Ok(Ok((remote_stream, _))) = accepted.await else {
panic!("listener did not accept test connection");
};
drop(remote_stream);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(manager.get_connection(node).is_none());
assert_eq!(callback_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn manual_disconnect_removes_connection_and_notifies_once() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.unwrap_or_else(|error| panic!("failed to bind local listener: {error}"));
let addr = listener
.local_addr()
.unwrap_or_else(|error| panic!("failed to inspect local listener: {error}"));
tokio::spawn(async move {
let _accepted = listener.accept().await;
tokio::time::sleep(Duration::from_millis(100)).await;
});
let resolver = Arc::new(StaticResolver::new(std::collections::HashMap::from([(
"remote@127.0.0.1".to_string(),
addr,
)])));
let manager = manager_with_resolver(resolver);
let callback_count = Arc::new(AtomicUsize::new(0));
let callback_count_for_hook = Arc::clone(&callback_count);
manager.register_connection_down(move |event| {
assert_eq!(event.reason, ConnectionDownReason::ManualDisconnect);
callback_count_for_hook.fetch_add(1, Ordering::SeqCst);
});
let node = manager.inner.atom_table.intern("remote@127.0.0.1");
assert!(manager.connect_node(node).await);
assert!(manager.disconnect_node(node));
assert!(manager.disconnect_node(node));
assert!(manager.get_connection(node).is_none());
assert!(manager.connected_nodes().is_empty());
assert_eq!(callback_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn write_error_removes_connection_and_notifies_once() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.unwrap_or_else(|error| {
panic!("failed to bind local listener: {error}");
});
let addr = listener.local_addr().unwrap_or_else(|error| {
panic!("failed to inspect local listener: {error}");
});
let accepted = tokio::spawn(async move { listener.accept().await });
let resolver = Arc::new(StaticResolver::new(std::collections::HashMap::from([(
"remote@127.0.0.1".to_string(),
addr,
)])));
let manager = manager_with_resolver(resolver);
let callback_count = Arc::new(AtomicUsize::new(0));
let callback_count_for_hook = Arc::clone(&callback_count);
manager.register_connection_down(move |_| {
callback_count_for_hook.fetch_add(1, Ordering::SeqCst);
});
let node = manager.inner.atom_table.intern("remote@127.0.0.1");
let connection = manager
.connect("remote@127.0.0.1")
.await
.unwrap_or_else(|error| panic!("connect failed: {error}"));
let Ok(Ok((remote_stream, _))) = accepted.await else {
panic!("listener did not accept test connection");
};
drop(remote_stream);
for _ in 0..8 {
if connection.write_raw(b"probe").await.is_err() {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
tokio::time::sleep(Duration::from_millis(25)).await;
assert!(manager.get_connection(node).is_none());
assert_eq!(callback_count.load(Ordering::SeqCst), 1);
}
}