use {
super::{
LocalNode,
PeerId,
error::{
Cancelled,
InvalidAlpn,
ProtocolViolation,
Success,
UnexpectedClose,
},
},
crate::{
Datum,
Digest,
primitives::{Bytes, Short},
},
core::{fmt, marker::PhantomData},
futures::{FutureExt, SinkExt, StreamExt},
iroh::{EndpointAddr, endpoint::*, protocol::AcceptError as IrohAcceptError},
n0_error::Meta,
std::io,
tokio_util::{
codec::{FramedRead, FramedWrite, LengthDelimitedCodec},
sync::CancellationToken,
},
};
pub trait Protocol {
const ALPN: &'static [u8];
}
pub struct Link<P: Protocol> {
connection: Connection,
cancel: CancellationToken,
sender: FramedWrite<SendStream, LengthDelimitedCodec>,
receiver: FramedRead<RecvStream, LengthDelimitedCodec>,
_protocol: PhantomData<P>,
}
impl<P: Protocol> Link<P> {
pub fn join(sender: LinkSender<P>, receiver: LinkReceiver<P>) -> Self {
assert_eq!(
sender.connection.stable_id(),
receiver.connection.stable_id(),
"sender and receiver must belong to the same connection",
);
Self {
connection: sender.connection,
sender: sender.sender,
receiver: receiver.receiver,
cancel: sender.cancel,
_protocol: PhantomData,
}
}
pub fn split(self) -> (LinkSender<P>, LinkReceiver<P>) {
let Self {
connection,
sender,
receiver,
cancel,
..
} = self;
(
LinkSender {
connection: connection.clone(),
sender,
cancel: cancel.clone(),
_protocol: PhantomData,
},
LinkReceiver {
connection,
receiver,
cancel,
_protocol: PhantomData,
},
)
}
pub async fn accept_with_cancel(
connection: Connection,
cancel: CancellationToken,
) -> Result<Self, AcceptError> {
if P::ALPN != connection.alpn() {
let alpn = connection.alpn().to_vec();
if let Some(reason) = close_connection(&connection, InvalidAlpn).await {
return Err(reason.into());
}
return Err(AcceptError::InvalidAlpn {
expected: P::ALPN,
received: alpn,
});
}
let Some(accept_result) =
cancel.run_until_cancelled(connection.accept_bi()).await
else {
close_connection(&connection, Cancelled).await;
return Err(AcceptError::Cancelled);
};
let (tx, rx) = accept_result?;
let mut codec = LengthDelimitedCodec::new();
codec.set_max_frame_length(usize::MAX);
let sender = FramedWrite::new(tx, codec.clone());
let receiver = FramedRead::new(rx, codec);
Ok(Self {
connection,
sender,
receiver,
cancel,
_protocol: PhantomData,
})
}
#[allow(unused)]
pub async fn accept(connection: Connection) -> Result<Self, AcceptError> {
Self::accept_with_cancel(connection, CancellationToken::new()).await
}
pub async fn open_with_cancel(
local: &LocalNode,
remote: impl Into<EndpointAddr>,
cancel: CancellationToken,
) -> Result<Self, OpenError> {
let fut = local.endpoint().connect(remote.into(), P::ALPN);
let Some(connection) = cancel.run_until_cancelled(fut).await else {
return Err(OpenError::Cancelled);
};
let connection = connection?;
let Some(open_result) =
cancel.run_until_cancelled(connection.open_bi()).await
else {
close_connection(&connection, Cancelled).await;
return Err(OpenError::Cancelled);
};
let (tx, rx) = match open_result {
Ok(streams) => streams,
Err(err) => {
close_connection(&connection, Cancelled).await;
return Err(err.into());
}
};
let mut codec = LengthDelimitedCodec::new();
codec.set_max_frame_length(usize::MAX);
let sender = FramedWrite::new(tx, codec.clone());
let receiver = FramedRead::new(rx, codec);
Ok(Self {
connection,
sender,
receiver,
cancel,
_protocol: PhantomData,
})
}
#[allow(unused)]
pub async fn open(
local: &LocalNode,
remote: impl Into<EndpointAddr>,
) -> Result<Self, OpenError> {
Self::open_with_cancel(local, remote, CancellationToken::new()).await
}
#[expect(clippy::unused_self)]
pub const fn alpn(&self) -> &[u8] {
P::ALPN
}
pub fn remote_id(&self) -> PeerId {
self.connection.remote_id()
}
pub const fn connection(&self) -> &Connection {
&self.connection
}
pub fn is_cancelled(&self) -> bool {
self.cancel.is_cancelled()
}
pub async fn recv<D: Datum>(&mut self) -> Result<D, RecvError> {
self.recv_with_size().await.map(|(d, _)| d)
}
pub async fn recv_with_size<D: Datum>(
&mut self,
) -> Result<(D, usize), RecvError> {
let Some(frame) =
self.cancel.run_until_cancelled(self.receiver.next()).await
else {
close_connection(&self.connection, Cancelled).await;
return Err(RecvError::Cancelled);
};
let Some(read_result) = frame else {
let Some(reason) =
close_connection(&self.connection, UnexpectedClose).await
else {
return Err(RecvError::closed(UnexpectedClose));
};
return Err(reason.into());
};
let bytes = match read_result {
Ok(bytes) => bytes,
Err(err) => match err.downcast::<ReadError>() {
Ok(read_err) => return Err(RecvError::Io(read_err)),
Err(other_err) => return Err(RecvError::Unknown(other_err)),
},
};
let decoded = match D::decode(&bytes) {
Ok(datum) => datum,
Err(err) => {
close_connection(&self.connection, ProtocolViolation).await;
return Err(RecvError::Decode(Box::new(err)));
}
};
Ok((decoded, bytes.len()))
}
pub async fn send<D: Datum>(
&mut self,
datum: &D,
) -> Result<usize, SendError> {
unsafe {
self
.send_raw(datum.encode().map_err(|e| SendError::Encode(Box::new(e)))?)
.await
}
}
pub async unsafe fn send_raw(
&mut self,
bytes: Bytes,
) -> Result<usize, SendError> {
let msg_len = bytes.len();
let fut = self.sender.send(bytes);
let Some(send_result) = self.cancel.run_until_cancelled(fut).await else {
close_connection(&self.connection, Cancelled).await;
return Err(SendError::Cancelled);
};
match send_result {
Ok(()) => Ok(msg_len),
Err(err) => match err.downcast::<WriteError>() {
Ok(io_err) => Err(SendError::Io(io_err)),
Err(other_err) => Err(SendError::Unknown(other_err)),
},
}
}
pub async fn close(
mut self,
reason: impl Into<ApplicationClose>,
) -> Result<(), CloseError> {
if let Some(reason) = self.connection().close_reason() {
return Err(CloseError::AlreadyClosed(reason));
}
let reason: ApplicationClose = reason.into();
self.connection().close(reason.error_code, &reason.reason);
let _ = self.cancel.run_until_cancelled(self.sender.flush()).await;
let _ = self.cancel.run_until_cancelled(self.sender.close()).await;
let close_result = self
.cancel
.run_until_cancelled(self.connection().closed())
.await;
match close_result {
None => Err(CloseError::Cancelled),
Some(ConnectionError::LocallyClosed) => Ok(()),
Some(reason) => Err(CloseError::UnexpectedReason(reason)),
}
}
pub fn closed(
&self,
) -> impl Future<Output = Result<(), ConnectionError>> + Send + Sync + 'static
{
let cancel = self.cancel.clone();
let connection = self.connection.clone();
async move {
match cancel.run_until_cancelled(connection.closed()).await {
None | Some(ConnectionError::LocallyClosed) => Ok(()),
Some(ConnectionError::ApplicationClosed(reason))
if reason == Success =>
{
Ok(())
}
Some(err) => Err(err),
}
}
.fuse()
}
pub fn replace_cancel_token(&mut self, cancel: CancellationToken) {
self.cancel = cancel;
}
pub fn shared_random(&self, label: impl AsRef<[u8]>) -> Digest {
let mut shared_secret = [0u8; 32];
self
.connection()
.export_keying_material(&mut shared_secret, label.as_ref(), self.alpn())
.expect("exporting keying material should not fail for this buffer len");
Digest::from_bytes(shared_secret)
}
}
async fn close_connection(
connection: &Connection,
reason: impl Into<ApplicationClose>,
) -> Option<ConnectionError> {
let reason = reason.into();
connection.close(reason.error_code, &reason.reason);
match connection.closed().await {
ConnectionError::LocallyClosed => None,
err => Some(err),
}
}
#[derive(Debug, thiserror::Error)]
pub enum LinkError {
#[error("{0}")]
Open(OpenError),
#[error("{0}")]
Accept(AcceptError),
#[error("{0}")]
Recv(RecvError),
#[error("{0}")]
Write(SendError),
#[error("{0}")]
Close(CloseError),
#[error("Operation cancelled")]
Cancelled,
}
#[derive(Debug, thiserror::Error)]
pub enum OpenError {
#[error("{0}")]
Io(#[from] ConnectError),
#[error("Operation cancelled")]
Cancelled,
}
#[derive(Debug, thiserror::Error)]
pub enum AcceptError {
#[error("{0}")]
Io(#[from] IrohAcceptError),
#[error("Invalid ALPN: expected {expected:?}, received {received:?}")]
InvalidAlpn {
expected: &'static [u8],
received: Vec<u8>,
},
#[error("Operation cancelled")]
Cancelled,
}
#[derive(Debug, thiserror::Error)]
pub enum RecvError {
#[error("{0}")]
Io(#[from] ReadError),
#[error("{0}")]
Decode(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("{0}")]
Unknown(#[from] io::Error),
#[error("Operation cancelled")]
Cancelled,
}
#[derive(Debug, thiserror::Error)]
pub enum SendError {
#[error("{0}")]
Encode(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("{0}")]
Io(#[from] WriteError),
#[error("{0}")]
Unknown(#[from] io::Error),
#[error("Operation cancelled")]
Cancelled,
}
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum CloseError {
#[error("{0}")]
AlreadyClosed(ConnectionError),
#[error("{0}")]
UnexpectedReason(ConnectionError),
#[error("Operation cancelled")]
Cancelled,
}
impl AcceptError {
pub const fn close_reason(&self) -> Option<&ApplicationClose> {
match self {
Self::Io(
IrohAcceptError::Connecting {
source:
ConnectingError::ConnectionError {
source: ConnectionError::ApplicationClosed(reason),
..
},
..
}
| IrohAcceptError::Connection {
source: ConnectionError::ApplicationClosed(reason),
..
},
) => Some(reason),
_ => None,
}
}
pub const fn is_cancelled(&self) -> bool {
matches!(self, Self::Cancelled)
}
}
impl From<CloseError> for AcceptError {
fn from(err: CloseError) -> Self {
match err {
CloseError::Cancelled => Self::Cancelled,
error @ CloseError::UnexpectedReason(_) => {
Self::Io(IrohAcceptError::from_err(error))
}
CloseError::AlreadyClosed(_) => Self::Io(IrohAcceptError::from_err(err)),
}
}
}
impl From<ApplicationClose> for AcceptError {
fn from(val: ApplicationClose) -> Self {
Self::Io(ConnectionError::ApplicationClosed(val).into())
}
}
impl From<ApplicationClose> for RecvError {
fn from(val: ApplicationClose) -> Self {
Self::Io(ReadError::ConnectionLost(
ConnectionError::ApplicationClosed(val),
))
}
}
impl From<ApplicationClose> for OpenError {
fn from(val: ApplicationClose) -> Self {
Self::Io(ConnectionError::ApplicationClosed(val).into())
}
}
impl RecvError {
pub fn closed(reason: impl Into<ApplicationClose>) -> Self {
Self::from(reason.into())
}
pub const fn close_reason(&self) -> Option<&ApplicationClose> {
match self {
Self::Io(ReadError::ConnectionLost(
ConnectionError::ApplicationClosed(reason),
)) => Some(reason),
_ => None,
}
}
pub const fn is_cancelled(&self) -> bool {
matches!(self, Self::Cancelled)
}
}
impl From<ConnectionError> for RecvError {
fn from(err: ConnectionError) -> Self {
Self::Io(ReadError::ConnectionLost(err))
}
}
impl SendError {
pub const fn close_reason(&self) -> Option<&ApplicationClose> {
match self {
Self::Io(WriteError::ConnectionLost(
ConnectionError::ApplicationClosed(reason),
)) => Some(reason),
_ => None,
}
}
pub const fn is_cancelled(&self) -> bool {
matches!(self, Self::Cancelled)
}
}
impl From<AcceptError> for IrohAcceptError {
fn from(err: AcceptError) -> Self {
match err {
AcceptError::Io(e) => e,
error => Self::from_err(error),
}
}
}
impl CloseError {
pub const fn close_reason(&self) -> Option<&ApplicationClose> {
match self {
Self::UnexpectedReason(ConnectionError::ApplicationClosed(reason)) => {
Some(reason)
}
_ => None,
}
}
pub const fn is_cancelled(&self) -> bool {
matches!(self, Self::Cancelled)
}
pub const fn was_already_closed(&self) -> bool {
matches!(self, Self::AlreadyClosed(_))
}
}
impl LinkError {
pub const fn is_cancelled(&self) -> bool {
match self {
Self::Open(err) => err.is_cancelled(),
Self::Accept(err) => err.is_cancelled(),
Self::Recv(err) => err.is_cancelled(),
Self::Write(err) => err.is_cancelled(),
Self::Close(err) => err.is_cancelled(),
Self::Cancelled => true,
}
}
pub const fn close_reason(&self) -> Option<&ApplicationClose> {
match self {
Self::Accept(err) => err.close_reason(),
Self::Open(err) => err.close_reason(),
Self::Recv(err) => err.close_reason(),
Self::Write(err) => err.close_reason(),
Self::Close(err) => err.close_reason(),
Self::Cancelled => None,
}
}
}
impl From<OpenError> for LinkError {
fn from(err: OpenError) -> Self {
match err {
OpenError::Cancelled => Self::Cancelled,
error @ OpenError::Io(_) => Self::Open(error),
}
}
}
impl From<AcceptError> for LinkError {
fn from(err: AcceptError) -> Self {
match err {
AcceptError::Cancelled => Self::Cancelled,
error => Self::Accept(error),
}
}
}
impl From<RecvError> for LinkError {
fn from(err: RecvError) -> Self {
match err {
RecvError::Cancelled => Self::Cancelled,
error => Self::Recv(error),
}
}
}
impl From<SendError> for LinkError {
fn from(err: SendError) -> Self {
match err {
SendError::Cancelled => Self::Cancelled,
error => Self::Write(error),
}
}
}
impl From<CloseError> for LinkError {
fn from(err: CloseError) -> Self {
match err {
CloseError::Cancelled => Self::Cancelled,
error => Self::Close(error),
}
}
}
impl From<ConnectionError> for OpenError {
fn from(err: ConnectionError) -> Self {
Self::Io(ConnectError::Connection {
source: err,
meta: Meta::default(),
})
}
}
impl From<tokio::time::error::Elapsed> for LinkError {
fn from(_: tokio::time::error::Elapsed) -> Self {
Self::Cancelled
}
}
impl OpenError {
pub const fn close_reason(&self) -> Option<&ApplicationClose> {
match self {
Self::Io(
ConnectError::Connecting {
source:
ConnectingError::ConnectionError {
source: ConnectionError::ApplicationClosed(reason),
..
},
..
}
| ConnectError::Connection {
source: ConnectionError::ApplicationClosed(reason),
..
},
) => Some(reason),
_ => None,
}
}
pub const fn is_cancelled(&self) -> bool {
matches!(self, Self::Cancelled)
}
}
impl From<ConnectionError> for AcceptError {
fn from(err: ConnectionError) -> Self {
Self::Io(IrohAcceptError::from(err))
}
}
impl<P: Protocol> fmt::Debug for Link<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Link")
.field("alpn", &String::from_utf8_lossy(self.alpn()))
.field("remote_id", &self.remote_id())
.finish_non_exhaustive()
}
}
impl<P: Protocol> fmt::Display for Link<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Link<{}:{}>",
String::from_utf8_lossy(self.alpn()),
Short(self.remote_id())
)
}
}
pub struct LinkSender<P: Protocol> {
connection: Connection,
cancel: CancellationToken,
sender: FramedWrite<SendStream, LengthDelimitedCodec>,
_protocol: PhantomData<P>,
}
impl<P: Protocol> LinkSender<P> {
pub async fn send<D: Datum>(
&mut self,
datum: &D,
) -> Result<usize, SendError> {
unsafe {
self
.send_raw(datum.encode().map_err(|e| SendError::Encode(Box::new(e)))?)
.await
}
}
pub async unsafe fn send_raw(
&mut self,
bytes: Bytes,
) -> Result<usize, SendError> {
let msg_len = bytes.len();
let fut = self.sender.send(bytes);
let Some(send_result) = self.cancel.run_until_cancelled(fut).await else {
close_connection(&self.connection, Cancelled).await;
return Err(SendError::Cancelled);
};
match send_result {
Ok(()) => Ok(msg_len),
Err(err) => match err.downcast::<WriteError>() {
Ok(io_err) => Err(SendError::Io(io_err)),
Err(other_err) => Err(SendError::Unknown(other_err)),
},
}
}
}
pub struct LinkReceiver<P: Protocol> {
connection: Connection,
cancel: CancellationToken,
receiver: FramedRead<RecvStream, LengthDelimitedCodec>,
_protocol: PhantomData<P>,
}
impl<P: Protocol> LinkReceiver<P> {
pub async fn recv<D: Datum>(&mut self) -> Result<D, RecvError> {
self.recv_with_size().await.map(|(d, _)| d)
}
pub async fn recv_with_size<D: Datum>(
&mut self,
) -> Result<(D, usize), RecvError> {
let Some(frame) =
self.cancel.run_until_cancelled(self.receiver.next()).await
else {
close_connection(&self.connection, Cancelled).await;
return Err(RecvError::Cancelled);
};
let Some(read_result) = frame else {
let Some(reason) =
close_connection(&self.connection, UnexpectedClose).await
else {
return Err(RecvError::closed(UnexpectedClose));
};
return Err(reason.into());
};
let bytes = match read_result {
Ok(bytes) => bytes,
Err(err) => match err.downcast::<ReadError>() {
Ok(read_err) => return Err(RecvError::Io(read_err)),
Err(other_err) => return Err(RecvError::Unknown(other_err)),
},
};
let decoded = match D::decode(&bytes) {
Ok(datum) => datum,
Err(err) => {
close_connection(&self.connection, ProtocolViolation).await;
return Err(RecvError::Decode(Box::new(err)));
}
};
Ok((decoded, bytes.len()))
}
}