use super::Dest; use super::dest_transport::DestTransport;
use super::error::{NetworkError, NetworkResult};
use super::wire_handle::WireHandle;
use actr_protocol::{ActrId, PayloadType};
use async_trait::async_trait;
use either::Either;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, Notify, RwLock};
use tokio_util::sync::CancellationToken;
#[async_trait]
pub trait WireBuilder: Send + Sync {
async fn create_connections(&self, dest: &Dest) -> NetworkResult<Vec<Arc<dyn WireHandle>>>;
async fn create_connections_with_cancel(
&self,
dest: &Dest,
cancel_token: Option<CancellationToken>,
) -> NetworkResult<Vec<Arc<dyn WireHandle>>> {
if let Some(ref token) = cancel_token {
if token.is_cancelled() {
return Err(NetworkError::ConnectionClosed(
"Connection creation cancelled".to_string(),
));
}
}
self.create_connections(dest).await
}
}
type DestState = Either<Arc<Notify>, Arc<DestTransport>>;
pub struct PeerTransport {
#[allow(dead_code)]
local_id: ActrId,
transports: Arc<RwLock<HashMap<Dest, DestState>>>,
conn_factory: Arc<dyn WireBuilder>,
pending_tokens: Arc<Mutex<HashMap<Dest, CancellationToken>>>,
#[allow(unused)]
closing_peers: Arc<RwLock<HashSet<Dest>>>,
}
impl PeerTransport {
pub fn new(local_id: ActrId, conn_factory: Arc<dyn WireBuilder>) -> Self {
Self {
local_id,
transports: Arc::new(RwLock::new(HashMap::new())),
conn_factory,
pending_tokens: Arc::new(Mutex::new(HashMap::new())),
closing_peers: Arc::new(RwLock::new(HashSet::new())),
}
}
#[allow(dead_code)]
pub async fn is_closing(&self, dest: &Dest) -> bool {
self.closing_peers.read().await.contains(dest)
}
pub async fn is_connecting(&self, dest: &Dest) -> bool {
let transports = self.transports.read().await;
matches!(transports.get(dest), Some(Either::Left(_)))
}
#[cfg_attr(feature = "opentelemetry", tracing::instrument(
skip_all,
name = "PeerTransport.get_or_create_transport",
fields(dest = ?dest.as_actor_id().map(|id| id))
))]
pub(crate) async fn get_or_create_transport(
&self,
dest: &Dest,
) -> NetworkResult<Arc<DestTransport>> {
if self.closing_peers.read().await.contains(dest) {
return Err(NetworkError::ConnectionClosed(format!(
"Destination {:?} is being closed.",
dest
)));
}
loop {
let state_opt = {
let transports = self.transports.read().await;
transports.get(dest).cloned()
};
match state_opt {
Some(Either::Right(transport)) => {
tracing::debug!("Reusing existing DestTransport: {:?}", dest);
return Ok(transport);
}
Some(Either::Left(notify)) => {
tracing::debug!("Waiting for ongoing connection: {:?}", dest);
notify.notified().await;
if self.closing_peers.read().await.contains(dest) {
return Err(NetworkError::ConnectionClosed(format!(
"Destination {:?} was closed while waiting",
dest
)));
}
continue;
}
None => {
}
}
let notify = {
let mut transports = self.transports.write().await;
match transports.get(dest) {
Some(Either::Right(transport)) => {
return Ok(Arc::clone(transport));
}
Some(Either::Left(notify)) => {
Arc::clone(notify)
}
None => {
if self.closing_peers.read().await.contains(dest) {
return Err(NetworkError::ConnectionClosed(format!(
"Destination {:?} is being closed",
dest
)));
}
let notify = Arc::new(Notify::new());
transports.insert(dest.clone(), Either::Left(Arc::clone(¬ify)));
tracing::debug!("Inserted Connecting state for: {:?}", dest);
Arc::clone(¬ify)
}
}
};
let is_creator = {
let transports = self.transports.read().await;
matches!(transports.get(dest), Some(Either::Left(n)) if Arc::ptr_eq(n, ¬ify))
};
if !is_creator {
tracing::debug!("Another thread is creating connection: {:?}", dest);
match tokio::time::timeout(Duration::from_secs(10), notify.notified()).await {
Ok(_) => continue,
Err(e) => {
return Err(NetworkError::TimeoutError(format!(
"Timeout waiting for notification: {:?} {}",
dest, e
)));
}
}
}
tracing::info!("Creating new connection for: {:?}", dest);
let cancel_token = CancellationToken::new();
{
let mut tokens = self.pending_tokens.lock().await;
tokens.insert(dest.clone(), cancel_token.clone());
}
let result = async {
let connections = self
.conn_factory
.create_connections_with_cancel(dest, Some(cancel_token.clone()))
.await?;
if connections.is_empty() {
return Err(NetworkError::ConfigurationError(format!(
"Connection factory returned no connections: {dest:?}"
)));
}
tracing::info!(
"Creating DestTransport: {:?} ({} connections)",
dest,
connections.len()
);
let transport = DestTransport::new(dest.clone(), connections).await?;
Ok(Arc::new(transport))
}
.await;
{
let mut tokens = self.pending_tokens.lock().await;
tokens.remove(dest);
}
let mut transports = self.transports.write().await;
match result {
Ok(transport) => {
tracing::info!("Connection established: {:?}", dest);
transports.insert(dest.clone(), Either::Right(Arc::clone(&transport)));
drop(transports);
self.spawn_ready_monitor(dest.clone(), Arc::clone(&transport));
notify.notify_waiters();
return Ok(transport);
}
Err(e) => {
tracing::error!("Connection failed: {:?}: {}", dest, e);
transports.remove(dest);
drop(transports);
notify.notify_waiters();
return Err(e);
}
}
}
}
pub async fn send(
&self,
dest: &Dest,
payload_type: PayloadType,
data: &[u8],
) -> NetworkResult<()> {
let transport = self.get_or_create_transport(dest).await?;
transport.send(payload_type, data).await
}
pub async fn close_transport(&self, dest: &Dest) -> NetworkResult<()> {
self.closing_peers.write().await.insert(dest.clone());
let current_state = {
let transports = self.transports.read().await;
transports.get(dest).cloned()
};
match current_state {
Some(Either::Left(_)) => {
tracing::debug!(
"Ignoring close request for connecting destination {:?}; creator owns retry/cleanup",
dest
);
}
Some(Either::Right(_)) => {
{
let mut tokens = self.pending_tokens.lock().await;
if let Some(token) = tokens.remove(dest) {
tracing::info!("Cancelling in-progress connection for {:?}", dest);
token.cancel();
}
}
let state = {
let mut transports = self.transports.write().await;
transports.remove(dest)
};
if let Some(Either::Right(transport)) = state {
tracing::info!("Closing DestTransport: {:?}", dest);
transport.close().await?;
}
}
None => {
tracing::debug!(
"Ignoring close request for {:?}; no transport state exists",
dest
);
}
}
self.closing_peers.write().await.remove(dest);
Ok(())
}
pub(crate) async fn close_transport_if_webrtc_session(
&self,
dest: &Dest,
peer_id: &ActrId,
session_id: u64,
) -> NetworkResult<bool> {
let transport = {
let transports = self.transports.read().await;
match transports.get(dest) {
Some(Either::Right(transport)) => Some(Arc::clone(transport)),
_ => None, }
};
let Some(transport) = transport else {
tracing::debug!(
"Stale close event for {:?} (session {} mismatch or transport absent)",
dest,
session_id
);
return Ok(false);
};
if !transport.matches_webrtc_session(peer_id, session_id).await {
tracing::debug!(
"Stale close event for {:?} (session {} mismatch or transport absent)",
dest,
session_id
);
return Ok(false);
}
self.close_transport(dest).await?;
Ok(true)
}
#[allow(dead_code)]
pub async fn close_all(&self) -> NetworkResult<()> {
let mut transports = self.transports.write().await;
tracing::info!("Closing all DestTransports (count: {})", transports.len());
for (dest, state) in transports.drain() {
match state {
Either::Right(transport) => {
if let Err(e) = transport.close().await {
tracing::warn!("Failed to close DestTransport {:?}: {}", dest, e);
}
}
Either::Left(_notify) => {
tracing::debug!("Skipped Connecting state for: {:?}", dest);
}
}
}
Ok(())
}
#[cfg(feature = "test-utils")]
pub async fn dest_count(&self) -> usize {
self.transports.read().await.len()
}
#[inline]
#[cfg(feature = "test-utils")]
pub fn local_id(&self) -> &ActrId {
&self.local_id
}
#[cfg(feature = "test-utils")]
pub async fn list_dests(&self) -> Vec<Dest> {
self.transports.read().await.keys().cloned().collect()
}
#[cfg(feature = "test-utils")]
pub async fn has_dest(&self, dest: &Dest) -> bool {
self.transports.read().await.contains_key(dest)
}
fn spawn_ready_monitor(&self, dest: Dest, transport: Arc<DestTransport>) {
let transports = Arc::clone(&self.transports);
tokio::spawn(async move {
let mut rx = transport.watch_ready();
let mut had_ready = !rx.borrow().is_empty();
loop {
if rx.changed().await.is_err() {
break;
}
let ready = rx.borrow().clone();
if ready.is_empty() && had_ready {
let mut map = transports.write().await;
let matched = matches!(
map.get(&dest),
Some(Either::Right(existing)) if Arc::ptr_eq(existing, &transport)
);
if matched {
map.remove(&dest);
drop(map);
tracing::warn!(
"Removing DestTransport for {:?} after all connections closed",
dest
);
if let Err(e) = transport.close().await {
tracing::warn!("Failed to close DestTransport {:?}: {}", dest, e);
}
}
break;
}
if !ready.is_empty() {
had_ready = true;
}
}
});
}
#[cfg(feature = "test-utils")]
pub fn spawn_health_checker(&self, interval: Duration) -> tokio::task::JoinHandle<()> {
let transports = Arc::clone(&self.transports);
let conn_factory = Arc::clone(&self.conn_factory);
tokio::spawn(async move {
let mut interval_timer = tokio::time::interval(interval);
interval_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval_timer.tick().await;
let snapshot: Vec<(Dest, Arc<DestTransport>)> = {
let transports_read = transports.read().await;
transports_read
.iter()
.filter_map(|(dest, state)| {
if let Either::Right(transport) = state {
Some((dest.clone(), Arc::clone(transport)))
} else {
None
}
})
.collect()
};
for (dest_clone, transport) in snapshot {
let healthy = transport.has_healthy_connection().await;
if !healthy {
tracing::warn!("All connections failed for {:?}, will remove", dest_clone);
let mut transports_write = transports.write().await;
if let Some(Either::Right(transport)) = transports_write.remove(&dest_clone)
{
tracing::info!(
"Removing completely failed DestTransport: {:?}",
dest_clone
);
drop(transports_write);
if let Err(e) = transport.close().await {
tracing::warn!(
"Failed to close DestTransport {:?}: {}",
dest_clone,
e
);
}
} else {
drop(transports_write);
}
} else {
tracing::debug!("Triggering smart reconnect for: {:?}", dest_clone);
if let Err(e) = transport
.retry_failed_connections(&dest_clone, conn_factory.as_ref())
.await
{
tracing::warn!("Smart reconnect failed for {:?}: {}", dest_clone, e);
}
}
}
}
})
}
}
impl Drop for PeerTransport {
fn drop(&mut self) {
tracing::debug!("PeerTransport dropped");
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestFactory;
#[async_trait]
impl WireBuilder for TestFactory {
async fn create_connections(
&self,
_dest: &Dest,
) -> NetworkResult<Vec<Arc<dyn WireHandle>>> {
Ok(vec![])
}
}
fn create_test_factory() -> Arc<dyn WireBuilder> {
Arc::new(TestFactory)
}
#[tokio::test]
async fn test_transport_manager_creation() {
let local_id = ActrId::default();
let factory = create_test_factory();
let mgr = PeerTransport::new(local_id.clone(), factory);
assert_eq!(mgr.dest_count().await, 0);
assert_eq!(mgr.local_id(), &local_id);
}
#[tokio::test]
async fn test_list_dests() {
let local_id = ActrId::default();
let factory = create_test_factory();
let mgr = PeerTransport::new(local_id, factory);
let dests = mgr.list_dests().await;
assert_eq!(dests.len(), 0);
}
#[tokio::test]
async fn test_has_dest() {
let local_id = ActrId::default();
let factory = create_test_factory();
let mgr = PeerTransport::new(local_id, factory);
let dest = Dest::shell();
assert!(!mgr.has_dest(&dest).await);
}
}