#![allow(clippy::disallowed_types)]
use iroh::endpoint::Connection;
use serde::Serialize;
use crate::{
ffi::handles::{HandleStore, SessionEntry},
ffi::pumps::{pump_body_to_quic_send, pump_quic_recv_to_body},
parse_node_addr, CoreError, FfiDuplexStream, IrohEndpoint, ALPN_DUPLEX,
};
fn is_connection_closed(err: &iroh::endpoint::ConnectionError) -> bool {
use iroh::endpoint::ConnectionError::*;
matches!(
err,
ApplicationClosed(_) | ConnectionClosed(_) | Reset | TimedOut | LocallyClosed
)
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CloseInfo {
pub close_code: u64,
pub reason: String,
}
#[derive(Clone)]
pub struct Session {
endpoint: IrohEndpoint,
handle: u64,
}
impl Session {
pub fn from_handle(endpoint: IrohEndpoint, handle: u64) -> Self {
Self { endpoint, handle }
}
pub fn handle(&self) -> u64 {
self.handle
}
pub async fn connect(
endpoint: IrohEndpoint,
remote_node_id: &str,
direct_addrs: Option<&[std::net::SocketAddr]>,
) -> Result<Self, CoreError> {
let parsed = parse_node_addr(remote_node_id)?;
let node_id = parsed.node_id;
let mut addr = iroh::EndpointAddr::new(node_id);
for a in &parsed.direct_addrs {
addr = addr.with_ip_addr(*a);
}
if let Some(addrs) = direct_addrs {
for a in addrs {
addr = addr.with_ip_addr(*a);
}
}
let conn = endpoint
.raw()
.connect(addr, ALPN_DUPLEX)
.await
.map_err(|e| CoreError::connection_failed(format!("connect session: {e}")))?;
let handle = endpoint.handles().insert_session(SessionEntry { conn })?;
Ok(Self { endpoint, handle })
}
pub async fn accept(endpoint: IrohEndpoint) -> Result<Option<Self>, CoreError> {
let incoming = match endpoint.raw().accept().await {
Some(inc) => inc,
None => return Ok(None),
};
let conn = incoming
.await
.map_err(|e| CoreError::connection_failed(format!("accept session: {e}")))?;
let handle = endpoint.handles().insert_session(SessionEntry { conn })?;
Ok(Some(Self { endpoint, handle }))
}
fn conn(&self) -> Result<Connection, CoreError> {
self.endpoint
.handles()
.lookup_session(self.handle)
.map(|s| s.conn.clone())
.ok_or_else(|| CoreError::invalid_handle(self.handle))
}
pub fn remote_id(&self) -> Result<iroh::PublicKey, CoreError> {
self.conn().map(|c| c.remote_id())
}
pub async fn create_bidi_stream(&self) -> Result<FfiDuplexStream, CoreError> {
let conn = self.conn()?;
let (send, recv) = conn
.open_bi()
.await
.map_err(|e| CoreError::connection_failed(format!("open_bi: {e}")))?;
wrap_bidi_stream(self.endpoint.handles(), send, recv)
}
pub async fn next_bidi_stream(&self) -> Result<Option<FfiDuplexStream>, CoreError> {
let conn = self.conn()?;
match conn.accept_bi().await {
Ok((send, recv)) => Ok(Some(wrap_bidi_stream(self.endpoint.handles(), send, recv)?)),
Err(e) if is_connection_closed(&e) => Ok(None),
Err(e) => Err(CoreError::connection_failed(format!("accept_bi: {e}"))),
}
}
pub fn close(&self, close_code: u64, reason: &str) -> Result<(), CoreError> {
let entry = self
.endpoint
.handles()
.remove_session(self.handle)
.ok_or_else(|| CoreError::invalid_handle(self.handle))?;
let code = iroh::endpoint::VarInt::from_u64(close_code).map_err(|_| {
CoreError::invalid_input(format!(
"close_code {close_code} exceeds QUIC VarInt max (2^62 - 1)"
))
})?;
entry.conn.close(code, reason.as_bytes());
Ok(())
}
pub async fn ready(&self) -> Result<(), CoreError> {
let _conn = self.conn()?;
Ok(())
}
pub async fn closed(&self) -> Result<CloseInfo, CoreError> {
let conn = self.conn()?;
let err = conn.closed().await;
self.endpoint.handles().remove_session(self.handle);
let (close_code, reason) = parse_connection_error(&err);
Ok(CloseInfo { close_code, reason })
}
pub async fn create_uni_stream(&self) -> Result<u64, CoreError> {
let conn = self.conn()?;
let send = conn
.open_uni()
.await
.map_err(|e| CoreError::connection_failed(format!("open_uni: {e}")))?;
let handles = self.endpoint.handles();
let (send_writer, send_reader) = handles.make_body_channel();
let write_handle = handles.insert_writer(send_writer)?;
tokio::spawn(pump_body_to_quic_send(send_reader, send));
Ok(write_handle)
}
pub async fn next_uni_stream(&self) -> Result<Option<u64>, CoreError> {
let conn = self.conn()?;
match conn.accept_uni().await {
Ok(recv) => {
let handles = self.endpoint.handles();
let (recv_writer, recv_reader) = handles.make_body_channel();
let read_handle = handles.insert_reader(recv_reader)?;
tokio::spawn(pump_quic_recv_to_body(recv, recv_writer));
Ok(Some(read_handle))
}
Err(e) if is_connection_closed(&e) => Ok(None),
Err(e) => Err(CoreError::connection_failed(format!("accept_uni: {e}"))),
}
}
pub fn send_datagram(&self, data: &[u8]) -> Result<(), CoreError> {
let conn = self.conn()?;
conn.send_datagram(bytes::Bytes::copy_from_slice(data))
.map_err(|e| match e {
iroh::endpoint::SendDatagramError::TooLarge => CoreError::body_too_large(
"datagram exceeds path MTU; check Session::max_datagram_size()",
),
_ => CoreError::internal(format!("send_datagram: {e}")),
})
}
pub async fn recv_datagram(&self) -> Result<Option<Vec<u8>>, CoreError> {
let conn = self.conn()?;
match conn.read_datagram().await {
Ok(data) => Ok(Some(data.to_vec())),
Err(e) if is_connection_closed(&e) => Ok(None),
Err(e) => Err(CoreError::connection_failed(format!("recv_datagram: {e}"))),
}
}
pub fn max_datagram_size(&self) -> Result<Option<usize>, CoreError> {
let conn = self.conn()?;
Ok(conn.max_datagram_size())
}
}
fn wrap_bidi_stream(
handles: &HandleStore,
send: iroh::endpoint::SendStream,
recv: iroh::endpoint::RecvStream,
) -> Result<FfiDuplexStream, CoreError> {
let mut guard = handles.insert_guard();
let (recv_writer, recv_reader) = handles.make_body_channel();
let read_handle = guard.insert_reader(recv_reader)?;
tokio::spawn(pump_quic_recv_to_body(recv, recv_writer));
let (send_writer, send_reader) = handles.make_body_channel();
let write_handle = guard.insert_writer(send_writer)?;
tokio::spawn(pump_body_to_quic_send(send_reader, send));
guard.commit();
Ok(FfiDuplexStream {
read_handle,
write_handle,
})
}
fn parse_connection_error(err: &iroh::endpoint::ConnectionError) -> (u64, String) {
match err {
iroh::endpoint::ConnectionError::ApplicationClosed(info) => {
let code: u64 = info.error_code.into();
let reason = String::from_utf8_lossy(&info.reason).into_owned();
(code, reason)
}
other => (0, other.to_string()),
}
}