use crate::error::{NexarError, Result};
use crate::protocol::NexarMessage;
use crate::protocol::codec::encode_message;
use crate::types::{Priority, Rank};
use futures::future::BoxFuture;
pub trait BulkTransport: Send + Sync + 'static {
fn send_bulk<'a>(&'a self, data: &'a [u8]) -> BoxFuture<'a, Result<()>>;
fn recv_bulk<'a>(&'a self, _expected_size: usize) -> BoxFuture<'a, Result<Vec<u8>>> {
Box::pin(async move {
Err(NexarError::transport(
"recv_bulk not supported by this transport",
))
})
}
}
pub(crate) const STREAM_TAG_FRAMED: u8 = 0x01;
pub(crate) const STREAM_TAG_RAW: u8 = 0x02;
pub(crate) const STREAM_TAG_RAW_COMM: u8 = 0x03;
pub(crate) const STREAM_TAG_RAW_TAGGED: u8 = 0x04;
pub struct PeerConnection {
pub rank: Rank,
pub(crate) conn: quinn::Connection,
stream_pool: super::stream_pool::StreamPool,
extensions: std::sync::RwLock<Vec<Box<dyn std::any::Any + Send + Sync>>>,
}
const STREAM_POOL_MAX_READY: usize = 8;
impl PeerConnection {
pub fn new(rank: Rank, conn: quinn::Connection) -> Self {
let stream_pool = super::stream_pool::StreamPool::new(conn.clone(), STREAM_POOL_MAX_READY);
Self {
rank,
conn,
stream_pool,
extensions: std::sync::RwLock::new(Vec::new()),
}
}
pub async fn warm_stream_pool(&self) {
let _ = self.stream_pool.refill().await;
}
pub fn add_extension<T: std::any::Any + Send + Sync + 'static>(
&self,
ext: T,
) -> crate::error::Result<()> {
let mut exts = self
.extensions
.write()
.map_err(|_| NexarError::LockPoisoned("extensions"))?;
exts.push(Box::new(ext));
Ok(())
}
pub fn extension<T: std::any::Any + Send + Sync + 'static>(
&self,
) -> crate::error::Result<Option<impl std::ops::Deref<Target = T> + '_>> {
let exts = self
.extensions
.read()
.map_err(|_| NexarError::LockPoisoned("extensions"))?;
let idx = match exts.iter().position(|e| e.downcast_ref::<T>().is_some()) {
Some(idx) => idx,
None => return Ok(None),
};
Ok(Some(ExtensionRef {
guard: exts,
idx,
_marker: std::marker::PhantomData,
}))
}
pub async fn send_message(&self, msg: &NexarMessage, priority: Priority) -> Result<()> {
let buf = encode_message(msg, priority)?;
self.send_framed(STREAM_TAG_FRAMED, &[], &buf).await
}
pub async fn send_raw(&self, data: &[u8]) -> Result<()> {
self.send_framed(STREAM_TAG_RAW, &[], data).await
}
pub async fn send_raw_comm(&self, comm_id: u64, data: &[u8]) -> Result<()> {
self.send_framed(STREAM_TAG_RAW_COMM, &comm_id.to_le_bytes(), data)
.await
}
pub async fn send_raw_best_effort(&self, data: &[u8]) -> Result<()> {
let bulk: Option<std::sync::Arc<dyn BulkTransport>> = self
.extension::<std::sync::Arc<dyn BulkTransport>>()?
.map(|b| std::sync::Arc::clone(&*b));
if let Some(bulk) = bulk {
match bulk.send_bulk(data).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::warn!(
peer = self.rank,
bytes = data.len(),
error = %e,
"bulk transport send failed, falling back to QUIC"
);
}
}
}
self.send_raw(data).await
}
pub async fn send_raw_tagged(&self, tag: u64, data: &[u8]) -> Result<()> {
self.send_framed(STREAM_TAG_RAW_TAGGED, &tag.to_le_bytes(), data)
.await
}
pub async fn send_raw_tagged_best_effort(&self, tag: u64, data: &[u8]) -> Result<()> {
let tagged_bulk: Option<std::sync::Arc<dyn super::TaggedBulkTransport>> = self
.extension::<std::sync::Arc<dyn super::TaggedBulkTransport>>()?
.map(|b| std::sync::Arc::clone(&*b));
if let Some(bulk) = tagged_bulk {
match bulk.send_bulk_tagged(tag, data).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::warn!(
peer = self.rank,
tag,
bytes = data.len(),
error = %e,
"tagged bulk transport send failed, falling back to QUIC"
);
}
}
}
self.send_raw_tagged(tag, data).await
}
pub fn remote_addr(&self) -> std::net::SocketAddr {
self.conn.remote_address()
}
async fn send_framed(&self, stream_tag: u8, prefix: &[u8], data: &[u8]) -> Result<()> {
let mut stream = self.stream_pool.checkout().await?;
stream
.write_all(&[stream_tag])
.await
.map_err(|e| NexarError::transport_with_source("write stream tag", e))?;
if !prefix.is_empty() {
stream
.write_all(prefix)
.await
.map_err(|e| NexarError::transport_with_source("write prefix", e))?;
}
stream
.write_all(&(data.len() as u64).to_le_bytes())
.await
.map_err(|e| NexarError::transport_with_source("write length", e))?;
stream
.write_all(data)
.await
.map_err(|e| NexarError::transport_with_source("write payload", e))?;
stream
.finish()
.map_err(|e| NexarError::transport_with_source("finish stream", e))?;
Ok(())
}
}
struct ExtensionRef<'a, T> {
guard: std::sync::RwLockReadGuard<'a, Vec<Box<dyn std::any::Any + Send + Sync>>>,
idx: usize,
_marker: std::marker::PhantomData<T>,
}
impl<T: std::any::Any> std::ops::Deref for ExtensionRef<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.guard[self.idx]
.downcast_ref::<T>()
.expect("extension type mismatch: index was validated at construction")
}
}