use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::time::timeout;
use super::{BlobManager, BlobReceiverState};
use crate::ids::PeerMessageId;
use crate::ids::{AureliaError, ErrorId};
use crate::peering::ring_buffer::{OutboundRingBuffer, TryPushAvailable};
pub(crate) struct BlobReceiverStream {
blob: std::sync::Arc<BlobManager>,
stream_id: PeerMessageId,
receiver: std::sync::Arc<BlobReceiverState>,
current: Option<Bytes>,
offset: usize,
pending: Option<Pin<Box<dyn Future<Output = ReadOutcome> + Send>>>,
runtime_handle: tokio::runtime::Handle,
}
enum ReadOutcome {
Chunk(Bytes),
Eof,
Error(AureliaError),
}
async fn try_receiver_read(
blob: &BlobManager,
receiver: &BlobReceiverState,
stream_id: PeerMessageId,
) -> Option<ReadOutcome> {
if let Some(err) = receiver.error.lock().await.clone() {
return Some(ReadOutcome::Error(err));
}
let ring = blob.recv_ring(stream_id).await;
if let Some(ring) = ring {
if let Some(chunk) = ring.take_next().await {
receiver.notify.notify_waiters();
return Some(ReadOutcome::Chunk(chunk));
}
if receiver.completed.load(std::sync::atomic::Ordering::SeqCst) && ring.is_complete().await
{
let _ = blob.remove_recv_stream(stream_id).await;
blob.note_recv_complete(stream_id, receiver.completion_ttl)
.await;
return Some(ReadOutcome::Eof);
}
} else {
if let Some(err) = receiver.error.lock().await.clone() {
return Some(ReadOutcome::Error(err));
}
if receiver.completed.load(std::sync::atomic::Ordering::SeqCst) {
return Some(ReadOutcome::Eof);
}
}
None
}
async fn fail_receiver_idle(
blob: &BlobManager,
receiver: &BlobReceiverState,
stream_id: PeerMessageId,
) -> ReadOutcome {
let err = AureliaError::new(ErrorId::BlobStreamIdleTimeout);
receiver.fail(err.clone()).await;
let _ = blob.remove_recv_stream(stream_id).await;
ReadOutcome::Error(err)
}
impl BlobReceiverStream {
pub(crate) fn new(
blob: std::sync::Arc<BlobManager>,
stream_id: PeerMessageId,
receiver: std::sync::Arc<BlobReceiverState>,
runtime_handle: tokio::runtime::Handle,
) -> Self {
Self {
blob,
stream_id,
receiver,
current: None,
offset: 0,
pending: None,
runtime_handle,
}
}
fn start_read(&mut self) {
let blob = std::sync::Arc::clone(&self.blob);
let receiver = std::sync::Arc::clone(&self.receiver);
let stream_id = self.stream_id;
let idle_timeout = receiver.idle_timeout;
self.pending = Some(Box::pin(async move {
loop {
if !receiver.accepted.load(std::sync::atomic::Ordering::SeqCst) {
let notified = receiver.notify.notified();
tokio::pin!(notified);
if receiver.accepted.load(std::sync::atomic::Ordering::SeqCst) {
continue;
}
if timeout(idle_timeout, &mut notified).await.is_err() {
return fail_receiver_idle(&blob, &receiver, stream_id).await;
}
continue;
}
if let Some(outcome) = try_receiver_read(&blob, &receiver, stream_id).await {
return outcome;
}
let notified = receiver.notify.notified();
tokio::pin!(notified);
if let Some(outcome) = try_receiver_read(&blob, &receiver, stream_id).await {
return outcome;
}
if timeout(idle_timeout, &mut notified).await.is_err() {
return fail_receiver_idle(&blob, &receiver, stream_id).await;
}
}
}));
}
}
impl AsyncRead for BlobReceiverStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
loop {
if self.current.is_some() {
let (to_copy, chunk_len) = {
let Some(chunk) = self.current.as_ref() else {
return Poll::Ready(Err(std::io::Error::other(
"blob receiver missing current chunk",
)));
};
let start = self.offset;
let end = (start + buf.remaining()).min(chunk.len());
let len = end.saturating_sub(start);
if len > 0 {
buf.put_slice(&chunk[start..end]);
}
(len, chunk.len())
};
self.offset = self.offset.saturating_add(to_copy);
if self.offset >= chunk_len {
self.current = None;
self.offset = 0;
}
return Poll::Ready(Ok(()));
}
if self.pending.is_none() {
self.start_read();
}
let Some(mut pending) = self.pending.take() else {
return Poll::Ready(Err(std::io::Error::other(
"blob receiver missing read future",
)));
};
match pending.as_mut().poll(cx) {
Poll::Pending => {
self.pending = Some(pending);
return Poll::Pending;
}
Poll::Ready(outcome) => match outcome {
ReadOutcome::Chunk(chunk) => {
self.current = Some(chunk);
}
ReadOutcome::Eof => return Poll::Ready(Ok(())),
ReadOutcome::Error(err) => {
return Poll::Ready(Err(std::io::Error::other(err.to_string())))
}
},
}
}
}
}
impl Drop for BlobReceiverStream {
fn drop(&mut self) {
let blob = std::sync::Arc::clone(&self.blob);
let receiver = std::sync::Arc::clone(&self.receiver);
let stream_id = self.stream_id;
let runtime_handle = self.runtime_handle.clone();
runtime_handle.spawn(async move {
if receiver.completed.load(std::sync::atomic::Ordering::SeqCst) {
return;
}
let err = AureliaError::new(ErrorId::PeerUnavailable);
let _ = blob.remove_recv_stream(stream_id).await;
blob.drop_pending_request(stream_id).await;
receiver.fail(err).await;
});
}
}
pub(crate) struct BlobSenderStream {
state: std::sync::Arc<tokio::sync::Mutex<BlobSenderState>>,
pending: Option<PendingOp>,
runtime_handle: tokio::runtime::Handle,
}
struct BlobSenderState {
blob: std::sync::Arc<BlobManager>,
stream_id: PeerMessageId,
ring: std::sync::Arc<OutboundRingBuffer>,
send_timeout: Duration,
closed: bool,
}
#[derive(Clone, Copy)]
enum PendingKind {
Capacity,
Flush,
Shutdown,
}
struct PendingOp {
kind: PendingKind,
future: Pin<Box<dyn Future<Output = Result<(), AureliaError>> + Send>>,
}
impl BlobSenderStream {
pub(crate) fn new(
blob: std::sync::Arc<BlobManager>,
stream_id: PeerMessageId,
ring: std::sync::Arc<OutboundRingBuffer>,
send_timeout: Duration,
runtime_handle: tokio::runtime::Handle,
) -> Self {
let state = BlobSenderState {
blob,
stream_id,
ring,
send_timeout,
closed: false,
};
Self {
state: std::sync::Arc::new(tokio::sync::Mutex::new(state)),
pending: None,
runtime_handle,
}
}
fn start_capacity_wait(&mut self) {
let state = std::sync::Arc::clone(&self.state);
let future = async move {
let state = state.lock().await;
if state.closed {
return Err(AureliaError::new(ErrorId::PeerUnavailable));
}
let deadline = tokio::time::Instant::now() + state.send_timeout;
state.ring.wait_for_capacity(deadline).await
};
self.pending = Some(PendingOp {
kind: PendingKind::Capacity,
future: Box::pin(future),
});
}
fn start_flush(&mut self) {
let state = std::sync::Arc::clone(&self.state);
let future = async move {
let state = state.lock().await;
if state.closed {
return Ok(());
}
let deadline = tokio::time::Instant::now() + state.send_timeout;
state.ring.wait_for_inflight_drain(deadline).await
};
self.pending = Some(PendingOp {
kind: PendingKind::Flush,
future: Box::pin(future),
});
}
fn start_shutdown(&mut self) {
let state = std::sync::Arc::clone(&self.state);
let future = async move {
let mut state = state.lock().await;
if state.closed {
return Ok(());
}
state.ring.seal(state.send_timeout).await?;
state.blob.notify_work();
let deadline = tokio::time::Instant::now() + state.send_timeout;
if let Err(err) = state.ring.wait_for_complete(deadline).await {
state.cleanup().await;
return Err(err);
}
state.cleanup().await;
Ok(())
};
self.pending = Some(PendingOp {
kind: PendingKind::Shutdown,
future: Box::pin(future),
});
}
}
impl BlobSenderState {
async fn cleanup(&mut self) {
if !self.closed {
self.blob.unregister_outbound_stream(self.stream_id).await;
self.closed = true;
}
}
}
impl AsyncWrite for BlobSenderStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
loop {
if let Some(pending) = self.pending.as_mut() {
match pending.future.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
let kind = pending.kind;
self.pending = None;
if let Err(err) = result {
return Poll::Ready(Err(std::io::Error::other(err.to_string())));
}
if let PendingKind::Capacity = kind {
continue;
}
continue;
}
}
}
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let state = match self.state.try_lock() {
Ok(state) => state,
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
if state.closed {
return Poll::Ready(Err(std::io::Error::other(
AureliaError::new(ErrorId::PeerUnavailable).to_string(),
)));
}
let blob = std::sync::Arc::clone(&state.blob);
let result = state.ring.try_push_available(buf, || blob.notify_work());
drop(state);
match result {
Ok(TryPushAvailable::Accepted { bytes }) => {
blob.notify_work();
return Poll::Ready(Ok(bytes));
}
Ok(TryPushAvailable::Full) => {
self.start_capacity_wait();
}
Ok(TryPushAvailable::Busy) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Err(err) => return Poll::Ready(Err(std::io::Error::other(err.to_string()))),
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
loop {
if let Some(pending) = self.pending.as_mut() {
match pending.future.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
let kind = pending.kind;
self.pending = None;
if let Err(err) = result {
return Poll::Ready(Err(std::io::Error::other(err.to_string())));
}
match kind {
PendingKind::Flush | PendingKind::Shutdown => {
return Poll::Ready(Ok(()));
}
PendingKind::Capacity => continue,
}
}
}
}
self.start_flush();
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
loop {
if let Some(pending) = self.pending.as_mut() {
match pending.future.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
let kind = pending.kind;
self.pending = None;
if let Err(err) = result {
return Poll::Ready(Err(std::io::Error::other(err.to_string())));
}
if matches!(kind, PendingKind::Shutdown) {
return Poll::Ready(Ok(()));
}
continue;
}
}
}
self.start_shutdown();
}
}
}
impl Drop for BlobSenderStream {
fn drop(&mut self) {
let state = std::sync::Arc::clone(&self.state);
let runtime_handle = self.runtime_handle.clone();
runtime_handle.spawn(async move {
let mut state = state.lock().await;
if state.closed {
return;
}
state.cleanup().await;
});
}
}