use async_trait::async_trait;
use iroh::{
endpoint::{presets, Connection},
protocol::{AcceptError, ProtocolHandler, Router},
Endpoint, EndpointAddr, EndpointId, SecretKey,
};
use tracing::{debug, warn};
use crate::endpoint::{MaEndpoint, DEFAULT_INBOX_CAPACITY};
use crate::error::{Error, Result};
use crate::inbox::Inbox;
use crate::iroh::channel::Channel;
use crate::outbox::Outbox;
use crate::transport::transport_string;
use crate::{Document, Message};
use std::collections::BTreeMap;
use std::time::{SystemTime, UNIX_EPOCH};
const DEFAULT_MAX_INBOUND_MESSAGE_SIZE: usize = 1024 * 1024;
pub struct IrohEndpoint {
endpoint: Endpoint,
protocols: Vec<String>,
inboxes: BTreeMap<String, Inbox<Message>>,
router: Option<Router>,
}
impl IrohEndpoint {
pub async fn new(secret_bytes: [u8; 32]) -> Result<Self> {
let secret = SecretKey::from_bytes(&secret_bytes);
let endpoint = Endpoint::builder(presets::N0)
.secret_key(secret)
.bind()
.await
.map_err(|e| Error::Transport(format!("endpoint bind failed: {e}")))?;
endpoint.online().await;
debug!(
endpoint_id = %endpoint.id(),
"iroh endpoint online"
);
Ok(Self {
endpoint,
protocols: Vec::new(),
inboxes: BTreeMap::new(),
router: None,
})
}
pub fn inner(&self) -> &Endpoint {
&self.endpoint
}
pub fn into_inner(self) -> Endpoint {
self.endpoint
}
pub fn endpoint_id(&self) -> EndpointId {
self.endpoint.id()
}
pub async fn open(&self, target: &str, protocol: &str) -> Result<Channel> {
let addr = self.resolve_addr(target)?;
self.open_addr(addr, protocol).await
}
async fn open_addr(&self, addr: EndpointAddr, protocol: &str) -> Result<Channel> {
let connection = self
.endpoint
.connect(addr, protocol.as_bytes())
.await
.map_err(|e| Error::Transport(format!("connect failed: {e}")))?;
let (send, _recv) = connection
.open_bi()
.await
.map_err(|e| Error::Transport(format!("open_bi failed: {e}")))?;
Ok(Channel::new(connection, send))
}
pub async fn close(self) {
if let Some(router) = self.router {
let _ = router.shutdown().await;
return;
}
self.endpoint.close().await;
}
pub fn start_router(&mut self) {
if self.router.is_some() {
return;
}
let mut builder = Router::builder(self.endpoint.clone());
for protocol in &self.protocols {
if let Some(inbox) = self.inboxes.get(protocol) {
let handler = InboxProtocolHandler::new(protocol.clone(), inbox.clone());
builder = builder.accept(protocol.as_bytes(), handler);
}
}
self.router = Some(builder.spawn());
}
pub fn remove_service(&mut self, protocol: &str) -> bool {
let normalized = normalize_protocol(protocol);
let removed = self.inboxes.remove(&normalized).is_some();
if !removed {
return false;
}
self.protocols.retain(|p| p != &normalized);
self.reload_router_if_running();
true
}
fn reload_router_if_running(&mut self) {
if self.router.is_none() {
return;
}
self.router.take();
self.start_router();
}
fn resolve_addr(&self, endpoint_id: &str) -> Result<EndpointAddr> {
let target_id: EndpointId = endpoint_id
.trim()
.parse()
.map_err(|e| Error::Transport(format!("invalid endpoint id: {e}")))?;
let mut addr = EndpointAddr::new(target_id);
if let Some(relay_url) = self.endpoint.addr().relay_urls().next() {
addr = addr.with_relay_url(relay_url.clone());
}
Ok(addr)
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl MaEndpoint for IrohEndpoint {
fn id(&self) -> String {
self.endpoint.id().to_string()
}
fn service(&mut self, protocol: &str) -> Inbox<Message> {
let normalized = normalize_protocol(protocol);
if !self.protocols.contains(&normalized) {
self.protocols.push(normalized.clone());
}
if let Some(existing) = self.inboxes.get(&normalized) {
return existing.clone();
}
let inbox = Inbox::new(DEFAULT_INBOX_CAPACITY);
self.inboxes.insert(normalized, inbox.clone());
if self.router.is_some() {
self.reload_router_if_running();
} else {
self.start_router();
}
inbox
}
fn services(&self) -> Vec<String> {
let id = self.endpoint.id().to_string();
self.protocols
.iter()
.map(|proto| transport_string(&id, proto))
.collect()
}
async fn connect_outbox(
&self,
_doc: &Document,
endpoint_id: &str,
did: &str,
protocol: &str,
) -> Result<Outbox> {
let addr = self.resolve_addr(endpoint_id)?;
let channel = self.open_addr(addr, protocol).await?;
Ok(Outbox::from_transport(
channel,
did.to_string(),
protocol.to_string(),
))
}
async fn send_to(&self, target: &str, protocol: &str, message: &Message) -> Result<()> {
message.headers().validate()?;
let cbor = message.to_cbor()?;
let mut channel = self.open(target, protocol).await?;
channel.send(&cbor).await?;
channel.close();
Ok(())
}
}
#[derive(Debug, Clone)]
struct InboxProtocolHandler {
protocol: String,
inbox: Inbox<Message>,
max_message_size: usize,
}
impl InboxProtocolHandler {
fn new(protocol: String, inbox: Inbox<Message>) -> Self {
Self {
protocol,
inbox,
max_message_size: DEFAULT_MAX_INBOUND_MESSAGE_SIZE,
}
}
}
impl ProtocolHandler for InboxProtocolHandler {
async fn accept(&self, connection: Connection) -> std::result::Result<(), AcceptError> {
loop {
let (mut send, mut recv) = match connection.accept_bi().await {
Ok(streams) => streams,
Err(err) => {
debug!(
protocol = %self.protocol,
remote = %connection.remote_id(),
error = %err,
"inbound connection closed"
);
break;
}
};
let payload = match recv.read_to_end(self.max_message_size).await {
Ok(payload) => payload,
Err(err) => {
warn!(
protocol = %self.protocol,
remote = %connection.remote_id(),
error = %err,
"failed to read inbound stream"
);
let _ = send.finish();
continue;
}
};
let _ = send.finish();
let message = match Message::from_cbor(&payload) {
Ok(message) => message,
Err(err) => {
warn!(
protocol = %self.protocol,
remote = %connection.remote_id(),
error = %err,
"invalid inbound message payload"
);
continue;
}
};
if let Err(err) = message.headers().validate() {
warn!(
protocol = %self.protocol,
remote = %connection.remote_id(),
error = %err,
"invalid inbound message headers"
);
continue;
}
let expires_at = if message.ttl == 0 {
0
} else {
message_created_at_secs(message.created_at).saturating_add(message.ttl)
};
self.inbox.push(now_secs(), expires_at, message);
}
Ok(())
}
}
fn normalize_protocol(input: &str) -> String {
let protocol = input.trim();
if protocol.is_empty() {
return String::new();
}
format!("/{}", protocol.trim_start_matches('/'))
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before UNIX epoch")
.as_secs()
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::cast_sign_loss
)]
fn message_created_at_secs(created_at: f64) -> u64 {
if !created_at.is_finite() || created_at <= 0.0 {
0
} else if created_at >= u64::MAX as f64 {
u64::MAX
} else {
created_at.floor() as u64
}
}
#[cfg(test)]
mod tests {
use crate::{Did, Document};
use std::collections::BTreeMap;
fn test_doc() -> Document {
let did = Did::new_url(
"k51qzi5uqu5dj9807pbuod1pplf0vxh8m4lfy3ewl9qbm2s8dsf9ugdf9gedhr",
None::<String>,
)
.expect("valid did");
Document::new(&did, &did)
}
fn test_secret() -> [u8; 32] {
let mut bytes = [0u8; 32];
bytes[0] = 42;
bytes
}
fn test_message() -> crate::Message {
use crate::{Did, SigningKey};
let did =
Did::new_identity("k51qzi5uqu5dkkciu33khkzbcmxtyhn376i1e83tya8kuy7z9euedzyr5nhoew")
.expect("valid did");
let did_id = did.id();
let sk = SigningKey::generate(did).expect("signing key");
crate::Message::new(
did_id,
String::new(),
crate::service::CONTENT_TYPE_BROADCAST,
b"test".to_vec(),
&sk,
)
.expect("message")
}
#[tokio::test]
#[ignore = "requires iroh network runtime"]
async fn service_returns_shared_inbox() {
use super::IrohEndpoint;
use crate::endpoint::MaEndpoint;
let mut endpoint = IrohEndpoint::new(test_secret()).await.unwrap();
let inbox_a = endpoint.service("/ma/inbox/0.0.1");
let inbox_b = endpoint.service("/ma/inbox/0.0.1");
inbox_a.push(0, 0, test_message());
assert_eq!(inbox_b.len(), 1, "cloned inbox should share the same queue");
endpoint.close().await;
}
#[tokio::test]
#[ignore = "requires iroh network runtime"]
async fn service_auto_starts_router() {
use super::IrohEndpoint;
use crate::endpoint::MaEndpoint;
let mut endpoint = IrohEndpoint::new(test_secret()).await.unwrap();
assert!(endpoint.router.is_none(), "router should start stopped");
endpoint.service("/ma/inbox/0.0.1");
assert!(
endpoint.router.is_some(),
"router should auto-start on first service registration"
);
endpoint.close().await;
}
#[tokio::test]
#[ignore = "requires iroh network runtime"]
async fn remove_service_updates_protocol_list() {
use super::IrohEndpoint;
use crate::endpoint::MaEndpoint;
let mut endpoint = IrohEndpoint::new(test_secret()).await.unwrap();
let _inbox = endpoint.service("/ma/custom/1.0");
assert!(endpoint
.services()
.iter()
.any(|s| s.contains("/ma/custom/1.0")));
let removed = endpoint.remove_service("/ma/custom/1.0");
assert!(
removed,
"remove_service should return true for registered protocol"
);
assert!(
endpoint
.services()
.iter()
.all(|s| !s.contains("/ma/custom/1.0")),
"protocol should be absent from services after removal"
);
endpoint.close().await;
}
#[tokio::test]
#[ignore = "requires iroh network runtime"]
async fn service_after_start_router_triggers_reload() {
use super::IrohEndpoint;
use crate::endpoint::MaEndpoint;
let mut endpoint = IrohEndpoint::new(test_secret()).await.unwrap();
endpoint.service("/ma/inbox/0.0.1");
endpoint.start_router();
assert!(
endpoint.router.is_some(),
"router should be running after start_router"
);
endpoint.service("/ma/custom/1.0");
assert!(
endpoint.router.is_some(),
"router should still be running after service addition"
);
assert!(
endpoint
.services()
.iter()
.any(|s| s.contains("/ma/custom/1.0")),
"new service should appear in services() after hot-add"
);
endpoint.close().await;
}
}