use super::limits::CHANNEL_BUFFER_SIZE_COUNT_DEFAULT;
use super::transport::{InMemoryChannelTransport, Transport, TransportError};
use crate::{QueueCapacity, RoleName};
use async_trait::async_trait;
use cfg_if::cfg_if;
use std::collections::BTreeMap;
use std::sync::Arc;
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
use futures::lock::Mutex;
} else {
use tokio::sync::Mutex;
}
}
#[async_trait]
pub trait TransportFactory: Send + Sync {
async fn create(&self, role: &RoleName) -> Result<Box<dyn Transport>, TransportError>;
}
#[derive(Debug)]
pub struct InMemoryTransportFactory {
buffer_size: QueueCapacity,
transports: Arc<Mutex<BTreeMap<RoleName, Arc<InMemoryChannelTransport>>>>,
}
impl InMemoryTransportFactory {
pub fn new() -> Self {
Self {
buffer_size: QueueCapacity::try_new(CHANNEL_BUFFER_SIZE_COUNT_DEFAULT)
.expect("default channel buffer size must be within bounds"),
transports: Arc::new(Mutex::new(BTreeMap::new())),
}
}
pub fn with_buffer_size(buffer_size: QueueCapacity) -> Self {
Self {
buffer_size,
transports: Arc::new(Mutex::new(BTreeMap::new())),
}
}
pub async fn get_or_create(&self, role: &RoleName) -> Arc<InMemoryChannelTransport> {
let mut transports = self.transports.lock().await;
if let Some(existing) = transports.get(role) {
return existing.clone();
}
let transport = Arc::new(InMemoryChannelTransport::with_buffer_size(
role.clone(),
self.buffer_size,
));
for (_other_role, other_transport) in transports.iter() {
transport.connect(other_transport).await;
}
transports.insert(role.clone(), transport.clone());
transport
}
pub async fn transports(&self) -> BTreeMap<RoleName, Arc<InMemoryChannelTransport>> {
self.transports.lock().await.clone()
}
pub async fn clear(&self) {
self.transports.lock().await.clear();
}
}
impl Default for InMemoryTransportFactory {
fn default() -> Self {
Self::new()
}
}
impl Clone for InMemoryTransportFactory {
fn clone(&self) -> Self {
Self {
buffer_size: self.buffer_size,
transports: Arc::clone(&self.transports),
}
}
}
#[async_trait]
impl TransportFactory for InMemoryTransportFactory {
async fn create(&self, role: &RoleName) -> Result<Box<dyn Transport>, TransportError> {
let transport = self.get_or_create(role).await;
Ok(Box::new((*transport).clone()))
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use crate::topology::Message;
#[tokio::test]
async fn test_in_memory_factory_creates_transport() {
let factory = InMemoryTransportFactory::new();
let transport = factory.create(&RoleName::from_static("Alice")).await.unwrap();
assert!(transport.is_connected(&RoleName::from_static("Bob"))); }
#[tokio::test]
async fn test_in_memory_factory_connects_transports() {
let factory = InMemoryTransportFactory::new();
let alice = factory.get_or_create(&RoleName::from_static("Alice")).await;
let bob = factory.get_or_create(&RoleName::from_static("Bob")).await;
let msg = Message::new(b"Hello Bob".to_vec()).unwrap();
alice.send(&RoleName::from_static("Bob"), msg).await.unwrap();
let received = bob.recv(&RoleName::from_static("Alice")).await.unwrap();
assert_eq!(received.as_bytes(), b"Hello Bob");
}
#[tokio::test]
async fn test_in_memory_factory_custom_buffer_size() {
let factory = InMemoryTransportFactory::with_buffer_size(
QueueCapacity::try_new(64).expect("test buffer size in range"),
);
let _transport = factory.create(&RoleName::from_static("Alice")).await.unwrap();
}
#[tokio::test]
async fn test_in_memory_factory_reuses_transport() {
let factory = InMemoryTransportFactory::new();
let t1 = factory.get_or_create(&RoleName::from_static("Alice")).await;
let t2 = factory.get_or_create(&RoleName::from_static("Alice")).await;
assert!(Arc::ptr_eq(&t1, &t2));
}
#[tokio::test]
async fn test_in_memory_factory_clear() {
let factory = InMemoryTransportFactory::new();
factory.get_or_create(&RoleName::from_static("Alice")).await;
factory.get_or_create(&RoleName::from_static("Bob")).await;
assert_eq!(factory.transports().await.len(), 2);
factory.clear().await;
assert!(factory.transports().await.is_empty());
}
#[tokio::test]
async fn test_in_memory_factory_clone_shares_state() {
let factory1 = InMemoryTransportFactory::new();
let factory2 = factory1.clone();
factory1.get_or_create(&RoleName::from_static("Alice")).await;
assert_eq!(factory2.transports().await.len(), 1);
}
#[tokio::test]
async fn test_in_memory_factory_transports_are_sorted_by_role() {
let factory = InMemoryTransportFactory::new();
factory.get_or_create(&RoleName::from_static("Zed")).await;
factory.get_or_create(&RoleName::from_static("Alice")).await;
factory.get_or_create(&RoleName::from_static("Bob")).await;
let roles: Vec<_> = factory
.transports()
.await
.keys()
.map(ToString::to_string)
.collect();
assert_eq!(roles, vec!["Alice", "Bob", "Zed"]);
}
}