use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use serde::Serialize;
use net::adapter::net::{
AckReason, ChannelConfig, ChannelConfigRegistry, ChannelName, ChannelPublisher, EntityKeypair,
MeshNode, MeshNodeConfig, MigrationSubprotocolHandler, PublishConfig, PublishReport, Stream,
StreamConfig, StreamStats,
};
use net::adapter::Adapter;
use net::event::StoredEvent;
use crate::error::{Result, SdkError};
#[derive(Default, Debug, Clone)]
pub struct SubscribeOptions {
pub token: Option<net::adapter::net::PermissionToken>,
}
pub struct MeshBuilder {
bind_addr: SocketAddr,
psk: [u8; 32],
heartbeat_interval: Duration,
session_timeout: Duration,
num_shards: u16,
identity: Option<crate::identity::Identity>,
subnet: Option<net::adapter::net::SubnetId>,
subnet_policy: Option<Arc<net::adapter::net::SubnetPolicy>>,
#[cfg(feature = "nat-traversal")]
reflex_override: Option<SocketAddr>,
#[cfg(feature = "port-mapping")]
try_port_mapping: bool,
}
impl MeshBuilder {
pub fn new(bind_addr: &str, psk: &[u8; 32]) -> Result<Self> {
let addr: SocketAddr = bind_addr
.parse()
.map_err(|e| SdkError::Config(format!("invalid bind address: {}", e)))?;
Ok(Self {
bind_addr: addr,
psk: *psk,
heartbeat_interval: Duration::from_secs(5),
session_timeout: Duration::from_secs(30),
num_shards: 4,
identity: None,
subnet: None,
subnet_policy: None,
#[cfg(feature = "nat-traversal")]
reflex_override: None,
#[cfg(feature = "port-mapping")]
try_port_mapping: false,
})
}
pub fn identity(mut self, identity: crate::identity::Identity) -> Self {
self.identity = Some(identity);
self
}
pub fn heartbeat_ms(mut self, ms: u64) -> Self {
self.heartbeat_interval = Duration::from_millis(ms);
self
}
pub fn session_timeout_ms(mut self, ms: u64) -> Self {
self.session_timeout = Duration::from_millis(ms);
self
}
pub fn shards(mut self, n: u16) -> Self {
self.num_shards = n;
self
}
pub fn subnet(mut self, id: net::adapter::net::SubnetId) -> Self {
self.subnet = Some(id);
self
}
pub fn subnet_policy(
mut self,
policy: impl Into<Arc<net::adapter::net::SubnetPolicy>>,
) -> Self {
self.subnet_policy = Some(policy.into());
self
}
#[cfg(feature = "nat-traversal")]
pub fn reflex_override(mut self, external: SocketAddr) -> Self {
self.reflex_override = Some(external);
self
}
#[cfg(feature = "port-mapping")]
pub fn try_port_mapping(mut self, enabled: bool) -> Self {
self.try_port_mapping = enabled;
self
}
pub async fn build(self) -> Result<Mesh> {
let (keypair, sdk_identity) = match self.identity {
Some(id) => (id.keypair().as_ref().clone(), Some(id)),
None => (EntityKeypair::generate(), None),
};
let mut config = MeshNodeConfig::new(self.bind_addr, self.psk)
.with_heartbeat_interval(self.heartbeat_interval)
.with_session_timeout(self.session_timeout)
.with_num_shards(self.num_shards)
.with_handshake(3, Duration::from_secs(5));
if let Some(id) = self.subnet {
config = config.with_subnet(id);
}
if let Some(policy) = self.subnet_policy {
config = config.with_subnet_policy(policy);
}
#[cfg(feature = "nat-traversal")]
if let Some(external) = self.reflex_override {
config = config.with_reflex_override(external);
}
#[cfg(feature = "port-mapping")]
if self.try_port_mapping {
config = config.with_try_port_mapping(true);
}
let mut node = MeshNode::new(keypair, config).await?;
let channel_configs = Arc::new(ChannelConfigRegistry::new());
node.set_channel_configs(channel_configs.clone());
if let Some(id) = sdk_identity.as_ref() {
node.set_token_cache(id.token_cache().clone());
}
Ok(Mesh {
node: Arc::new(node),
channel_configs,
identity: sdk_identity,
})
}
}
pub struct Mesh {
node: Arc<MeshNode>,
channel_configs: Arc<ChannelConfigRegistry>,
identity: Option<crate::identity::Identity>,
}
impl Mesh {
pub fn builder(bind_addr: &str, psk: &[u8; 32]) -> Result<MeshBuilder> {
MeshBuilder::new(bind_addr, psk)
}
pub fn public_key(&self) -> &[u8; 32] {
self.node.public_key()
}
pub fn node_id(&self) -> u64 {
self.node.node_id()
}
pub fn local_addr(&self) -> SocketAddr {
self.node.local_addr()
}
#[cfg(feature = "cortex")]
pub fn set_rpc_observer(&self, observer: Option<crate::mesh_rpc::RpcObserverHandle>) {
self.node.set_rpc_observer(observer);
}
#[cfg(all(feature = "net", feature = "cortex"))]
pub(crate) fn node(&self) -> &Arc<MeshNode> {
&self.node
}
#[cfg(all(feature = "net", feature = "cortex"))]
pub(crate) fn channel_configs_arc(&self) -> &Arc<ChannelConfigRegistry> {
&self.channel_configs
}
pub async fn connect(
&self,
peer_addr: &str,
peer_pubkey: &[u8; 32],
peer_node_id: u64,
) -> Result<()> {
let addr: SocketAddr = peer_addr
.parse()
.map_err(|e| SdkError::Config(format!("invalid peer address: {}", e)))?;
self.node.connect(addr, peer_pubkey, peer_node_id).await?;
Ok(())
}
pub async fn accept(&self, peer_node_id: u64) -> Result<SocketAddr> {
let (addr, _) = self.node.accept(peer_node_id).await?;
Ok(addr)
}
pub fn start(&self) {
self.node.start();
}
pub fn peer_count(&self) -> usize {
self.node.peer_count()
}
pub async fn send_to(&self, peer_addr: &str, event: &impl Serialize) -> Result<()> {
let addr: SocketAddr = peer_addr
.parse()
.map_err(|e| SdkError::Config(format!("invalid address: {}", e)))?;
let json = serde_json::to_vec(event)?;
let batch = net::event::Batch {
shard_id: 0,
events: vec![net::event::InternalEvent::new(Bytes::from(json), 0, 0)],
sequence_start: 0,
process_nonce: net::event::batch_process_nonce(),
};
self.node.send_to_peer(addr, batch).await?;
Ok(())
}
pub async fn send(&self, dest_node_id: u64, event: &impl Serialize) -> Result<()> {
let json = serde_json::to_vec(event)?;
let batch = net::event::Batch {
shard_id: 0,
events: vec![net::event::InternalEvent::new(Bytes::from(json), 0, 0)],
sequence_start: 0,
process_nonce: net::event::batch_process_nonce(),
};
self.node.send_routed(dest_node_id, batch).await?;
Ok(())
}
pub async fn send_raw_to(&self, peer_addr: &str, data: &[u8]) -> Result<()> {
let addr: SocketAddr = peer_addr
.parse()
.map_err(|e| SdkError::Config(format!("invalid address: {}", e)))?;
let batch = net::event::Batch {
shard_id: 0,
events: vec![net::event::InternalEvent::new(
Bytes::copy_from_slice(data),
0,
0,
)],
sequence_start: 0,
process_nonce: net::event::batch_process_nonce(),
};
self.node.send_to_peer(addr, batch).await?;
Ok(())
}
pub async fn recv(&self, limit: usize) -> Result<Vec<StoredEvent>> {
let result = self.node.poll_shard(0, None, limit).await?;
Ok(result.events)
}
pub async fn recv_shard(&self, shard_id: u16, limit: usize) -> Result<Vec<StoredEvent>> {
let result = self.node.poll_shard(shard_id, None, limit).await?;
Ok(result.events)
}
pub fn register_channel(&self, config: ChannelConfig) {
self.channel_configs.insert(config);
}
pub async fn subscribe_channel(
&self,
publisher_node_id: u64,
channel: &ChannelName,
) -> Result<()> {
self.subscribe_channel_with(publisher_node_id, channel, SubscribeOptions::default())
.await
}
pub async fn subscribe_channel_with(
&self,
publisher_node_id: u64,
channel: &ChannelName,
opts: SubscribeOptions,
) -> Result<()> {
let result = match opts.token {
Some(token) => {
self.node
.subscribe_channel_with_token(publisher_node_id, channel.clone(), token)
.await
}
None => {
self.node
.subscribe_channel(publisher_node_id, channel.clone())
.await
}
};
match result {
Ok(()) => Ok(()),
Err(e) => Err(adapter_to_channel_error(e)),
}
}
pub async fn unsubscribe_channel(
&self,
publisher_node_id: u64,
channel: &ChannelName,
) -> Result<()> {
match self
.node
.unsubscribe_channel(publisher_node_id, channel.clone())
.await
{
Ok(()) => Ok(()),
Err(e) => Err(adapter_to_channel_error(e)),
}
}
pub async fn publish(
&self,
channel: &ChannelName,
payload: Bytes,
config: PublishConfig,
) -> Result<PublishReport> {
let publisher = ChannelPublisher::new(channel.clone(), config);
Ok(self.node.publish(&publisher, payload).await?)
}
pub async fn publish_many(
&self,
channel: &ChannelName,
payloads: &[Bytes],
config: PublishConfig,
) -> Result<PublishReport> {
let publisher = ChannelPublisher::new(channel.clone(), config);
Ok(self.node.publish_many(&publisher, payloads).await?)
}
pub fn add_route(&self, dest_node_id: u64, next_hop_addr: &str) -> Result<()> {
let addr: SocketAddr = next_hop_addr
.parse()
.map_err(|e| SdkError::Config(format!("invalid address: {}", e)))?;
self.node.router().add_route(dest_node_id, addr);
Ok(())
}
pub fn remove_route(&self, dest_node_id: u64) {
self.node.router().remove_route(dest_node_id);
}
pub fn block_peer(&self, peer_addr: &str) -> Result<()> {
let addr: SocketAddr = peer_addr
.parse()
.map_err(|e| SdkError::Config(format!("invalid address: {}", e)))?;
self.node.block_peer(addr);
Ok(())
}
pub fn unblock_peer(&self, peer_addr: &str) -> Result<()> {
let addr: SocketAddr = peer_addr
.parse()
.map_err(|e| SdkError::Config(format!("invalid address: {}", e)))?;
self.node.unblock_peer(&addr);
Ok(())
}
pub fn discovered_nodes(&self) -> usize {
self.node.proximity_graph().node_count()
}
pub fn active_reroutes(&self) -> usize {
self.node.reroute_policy().active_reroutes()
}
pub fn open_stream(
&self,
peer_node_id: u64,
stream_id: u64,
config: StreamConfig,
) -> Result<Stream> {
self.node
.open_stream(peer_node_id, stream_id, config)
.map_err(SdkError::from)
}
pub fn close_stream(&self, peer_node_id: u64, stream_id: u64) {
self.node.close_stream(peer_node_id, stream_id);
}
pub async fn send_on_stream(&self, stream: &Stream, events: &[Bytes]) -> Result<()> {
self.node
.send_on_stream(stream, events)
.await
.map_err(SdkError::from)
}
pub async fn send_with_retry(
&self,
stream: &Stream,
events: &[Bytes],
max_retries: usize,
) -> Result<()> {
self.node
.send_with_retry(stream, events, max_retries)
.await
.map_err(SdkError::from)
}
pub async fn send_blocking(&self, stream: &Stream, events: &[Bytes]) -> Result<()> {
self.node
.send_blocking(stream, events)
.await
.map_err(SdkError::from)
}
pub fn stream_stats(&self, peer_node_id: u64, stream_id: u64) -> Option<StreamStats> {
self.node.stream_stats(peer_node_id, stream_id)
}
pub fn all_stream_stats(&self, peer_node_id: u64) -> Vec<(u64, StreamStats)> {
self.node.all_stream_stats(peer_node_id)
}
pub async fn announce_capabilities(
&self,
caps: crate::capabilities::CapabilitySet,
) -> Result<()> {
self.node.announce_capabilities(caps).await?;
Ok(())
}
pub async fn announce_capabilities_with(
&self,
caps: crate::capabilities::CapabilitySet,
ttl: std::time::Duration,
sign: bool,
) -> Result<()> {
self.node
.announce_capabilities_with(caps, ttl, sign)
.await?;
Ok(())
}
pub fn find_nodes(&self, filter: &crate::capabilities::CapabilityFilter) -> Vec<u64> {
self.node.find_nodes_by_filter(filter)
}
pub fn find_nodes_scoped(
&self,
filter: &crate::capabilities::CapabilityFilter,
scope: &crate::capabilities::ScopeFilter<'_>,
) -> Vec<u64> {
self.node.find_nodes_by_filter_scoped(filter, scope)
}
pub fn find_best_node(&self, req: &crate::capabilities::CapabilityRequirement) -> Option<u64> {
self.node.find_best_node(req)
}
pub fn find_best_node_scoped(
&self,
req: &crate::capabilities::CapabilityRequirement,
scope: &crate::capabilities::ScopeFilter<'_>,
) -> Option<u64> {
self.node.find_best_node_scoped(req, scope)
}
pub fn set_migration_handler(&mut self, handler: Arc<MigrationSubprotocolHandler>) {
self.node.set_migration_handler(handler);
}
pub async fn shutdown(self) -> Result<()> {
self.node.shutdown().await?;
Ok(())
}
pub fn inner(&self) -> &MeshNode {
&self.node
}
pub fn node_arc(&self) -> Arc<MeshNode> {
self.node.clone()
}
pub fn from_node_arc(
node: Arc<MeshNode>,
channel_configs: Arc<ChannelConfigRegistry>,
identity: Option<crate::identity::Identity>,
) -> Self {
Self {
node,
channel_configs,
identity,
}
}
pub fn identity(&self) -> Option<&crate::identity::Identity> {
self.identity.as_ref()
}
#[cfg(feature = "nat-traversal")]
pub fn nat_type(&self) -> net::adapter::net::traversal::classify::NatClass {
self.node.nat_class()
}
#[cfg(feature = "nat-traversal")]
pub fn reflex_addr(&self) -> Option<SocketAddr> {
self.node.reflex_addr()
}
#[cfg(feature = "nat-traversal")]
pub fn peer_nat_type(
&self,
peer_node_id: u64,
) -> net::adapter::net::traversal::classify::NatClass {
self.node.peer_nat_class(peer_node_id)
}
#[cfg(feature = "nat-traversal")]
pub async fn probe_reflex(&self, peer_node_id: u64) -> Result<SocketAddr> {
Ok(self.node.probe_reflex(peer_node_id).await?)
}
#[cfg(feature = "nat-traversal")]
pub async fn reclassify_nat(&self) {
self.node.reclassify_nat().await
}
#[cfg(feature = "nat-traversal")]
pub async fn connect_direct(
&self,
peer_node_id: u64,
peer_pubkey: &[u8; 32],
coordinator: u64,
) -> Result<()> {
self.node
.connect_direct(peer_node_id, peer_pubkey, coordinator)
.await?;
Ok(())
}
#[cfg(feature = "nat-traversal")]
pub fn traversal_stats(&self) -> net::adapter::net::traversal::TraversalStatsSnapshot {
self.node.traversal_stats()
}
#[cfg(feature = "nat-traversal")]
pub fn set_reflex_override(&self, external: SocketAddr) {
self.node.set_reflex_override(external);
}
#[cfg(feature = "nat-traversal")]
pub fn clear_reflex_override(&self) {
self.node.clear_reflex_override();
}
}
fn adapter_to_channel_error(err: net::error::AdapterError) -> SdkError {
use net::error::AdapterError;
if let AdapterError::Connection(ref msg) = err {
let prefix = "membership request rejected: ";
if let Some(tail) = msg.strip_prefix(prefix) {
let reason = parse_ack_reason(tail);
return SdkError::ChannelRejected(reason);
}
}
SdkError::from(err)
}
fn parse_ack_reason(s: &str) -> Option<AckReason> {
let inside = s.trim().strip_prefix("Some(")?.strip_suffix(')')?;
match inside {
"Unauthorized" => Some(AckReason::Unauthorized),
"UnknownChannel" => Some(AckReason::UnknownChannel),
"RateLimited" => Some(AckReason::RateLimited),
"TooManyChannels" => Some(AckReason::TooManyChannels),
_ => None,
}
}