use iroh::{Endpoint, EndpointAddr, endpoint::presets, protocol::Router};
use n0_error::{Result, StdResultExt};
use crate::echo::Echo;
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let server_router = accept_side(b"secret!!").await?;
server_router.endpoint().online().await;
let server_addr = server_router.endpoint().addr();
println!("-- no --");
let res = connect_side_no_auth(server_addr.clone()).await;
println!("echo without auth: {:#}", res.unwrap_err());
println!("-- wrong --");
let res = connect_side(server_addr.clone(), b"dunno").await;
println!("echo with wrong auth: {:#}", res.unwrap_err());
println!("-- correct --");
let res = connect_side(server_addr.clone(), b"secret!!").await;
println!("echo with correct auth: {res:?}");
server_router.shutdown().await.anyerr()?;
Ok(())
}
async fn connect_side(remote_addr: EndpointAddr, token: &[u8]) -> Result<()> {
let (auth_hook, auth_task) = auth::outgoing(token.to_vec());
let endpoint = Endpoint::builder(presets::N0)
.hooks(auth_hook)
.bind()
.await?;
let _guard = auth_task.spawn(endpoint.clone());
Echo::connect(&endpoint, remote_addr, b"hello there!").await
}
async fn connect_side_no_auth(remote_addr: EndpointAddr) -> Result<()> {
let endpoint = Endpoint::bind(presets::N0).await?;
Echo::connect(&endpoint, remote_addr, b"hello there!").await
}
async fn accept_side(token: &[u8]) -> Result<Router> {
let (auth_hook, auth_protocol) = auth::incoming(token.to_vec());
let endpoint = Endpoint::builder(presets::N0)
.hooks(auth_hook)
.bind()
.await?;
let router = Router::builder(endpoint)
.accept(auth::ALPN, auth_protocol)
.accept(echo::ALPN, Echo)
.spawn();
Ok(router)
}
mod echo {
use iroh::{
Endpoint, EndpointAddr,
endpoint::Connection,
protocol::{AcceptError, ProtocolHandler},
};
use n0_error::{Result, StdResultExt, anyerr};
#[derive(Debug, Clone)]
pub struct Echo;
pub const ALPN: &[u8] = b"iroh-example/echo/0";
impl Echo {
pub async fn connect(
endpoint: &Endpoint,
remote: impl Into<EndpointAddr>,
message: &[u8],
) -> Result<()> {
let conn = endpoint.connect(remote, ALPN).await?;
let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
send.write_all(message).await.anyerr()?;
send.finish().anyerr()?;
let response = recv.read_to_end(1000).await.anyerr()?;
conn.close(0u32.into(), b"bye!");
if response == message {
Ok(())
} else {
Err(anyerr!("Received invalid response"))
}
}
}
impl ProtocolHandler for Echo {
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
let (mut send, mut recv) = connection.accept_bi().await?;
tokio::io::copy(&mut recv, &mut send).await?;
send.finish()?;
connection.closed().await;
Ok(())
}
}
}
mod auth {
use std::{
collections::{HashMap, HashSet, hash_map},
sync::{Arc, Mutex},
};
use iroh::{
Endpoint, EndpointAddr, EndpointId,
endpoint::{
AfterHandshakeOutcome, BeforeConnectOutcome, Connection, ConnectionError, EndpointHooks,
},
protocol::{AcceptError, ProtocolHandler},
};
use n0_error::{AnyError, Result, StackResultExt, StdResultExt, anyerr};
use n0_future::task::AbortOnDropHandle;
use tokio::{
sync::{mpsc, oneshot},
task::JoinSet,
};
use tracing::debug;
pub const ALPN: &[u8] = b"iroh-example/auth/0";
const CLOSE_ACCEPTED: u32 = 1;
const CLOSE_DENIED: u32 = 403;
pub fn outgoing(token: Vec<u8>) -> (OutgoingAuthHook, OutgoingAuthTask) {
let (tx, rx) = mpsc::channel(16);
let hook = OutgoingAuthHook { tx };
let connector = OutgoingAuthTask {
token,
rx,
allowed_remotes: Default::default(),
pending_remotes: Default::default(),
tasks: JoinSet::new(),
};
(hook, connector)
}
type AuthResult = Result<(), Arc<AnyError>>;
#[derive(Debug)]
pub struct OutgoingAuthHook {
tx: mpsc::Sender<(EndpointId, oneshot::Sender<AuthResult>)>,
}
impl OutgoingAuthHook {
async fn authenticate(&self, remote_id: EndpointId) -> Result<()> {
let (tx, rx) = oneshot::channel();
self.tx
.send((remote_id, tx))
.await
.std_context("authenticator stopped")?;
rx.await
.std_context("authenticator stopped")?
.context("failed to authenticate")
}
}
impl EndpointHooks for OutgoingAuthHook {
async fn before_connect<'a>(
&'a self,
remote_addr: &'a EndpointAddr,
alpn: &'a [u8],
) -> BeforeConnectOutcome {
if alpn == ALPN {
BeforeConnectOutcome::Accept
} else {
match self.authenticate(remote_addr.id).await {
Ok(()) => BeforeConnectOutcome::Accept,
Err(err) => {
debug!("authentication denied: {err:#}");
BeforeConnectOutcome::Reject
}
}
}
}
}
pub struct OutgoingAuthTask {
token: Vec<u8>,
rx: mpsc::Receiver<(EndpointId, oneshot::Sender<AuthResult>)>,
allowed_remotes: HashSet<EndpointId>,
pending_remotes: HashMap<EndpointId, Vec<oneshot::Sender<AuthResult>>>,
tasks: JoinSet<(EndpointId, Result<()>)>,
}
impl OutgoingAuthTask {
pub fn spawn(self, endpoint: Endpoint) -> AbortOnDropHandle<()> {
AbortOnDropHandle::new(tokio::spawn(self.run(endpoint)))
}
async fn run(mut self, endpoint: Endpoint) {
loop {
tokio::select! {
msg = self.rx.recv() => {
let Some((remote_id, tx)) = msg else {
break;
};
self.handle_msg(&endpoint, remote_id, tx);
}
Some(res) = self.tasks.join_next(), if !self.tasks.is_empty() => {
let (remote_id, res) = res.expect("connect task panicked");
let res = res.map_err(Arc::new);
self.handle_task(remote_id, res);
}
}
}
}
fn handle_msg(
&mut self,
endpoint: &Endpoint,
remote_id: EndpointId,
tx: oneshot::Sender<Result<(), Arc<AnyError>>>,
) {
if self.allowed_remotes.contains(&remote_id) {
tx.send(Ok(())).ok();
} else {
match self.pending_remotes.entry(remote_id) {
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().push(tx);
}
hash_map::Entry::Vacant(entry) => {
let endpoint = endpoint.clone();
let token = self.token.clone();
self.tasks.spawn(async move {
let res = Self::connect(endpoint, remote_id, token).await;
(remote_id, res)
});
entry.insert(vec![tx]);
}
}
}
}
fn handle_task(&mut self, remote_id: EndpointId, res: Result<(), Arc<AnyError>>) {
if res.is_ok() {
self.allowed_remotes.insert(remote_id);
}
let senders = self.pending_remotes.remove(&remote_id);
for tx in senders.into_iter().flatten() {
tx.send(res.clone()).ok();
}
}
async fn connect(endpoint: Endpoint, remote_id: EndpointId, token: Vec<u8>) -> Result<()> {
let conn = endpoint.connect(remote_id, ALPN).await?;
let mut stream = conn.open_uni().await.anyerr()?;
stream.write_all(&token).await.anyerr()?;
stream.finish().anyerr()?;
let reason = conn.closed().await;
if let ConnectionError::ApplicationClosed(code) = &reason
&& code.error_code.into_inner() as u32 == CLOSE_ACCEPTED
{
Ok(())
} else if let ConnectionError::ApplicationClosed(code) = &reason
&& code.error_code.into_inner() as u32 == CLOSE_DENIED
{
Err(anyerr!("authentication denied by remote"))
} else {
Err(AnyError::from_std(reason))
}
}
}
pub fn incoming(token: Vec<u8>) -> (IncomingAuthHook, AuthProtocol) {
let allowed_remotes: Arc<Mutex<HashSet<EndpointId>>> = Default::default();
let hook = IncomingAuthHook {
allowed_remotes: allowed_remotes.clone(),
};
let protocol = AuthProtocol {
allowed_remotes,
token,
};
(hook, protocol)
}
#[derive(Debug)]
pub struct IncomingAuthHook {
allowed_remotes: Arc<Mutex<HashSet<EndpointId>>>,
}
impl EndpointHooks for IncomingAuthHook {
async fn after_handshake<'a>(
&'a self,
conn: &'a iroh::endpoint::ConnectionInfo,
) -> AfterHandshakeOutcome {
if conn.alpn() == ALPN
|| self
.allowed_remotes
.lock()
.expect("poisoned")
.contains(&conn.remote_id())
{
AfterHandshakeOutcome::Accept
} else {
AfterHandshakeOutcome::Reject {
error_code: 403u32.into(),
reason: b"not authenticated".to_vec(),
}
}
}
}
#[derive(Debug, Clone)]
pub struct AuthProtocol {
token: Vec<u8>,
allowed_remotes: Arc<Mutex<HashSet<EndpointId>>>,
}
impl ProtocolHandler for AuthProtocol {
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
let mut stream = connection.accept_uni().await?;
let token = stream.read_to_end(256).await.anyerr()?;
let remote_id = connection.remote_id();
if token == self.token {
self.allowed_remotes
.lock()
.expect("poisoned")
.insert(remote_id);
connection.close(CLOSE_ACCEPTED.into(), b"accepted");
} else {
connection.close(CLOSE_DENIED.into(), b"rejected");
}
Ok(())
}
}
}