use std::{
collections::{BTreeMap, VecDeque},
net::SocketAddr,
num::NonZeroUsize,
sync::{Arc, atomic::AtomicU16},
task::{Poll, Waker},
time::Duration,
};
use amaru_kernel::{NonEmptyBytes, Peer};
use amaru_ouroboros::{ConnectionId, ConnectionProvider, ToSocketAddrs};
use parking_lot::Mutex;
use pure_stage::BoxFuture;
use tokio_util::bytes::{Buf, Bytes, BytesMut};
#[derive(Clone, Default)]
pub struct InMemoryConnectionProvider {
inner: Arc<Inner>,
}
impl ConnectionProvider for InMemoryConnectionProvider {
fn listen(&self, addr: SocketAddr) -> BoxFuture<'static, std::io::Result<SocketAddr>> {
let inner = self.inner.clone();
let provider = self.clone();
Box::pin(async move {
tracing::debug!("InMemoryConnectionProvider::listen for {addr}");
{
let mut listeners = inner.listeners.lock();
if let Some(old_listener) = listeners.remove(&addr) {
tracing::info!(%addr, "removing existing listener for restart");
let mut connections = inner.connection_endpoints.lock();
for pending in old_listener.pending_connects {
if let Some(initiator_conn_id) = *pending.responder_endpoint.peer_conn_id.lock()
&& let Some(initiator_endpoint) = connections.remove(&initiator_conn_id)
{
if let Some(waker) = initiator_endpoint.recv_waker.lock().take() {
waker.wake();
}
tracing::debug!(
"cleaned up orphaned initiator connection {initiator_conn_id} during listener restart"
);
}
}
}
listeners.insert(addr, Listener { pending_connects: VecDeque::new() });
}
provider.wake_connect_wakers(&addr);
tracing::debug!("listener bound to {addr}");
Ok(addr)
})
}
fn accept(&self, listener_addr: SocketAddr) -> BoxFuture<'static, std::io::Result<(Peer, ConnectionId)>> {
let inner = self.inner.clone();
Box::pin(std::future::poll_fn(move |cx| {
let pending = {
let mut listeners = inner.listeners.lock();
listeners.get_mut(&listener_addr).and_then(|l| l.pending_connects.pop_front())
};
let Some(pending) = pending else {
inner.accept_wakers.lock().push(cx.waker().clone());
return Poll::Pending;
};
let conn_id = inner.register_endpoint(pending.responder_endpoint);
*pending.initiator_peer_conn_id_slot.lock() = Some(conn_id);
tracing::debug!("accepted in-memory connection from {} with id {conn_id}", pending.initiator_addr);
Poll::Ready(Ok((Peer::from_addr(&pending.initiator_addr), conn_id)))
}))
}
fn connect(&self, addr: Vec<SocketAddr>, _timeout: Duration) -> BoxFuture<'static, std::io::Result<ConnectionId>> {
let inner = self.inner.clone();
let provider = self.clone();
tracing::debug!("InMemoryConnectionProvider::connect called for {addr:?}");
Box::pin(std::future::poll_fn(move |cx| {
if addr.is_empty() {
return Poll::Ready(Err(std::io::Error::other("no addresses provided")));
}
let target_addr = {
let listeners = inner.listeners.lock();
addr.iter().copied().find(|a| listeners.contains_key(a))
};
let Some(target_addr) = target_addr else {
let mut wakers = inner.connect_wakers.lock();
for a in &addr {
wakers.entry(*a).or_default().push(cx.waker().clone());
}
return Poll::Pending;
};
let initiator_to_responder = Arc::new(Mutex::new(VecDeque::new()));
let responder_to_initiator = Arc::new(Mutex::new(VecDeque::new()));
let initiator_peer_conn_id_slot = Arc::new(Mutex::new(None));
let responder_peer_conn_id_slot = Arc::new(Mutex::new(None));
let initiator_endpoint = ConnectionEndpoint {
read_queue: responder_to_initiator.clone(),
write_queue: initiator_to_responder.clone(),
peer_conn_id: initiator_peer_conn_id_slot.clone(),
..Default::default()
};
let responder_endpoint = ConnectionEndpoint {
read_queue: initiator_to_responder,
write_queue: responder_to_initiator,
peer_conn_id: responder_peer_conn_id_slot.clone(),
..Default::default()
};
let queued = inner.connect(&target_addr, responder_endpoint, initiator_peer_conn_id_slot, cx.waker());
if !queued {
return Poll::Pending;
}
let conn_id = inner.register_endpoint(initiator_endpoint);
*responder_peer_conn_id_slot.lock() = Some(conn_id);
provider.wake_accept_wakers();
tracing::debug!("connected to {target_addr} with id {conn_id}");
Poll::Ready(Ok(conn_id))
}))
}
fn connect_addrs(
&self,
addr: ToSocketAddrs,
timeout: Duration,
) -> BoxFuture<'static, std::io::Result<ConnectionId>> {
let inner = self.inner.clone();
Box::pin(async move {
let addrs = addr.to_socket_addrs().map_err(std::io::Error::other)?;
let provider = InMemoryConnectionProvider { inner };
provider.connect(addrs, timeout).await
})
}
fn send(&self, conn: ConnectionId, data: NonEmptyBytes) -> BoxFuture<'static, std::io::Result<()>> {
let inner = self.inner.clone();
let provider = self.clone();
Box::pin(async move {
let (write_queue, peer_conn_id) = {
let connections = inner.connection_endpoints.lock();
let endpoint = connections
.get(&conn)
.ok_or_else(|| std::io::Error::other(format!("connection {conn} not found for send")))?;
(endpoint.write_queue.clone(), *endpoint.peer_conn_id.lock())
};
write_queue.lock().push_back(Bytes::copy_from_slice(&data));
if let Some(peer_id) = peer_conn_id {
provider.wake_recv_wakers(peer_id);
} else {
let all_conn_ids: Vec<ConnectionId> = inner.connection_endpoints.lock().keys().copied().collect();
for id in all_conn_ids {
if id != conn {
provider.wake_recv_wakers(id);
}
}
}
Ok(())
})
}
fn recv(&self, conn: ConnectionId, bytes: NonZeroUsize) -> BoxFuture<'static, std::io::Result<NonEmptyBytes>> {
let inner = self.inner.clone();
Box::pin(std::future::poll_fn(move |cx| {
let (read_buffer, read_queue, recv_waker) = {
let connections = inner.connection_endpoints.lock();
match connections.get(&conn) {
Some(e) => (e.read_buffer.clone(), e.read_queue.clone(), e.recv_waker.clone()),
None => {
return Poll::Ready(Err(std::io::Error::other(format!(
"connection {conn} not found for recv"
))));
}
}
};
*recv_waker.lock() = Some(cx.waker().clone());
let mut buffer = read_buffer.lock();
{
let mut rq = read_queue.lock();
while let Some(data) = rq.pop_front() {
buffer.extend_from_slice(&data);
}
}
if buffer.remaining() >= bytes.get() {
#[expect(clippy::expect_used)]
let result = buffer.copy_to_bytes(bytes.get()).try_into().expect("guaranteed by NonZeroUsize");
return Poll::Ready(Ok(result));
}
Poll::Pending
}))
}
fn close(&self, conn: ConnectionId) -> BoxFuture<'static, std::io::Result<()>> {
let inner = self.inner.clone();
Box::pin(async move {
let (removed, peer) = {
let mut connections = inner.connection_endpoints.lock();
let Some(removed) = connections.remove(&conn) else {
tracing::debug!("in-memory connection {conn} already closed");
return Ok(());
};
let peer_id = *removed.peer_conn_id.lock();
let peer = peer_id.and_then(|id| connections.remove(&id));
(removed, peer)
};
if let Some(waker) = removed.recv_waker.lock().take() {
waker.wake();
}
if let Some(peer) = peer
&& let Some(waker) = peer.recv_waker.lock().take()
{
waker.wake();
}
tracing::debug!("closed in-memory connection {conn}");
Ok(())
})
}
}
impl InMemoryConnectionProvider {
fn wake_accept_wakers(&self) {
let wakers: Vec<Waker> = self.inner.accept_wakers.lock().drain(..).collect();
for waker in wakers {
waker.wake();
}
}
fn wake_connect_wakers(&self, addr: &SocketAddr) {
let wakers: Vec<Waker> = self.inner.connect_wakers.lock().remove(addr).unwrap_or_default();
for waker in wakers {
waker.wake();
}
}
fn wake_recv_wakers(&self, conn: ConnectionId) {
let waker = {
let connections = self.inner.connection_endpoints.lock();
connections.get(&conn).and_then(|e| e.recv_waker.lock().take())
};
if let Some(waker) = waker {
waker.wake();
}
}
}
struct ConnectionEndpoint {
read_buffer: Arc<Mutex<BytesMut>>,
read_queue: Arc<Mutex<VecDeque<Bytes>>>,
write_queue: Arc<Mutex<VecDeque<Bytes>>>,
recv_waker: Arc<Mutex<Option<Waker>>>,
peer_conn_id: Arc<Mutex<Option<ConnectionId>>>,
}
impl Default for ConnectionEndpoint {
fn default() -> Self {
Self {
read_buffer: Arc::new(Mutex::new(BytesMut::with_capacity(65536))),
read_queue: Arc::new(Mutex::new(VecDeque::new())),
write_queue: Arc::new(Mutex::new(VecDeque::new())),
recv_waker: Arc::new(Mutex::new(None)),
peer_conn_id: Arc::new(Mutex::new(None)),
}
}
}
impl std::fmt::Debug for ConnectionEndpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InMemoryEndpoint")
.field("read_buffer", &self.read_buffer)
.field("read_queue", &self.read_queue)
.field("write_queue", &self.write_queue)
.field("recv_waker", &"...")
.field("peer_conn_id", &self.peer_conn_id)
.finish()
}
}
struct Listener {
pending_connects: VecDeque<PendingConnect>,
}
struct PendingConnect {
responder_endpoint: ConnectionEndpoint,
initiator_addr: SocketAddr,
initiator_peer_conn_id_slot: Arc<Mutex<Option<ConnectionId>>>,
}
struct Inner {
listeners: Mutex<BTreeMap<SocketAddr, Listener>>,
connection_endpoints: Mutex<BTreeMap<ConnectionId, ConnectionEndpoint>>,
next_connection_id: Mutex<ConnectionId>,
last_connection_port: AtomicU16,
accept_wakers: Mutex<Vec<Waker>>,
connect_wakers: Mutex<BTreeMap<SocketAddr, Vec<Waker>>>,
}
impl Default for Inner {
fn default() -> Self {
Self {
listeners: Mutex::new(BTreeMap::new()),
connection_endpoints: Mutex::new(BTreeMap::new()),
next_connection_id: Mutex::new(ConnectionId::initial()),
last_connection_port: AtomicU16::new(5000),
accept_wakers: Mutex::new(Vec::new()),
connect_wakers: Mutex::new(BTreeMap::new()),
}
}
}
impl Inner {
fn register_endpoint(&self, endpoint: ConnectionEndpoint) -> ConnectionId {
let connection_id = {
let mut next_id = self.next_connection_id.lock();
let connection_id = *next_id;
*next_id = connection_id.next();
connection_id
};
let mut connections = self.connection_endpoints.lock();
connections.insert(connection_id, endpoint);
connection_id
}
fn new_port(&self) -> u16 {
self.last_connection_port.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
fn connect(
&self,
target_addr: &SocketAddr,
responder_endpoint: ConnectionEndpoint,
initiator_peer_conn_id_slot: Arc<Mutex<Option<ConnectionId>>>,
waker: &Waker,
) -> bool {
let pending = PendingConnect {
responder_endpoint,
initiator_addr: SocketAddr::from(([127, 0, 0, 1], self.new_port())),
initiator_peer_conn_id_slot,
};
let queued = {
let mut listeners = self.listeners.lock();
if let Some(listener) = listeners.get_mut(target_addr) {
listener.pending_connects.push_back(pending);
true
} else {
false
}
};
if !queued {
self.connect_wakers.lock().entry(*target_addr).or_default().push(waker.clone());
}
queued
}
}
#[cfg(test)]
mod tests {
use amaru_ouroboros_traits::ConnectionProvider;
use super::*;
#[tokio::test]
async fn test_connection_and_send_recv() -> std::io::Result<()> {
let provider = InMemoryConnectionProvider::default();
let listener_addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
provider.listen(listener_addr).await?;
let initiator_conn_id = provider.connect(vec![listener_addr], Duration::from_secs(1)).await?;
let (_peer, responder_conn_id) = provider.accept(listener_addr).await?;
let msg1 = NonEmptyBytes::try_from(Bytes::from("hello from initiator")).unwrap();
provider.send(initiator_conn_id, msg1.clone()).await?;
let received1 = provider.recv(responder_conn_id, msg1.len()).await?;
assert_eq!(received1.as_ref(), msg1.as_ref());
let msg2 = NonEmptyBytes::try_from(Bytes::from("hello from responder")).unwrap();
provider.send(responder_conn_id, msg2.clone()).await?;
let received2 = provider.recv(initiator_conn_id, msg2.len()).await?;
assert_eq!(received2.as_ref(), msg2.as_ref());
provider.close(initiator_conn_id).await?;
provider.close(responder_conn_id).await?;
Ok(())
}
#[tokio::test]
async fn test_multiple_messages() -> std::io::Result<()> {
let provider = InMemoryConnectionProvider::default();
let listener_addr: SocketAddr = "127.0.0.1:9001".parse().unwrap();
provider.listen(listener_addr).await?;
let initiator_conn_id = provider.connect(vec![listener_addr], Duration::from_secs(1)).await?;
let (_peer, responder_conn_id) = provider.accept(listener_addr).await?;
let messages = vec!["first", "second", "third"];
for msg in &messages {
let data = NonEmptyBytes::try_from(Bytes::from(*msg)).unwrap();
provider.send(initiator_conn_id, data).await?;
}
for msg in &messages {
let received = provider.recv(responder_conn_id, NonZeroUsize::new(msg.len()).unwrap()).await?;
assert_eq!(received.as_ref(), msg.as_bytes());
}
provider.close(initiator_conn_id).await?;
provider.close(responder_conn_id).await?;
Ok(())
}
#[tokio::test]
async fn test_connect_waits_for_listen() -> std::io::Result<()> {
let provider = InMemoryConnectionProvider::default();
let listener_addr: SocketAddr = "127.0.0.1:9002".parse().unwrap();
let provider_clone = provider.clone();
let connect_handle =
tokio::spawn(async move { provider_clone.connect(vec![listener_addr], Duration::from_secs(10)).await });
tokio::task::yield_now().await;
provider.listen(listener_addr).await?;
let initiator_conn_id = connect_handle.await.unwrap()?;
let (_peer, _responder_conn_id) = provider.accept(listener_addr).await?;
provider.close(initiator_conn_id).await?;
Ok(())
}
}