use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot, Mutex};
use uuid::Uuid;
use crate::node::NodeId;
use crate::remote::WireEnvelope;
use crate::system_actors::HandshakeRequest;
#[derive(Debug, Clone)]
pub struct TransportError {
pub message: String,
}
impl TransportError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl std::fmt::Display for TransportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "transport error: {}", self.message)
}
}
impl std::error::Error for TransportError {}
#[async_trait]
pub trait Transport: Send + Sync + 'static {
async fn send(
&self,
target_node: &NodeId,
envelope: WireEnvelope,
) -> Result<(), TransportError>;
async fn send_request(
&self,
target_node: &NodeId,
envelope: WireEnvelope,
) -> Result<WireEnvelope, TransportError>;
async fn is_reachable(&self, node: &NodeId) -> bool;
}
pub struct InMemoryTransport {
routes: Arc<Mutex<HashMap<NodeId, mpsc::Sender<WireEnvelope>>>>,
pending: Arc<Mutex<HashMap<Uuid, oneshot::Sender<WireEnvelope>>>>,
local_node: NodeId,
connected: Arc<Mutex<std::collections::HashSet<NodeId>>>,
handshake_info: Arc<Mutex<HashMap<NodeId, HandshakeRequest>>>,
}
impl InMemoryTransport {
pub fn new(local_node: NodeId) -> Self {
Self {
routes: Arc::new(Mutex::new(HashMap::new())),
pending: Arc::new(Mutex::new(HashMap::new())),
local_node,
connected: Arc::new(Mutex::new(std::collections::HashSet::new())),
handshake_info: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn register_node(&self, node: NodeId) -> mpsc::Receiver<WireEnvelope> {
let (tx, rx) = mpsc::channel(256);
self.routes.lock().await.insert(node, tx);
rx
}
pub async fn link(&self, other: &InMemoryTransport) {
let self_entries: Vec<_> = {
let routes = self.routes.lock().await;
routes.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
};
let other_entries: Vec<_> = {
let routes = other.routes.lock().await;
routes.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
};
{
let mut routes = self.routes.lock().await;
for (node, sender) in other_entries {
routes.insert(node, sender);
}
}
{
let mut routes = other.routes.lock().await;
for (node, sender) in self_entries {
routes.insert(node, sender);
}
}
self.connected.lock().await.insert(other.local_node.clone());
other.connected.lock().await.insert(self.local_node.clone());
let self_info: Vec<_> = {
let info = self.handshake_info.lock().await;
info.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
};
let other_info: Vec<_> = {
let info = other.handshake_info.lock().await;
info.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
};
{
let mut info = self.handshake_info.lock().await;
for (node, req) in other_info {
info.insert(node, req);
}
}
{
let mut info = other.handshake_info.lock().await;
for (node, req) in self_info {
info.insert(node, req);
}
}
}
pub async fn set_handshake_info(&self, request: HandshakeRequest) {
let node = request.node_id.clone();
self.handshake_info.lock().await.insert(node, request);
}
pub async fn complete_request(
&self,
request_id: Uuid,
reply: WireEnvelope,
) -> Result<(), TransportError> {
let sender = self
.pending
.lock()
.await
.remove(&request_id)
.ok_or_else(|| TransportError::new(format!("no pending request for {request_id}")))?;
sender
.send(reply)
.map_err(|_| TransportError::new("reply receiver dropped"))
}
}
#[async_trait]
impl Transport for InMemoryTransport {
async fn send(
&self,
target_node: &NodeId,
envelope: WireEnvelope,
) -> Result<(), TransportError> {
let routes = self.routes.lock().await;
let sender = routes
.get(target_node)
.ok_or_else(|| TransportError::new(format!("no route to {target_node}")))?;
sender
.send(envelope)
.await
.map_err(|_| TransportError::new(format!("channel closed for {target_node}")))
}
async fn send_request(
&self,
target_node: &NodeId,
envelope: WireEnvelope,
) -> Result<WireEnvelope, TransportError> {
let request_id = envelope
.request_id
.ok_or_else(|| TransportError::new("send_request requires a request_id"))?;
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(request_id, tx);
if let Err(e) = self.send(target_node, envelope).await {
self.pending.lock().await.remove(&request_id);
return Err(e);
}
rx.await
.map_err(|_| TransportError::new("reply sender dropped"))
}
async fn is_reachable(&self, node: &NodeId) -> bool {
self.connected.lock().await.contains(node)
}
}
impl InMemoryTransport {
pub async fn connect(&self, node: &NodeId) -> Result<(), TransportError> {
let routes = self.routes.lock().await;
if routes.contains_key(node) {
self.connected.lock().await.insert(node.clone());
Ok(())
} else {
Err(TransportError::new(format!("no route to {node}")))
}
}
pub async fn disconnect(&self, node: &NodeId) -> Result<(), TransportError> {
self.connected.lock().await.remove(node);
Ok(())
}
pub async fn handshake(
&self,
node: &NodeId,
request: crate::system_actors::HandshakeRequest,
) -> Result<crate::system_actors::HandshakeResponse, TransportError> {
let info = self.handshake_info.lock().await;
let remote_info = info.get(node).ok_or_else(|| {
TransportError::new(format!("no handshake info registered for {node}"))
})?;
Ok(crate::system_actors::validate_handshake(remote_info, &request))
}
}
pub struct TransportRegistry {
transports: Mutex<HashMap<NodeId, Arc<dyn Transport>>>,
}
impl TransportRegistry {
pub fn new() -> Self {
Self {
transports: Mutex::new(HashMap::new()),
}
}
pub async fn register(&self, node: NodeId, transport: Arc<dyn Transport>) {
self.transports.lock().await.insert(node, transport);
}
pub async fn unregister(&self, node: &NodeId) {
self.transports.lock().await.remove(node);
}
pub async fn get(&self, node: &NodeId) -> Option<Arc<dyn Transport>> {
self.transports.lock().await.get(node).cloned()
}
pub async fn send(
&self,
target_node: &NodeId,
envelope: WireEnvelope,
) -> Result<(), TransportError> {
let transport = self
.get(target_node)
.await
.ok_or_else(|| TransportError::new(format!("no transport for {target_node}")))?;
transport.send(target_node, envelope).await
}
}
impl Default for TransportRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::interceptor::SendMode;
use crate::node::ActorId;
use crate::remote::WireHeaders;
use crate::system_actors::HandshakeResponse;
fn test_envelope(target_node: &str, body: &[u8]) -> WireEnvelope {
WireEnvelope {
target: ActorId {
node: NodeId(target_node.into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::Msg".into(),
send_mode: SendMode::Tell,
headers: WireHeaders::new(),
body: body.to_vec(),
request_id: None,
version: None,
}
}
fn test_envelope_with_request_id(target_node: &str, body: &[u8], id: Uuid) -> WireEnvelope {
WireEnvelope {
target: ActorId {
node: NodeId(target_node.into()),
local: 1,
},
target_name: "test".into(),
message_type: "test::Ask".into(),
send_mode: SendMode::Ask,
headers: WireHeaders::new(),
body: body.to_vec(),
request_id: Some(id),
version: None,
}
}
#[tokio::test]
async fn send_receive_roundtrip() {
let transport = InMemoryTransport::new(NodeId("node-a".into()));
let mut rx = transport.register_node(NodeId("node-b".into())).await;
transport.connect(&NodeId("node-b".into())).await.unwrap();
let envelope = test_envelope("node-b", b"hello");
transport
.send(&NodeId("node-b".into()), envelope)
.await
.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.body, b"hello");
assert_eq!(received.message_type, "test::Msg");
}
#[tokio::test]
async fn send_request_with_reply() {
let transport = Arc::new(InMemoryTransport::new(NodeId("node-a".into())));
let mut rx = transport.register_node(NodeId("node-b".into())).await;
transport.connect(&NodeId("node-b".into())).await.unwrap();
let request_id = Uuid::new_v4();
let envelope = test_envelope_with_request_id("node-b", b"question", request_id);
let transport_clone = Arc::clone(&transport);
let handle = tokio::spawn(async move {
transport_clone
.send_request(&NodeId("node-b".into()), envelope)
.await
});
let received = rx.recv().await.unwrap();
assert_eq!(received.body, b"question");
let reply = test_envelope_with_request_id("node-a", b"answer", request_id);
transport.complete_request(request_id, reply).await.unwrap();
let response = handle.await.unwrap().unwrap();
assert_eq!(response.body, b"answer");
}
#[tokio::test]
async fn is_reachable_false_for_unknown_true_after_connect() {
let transport = InMemoryTransport::new(NodeId("node-a".into()));
let _rx = transport.register_node(NodeId("node-b".into())).await;
assert!(!transport.is_reachable(&NodeId("node-b".into())).await);
assert!(!transport.is_reachable(&NodeId("node-c".into())).await);
transport.connect(&NodeId("node-b".into())).await.unwrap();
assert!(transport.is_reachable(&NodeId("node-b".into())).await);
transport
.disconnect(&NodeId("node-b".into()))
.await
.unwrap();
assert!(!transport.is_reachable(&NodeId("node-b".into())).await);
}
#[tokio::test]
async fn linked_transports_communicate() {
let t1 = InMemoryTransport::new(NodeId("node-1".into()));
let t2 = InMemoryTransport::new(NodeId("node-2".into()));
let mut rx1 = t1.register_node(NodeId("node-1".into())).await;
let mut rx2 = t2.register_node(NodeId("node-2".into())).await;
t1.link(&t2).await;
let envelope = test_envelope("node-2", b"from-t1");
t1.send(&NodeId("node-2".into()), envelope).await.unwrap();
let received = rx2.recv().await.unwrap();
assert_eq!(received.body, b"from-t1");
let envelope = test_envelope("node-1", b"from-t2");
t2.send(&NodeId("node-1".into()), envelope).await.unwrap();
let received = rx1.recv().await.unwrap();
assert_eq!(received.body, b"from-t2");
}
#[tokio::test]
async fn connect_fails_without_route() {
let transport = InMemoryTransport::new(NodeId("node-a".into()));
let result = transport.connect(&NodeId("node-unknown".into())).await;
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("no route"));
}
#[tokio::test]
async fn transport_registry_send() {
let transport = Arc::new(InMemoryTransport::new(NodeId("node-a".into())));
let mut rx = transport.register_node(NodeId("node-b".into())).await;
let registry = TransportRegistry::new();
registry
.register(NodeId("node-b".into()), transport.clone())
.await;
let envelope = test_envelope("node-b", b"via-registry");
registry
.send(&NodeId("node-b".into()), envelope)
.await
.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.body, b"via-registry");
}
#[tokio::test]
async fn transport_registry_missing_node() {
let registry = TransportRegistry::new();
let envelope = test_envelope("node-x", b"lost");
let result = registry.send(&NodeId("node-x".into()), envelope).await;
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("no transport"));
}
#[tokio::test]
async fn transport_error_display() {
let err = TransportError::new("connection refused");
assert_eq!(format!("{err}"), "transport error: connection refused");
}
use crate::version::WireVersion;
fn test_handshake_req(node: &str, wire: &str, adapter: &str) -> HandshakeRequest {
HandshakeRequest {
node_id: NodeId(node.into()),
wire_version: WireVersion::parse(wire).unwrap(),
app_version: None,
adapter: adapter.into(),
}
}
#[tokio::test]
async fn handshake_compatible_accepted() {
let t1 = InMemoryTransport::new(NodeId("node-1".into()));
let t2 = InMemoryTransport::new(NodeId("node-2".into()));
t1.set_handshake_info(test_handshake_req("node-1", "0.2.0", "ractor"))
.await;
t2.set_handshake_info(test_handshake_req("node-2", "0.2.0", "ractor"))
.await;
let _rx1 = t1.register_node(NodeId("node-1".into())).await;
let _rx2 = t2.register_node(NodeId("node-2".into())).await;
t1.link(&t2).await;
let req = test_handshake_req("node-1", "0.2.0", "ractor");
let resp = t1.handshake(&NodeId("node-2".into()), req).await.unwrap();
match resp {
HandshakeResponse::Accepted { node_id, .. } => {
assert_eq!(node_id, NodeId("node-2".into()));
}
_ => panic!("expected Accepted"),
}
}
#[tokio::test]
async fn handshake_incompatible_protocol_rejected() {
let t1 = InMemoryTransport::new(NodeId("node-1".into()));
let t2 = InMemoryTransport::new(NodeId("node-2".into()));
t1.set_handshake_info(test_handshake_req("node-1", "0.2.0", "ractor"))
.await;
t2.set_handshake_info(test_handshake_req("node-2", "1.0.0", "ractor"))
.await;
let _rx1 = t1.register_node(NodeId("node-1".into())).await;
let _rx2 = t2.register_node(NodeId("node-2".into())).await;
t1.link(&t2).await;
let req = test_handshake_req("node-1", "0.2.0", "ractor");
let resp = t1.handshake(&NodeId("node-2".into()), req).await.unwrap();
assert!(matches!(
resp,
HandshakeResponse::Rejected {
reason: crate::system_actors::RejectionReason::IncompatibleProtocol,
..
}
));
}
#[tokio::test]
async fn handshake_incompatible_adapter_rejected() {
let t1 = InMemoryTransport::new(NodeId("node-1".into()));
let t2 = InMemoryTransport::new(NodeId("node-2".into()));
t1.set_handshake_info(test_handshake_req("node-1", "0.2.0", "ractor"))
.await;
t2.set_handshake_info(test_handshake_req("node-2", "0.2.0", "kameo"))
.await;
let _rx1 = t1.register_node(NodeId("node-1".into())).await;
let _rx2 = t2.register_node(NodeId("node-2".into())).await;
t1.link(&t2).await;
let req = test_handshake_req("node-1", "0.2.0", "ractor");
let resp = t1.handshake(&NodeId("node-2".into()), req).await.unwrap();
assert!(matches!(
resp,
HandshakeResponse::Rejected {
reason: crate::system_actors::RejectionReason::IncompatibleAdapter,
..
}
));
}
#[tokio::test]
async fn handshake_no_info_returns_error() {
let t1 = InMemoryTransport::new(NodeId("node-1".into()));
let req = test_handshake_req("node-1", "0.2.0", "ractor");
let result = t1.handshake(&NodeId("node-unknown".into()), req).await;
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("no handshake info"));
}
}