use super::{Location, Topology, TopologyMode};
use crate::identifiers::RoleName;
use crate::mutex_lock;
use crate::runtime::sync::{mpsc, Mutex};
use async_trait::async_trait;
use cfg_if::cfg_if;
#[cfg(target_arch = "wasm32")]
use futures::{SinkExt, StreamExt};
use std::collections::BTreeMap;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum TransportError {
#[error("connection failed: {0}")]
ConnectionFailed(String),
#[error("send failed: {0}")]
SendFailed(String),
#[error("receive failed: {0}")]
ReceiveFailed(String),
#[error("timeout")]
Timeout,
#[error("channel closed")]
ChannelClosed,
#[error("unknown role: {0}")]
UnknownRole(RoleName),
#[error("transport not ready")]
NotReady,
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
pub type TransportResult<T> = Result<T, TransportError>;
pub trait TransportMessage: Send + Sync + 'static {
fn to_bytes(&self) -> Vec<u8>;
fn from_bytes(bytes: &[u8]) -> Result<Self, String>
where
Self: Sized;
}
#[derive(Debug, Clone)]
pub struct ByteMessage(pub Vec<u8>);
impl TransportMessage for ByteMessage {
fn to_bytes(&self) -> Vec<u8> {
self.0.clone()
}
fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
Ok(ByteMessage(bytes.to_vec()))
}
}
#[async_trait]
pub trait Transport: Send + Sync + 'static {
async fn send(&self, to_role: &RoleName, message: Vec<u8>) -> TransportResult<()>;
async fn recv(&self, from_role: &RoleName) -> TransportResult<Vec<u8>>;
fn is_connected(&self, role: &RoleName) -> bool;
async fn close(&self) -> TransportResult<()>;
}
pub struct InMemoryChannelTransport {
role: RoleName,
senders: Arc<Mutex<BTreeMap<RoleName, mpsc::Sender<Vec<u8>>>>>,
receivers: Arc<Mutex<BTreeMap<RoleName, mpsc::Receiver<Vec<u8>>>>>,
}
impl InMemoryChannelTransport {
pub fn new(role: RoleName) -> Self {
Self {
role,
senders: Arc::new(Mutex::new(BTreeMap::new())),
receivers: Arc::new(Mutex::new(BTreeMap::new())),
}
}
pub async fn connect(&self, other: &InMemoryChannelTransport) {
let (tx1, rx1) = mpsc::channel(32);
let (tx2, rx2) = mpsc::channel(32);
mutex_lock!(self.senders).insert(other.role.clone(), tx1);
mutex_lock!(other.receivers).insert(self.role.clone(), rx1);
mutex_lock!(other.senders).insert(self.role.clone(), tx2);
mutex_lock!(self.receivers).insert(other.role.clone(), rx2);
}
pub fn role(&self) -> &RoleName {
&self.role
}
}
#[async_trait]
impl Transport for InMemoryChannelTransport {
async fn send(&self, to_role: &RoleName, message: Vec<u8>) -> TransportResult<()> {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
let sender = {
let senders = mutex_lock!(self.senders);
senders
.get(to_role)
.cloned()
.ok_or_else(|| TransportError::UnknownRole(to_role.clone()))?
};
let mut sender = sender;
sender
.send(message)
.await
.map_err(|_| TransportError::ChannelClosed)
} else {
let senders = mutex_lock!(self.senders);
let sender = senders
.get(to_role)
.ok_or_else(|| TransportError::UnknownRole(to_role.clone()))?;
sender
.send(message)
.await
.map_err(|_| TransportError::ChannelClosed)
}
}
}
async fn recv(&self, from_role: &RoleName) -> TransportResult<Vec<u8>> {
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
let mut receiver = {
let mut receivers = mutex_lock!(self.receivers);
receivers
.remove(from_role)
.ok_or_else(|| TransportError::UnknownRole(from_role.clone()))?
};
let result = receiver.next().await;
{
let mut receivers = mutex_lock!(self.receivers);
receivers.insert(from_role.clone(), receiver);
}
result.ok_or(TransportError::ChannelClosed)
} else {
let mut receivers = mutex_lock!(self.receivers);
let receiver = receivers
.get_mut(from_role)
.ok_or_else(|| TransportError::UnknownRole(from_role.clone()))?;
receiver.recv().await.ok_or(TransportError::ChannelClosed)
}
}
}
fn is_connected(&self, _role: &RoleName) -> bool {
true
}
async fn close(&self) -> TransportResult<()> {
mutex_lock!(self.senders).clear();
mutex_lock!(self.receivers).clear();
Ok(())
}
}
pub struct TransportFactory;
impl TransportFactory {
pub fn create(topology: &Topology, role: &RoleName) -> Box<dyn Transport> {
match &topology.mode {
Some(TopologyMode::Local) | None => {
Box::new(InMemoryChannelTransport::new(role.clone()))
}
Some(TopologyMode::PerRole) => {
Box::new(InMemoryChannelTransport::new(role.clone()))
}
Some(TopologyMode::Kubernetes(_namespace)) => {
Box::new(InMemoryChannelTransport::new(role.clone()))
}
Some(TopologyMode::Consul(_datacenter)) => {
Box::new(InMemoryChannelTransport::new(role.clone()))
}
}
}
pub fn transport_for_location(
_from_role: &RoleName,
to_role: &RoleName,
topology: &Topology,
) -> Result<TransportType, super::TopologyError> {
match topology.get_location(to_role)? {
Location::Local => Ok(TransportType::InMemory),
Location::Colocated(_) => Ok(TransportType::SharedMemory),
Location::Remote(_) => Ok(TransportType::Tcp),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransportType {
InMemory,
SharedMemory,
Tcp,
WebSocket,
}
impl TransportType {
pub fn is_local(&self) -> bool {
matches!(self, TransportType::InMemory | TransportType::SharedMemory)
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_transport() {
let alice = InMemoryChannelTransport::new(RoleName::from_static("Alice"));
let bob = InMemoryChannelTransport::new(RoleName::from_static("Bob"));
alice.connect(&bob).await;
alice
.send(&RoleName::from_static("Bob"), b"Hello Bob".to_vec())
.await
.unwrap();
let msg = bob.recv(&RoleName::from_static("Alice")).await.unwrap();
assert_eq!(msg, b"Hello Bob".to_vec());
bob.send(&RoleName::from_static("Alice"), b"Hello Alice".to_vec())
.await
.unwrap();
let msg = alice.recv(&RoleName::from_static("Bob")).await.unwrap();
assert_eq!(msg, b"Hello Alice".to_vec());
}
#[test]
fn test_transport_type_for_location() {
let topology = Topology::builder()
.local_role(RoleName::from_static("Alice"))
.remote_role(
RoleName::from_static("Bob"),
crate::identifiers::Endpoint::new("localhost:8080").unwrap(),
)
.colocated_role(
RoleName::from_static("Carol"),
RoleName::from_static("Alice"),
)
.build();
assert_eq!(
TransportFactory::transport_for_location(
&RoleName::from_static("Alice"),
&RoleName::from_static("Alice"),
&topology
)
.unwrap(),
TransportType::InMemory
);
assert_eq!(
TransportFactory::transport_for_location(
&RoleName::from_static("Alice"),
&RoleName::from_static("Bob"),
&topology
)
.unwrap(),
TransportType::Tcp
);
assert_eq!(
TransportFactory::transport_for_location(
&RoleName::from_static("Alice"),
&RoleName::from_static("Carol"),
&topology
)
.unwrap(),
TransportType::SharedMemory
);
}
#[test]
fn test_transport_type_is_local() {
assert!(TransportType::InMemory.is_local());
assert!(TransportType::SharedMemory.is_local());
assert!(!TransportType::Tcp.is_local());
assert!(!TransportType::WebSocket.is_local());
}
}