use super::error::QuicError;
use crate::cx::Cx;
use std::future::{Future, poll_fn};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::Poll;
#[derive(Debug, Default)]
pub struct StreamTracker {
closing: AtomicBool,
}
impl StreamTracker {
#[must_use]
pub fn new() -> Arc<Self> {
Arc::new(Self {
closing: AtomicBool::new(false),
})
}
pub fn mark_closing(&self) {
self.closing.store(true, Ordering::Release);
}
pub fn is_closing(&self) -> bool {
self.closing.load(Ordering::Acquire)
}
}
#[derive(Debug)]
pub struct SendStream {
inner: quinn::SendStream,
tracker: Arc<StreamTracker>,
reset_code: u32,
}
impl SendStream {
pub(crate) fn new(inner: quinn::SendStream, tracker: &Arc<StreamTracker>) -> Self {
Self {
inner,
tracker: Arc::clone(tracker),
reset_code: 0,
}
}
#[must_use]
pub fn id(&self) -> quinn::StreamId {
self.inner.id()
}
pub fn set_reset_code(&mut self, code: u32) {
self.reset_code = code;
}
pub async fn write(&mut self, cx: &Cx, data: &[u8]) -> Result<usize, QuicError> {
check_stream_operation(cx, &self.tracker)?;
wait_result_with_cx(cx, self.inner.write(data)).await
}
pub async fn write_all(&mut self, cx: &Cx, data: &[u8]) -> Result<(), QuicError> {
check_stream_operation(cx, &self.tracker)?;
wait_result_with_cx(cx, self.inner.write_all(data)).await
}
pub async fn finish(&mut self, cx: &Cx) -> Result<(), QuicError> {
check_stream_operation(cx, &self.tracker)?;
self.inner.finish().map_err(QuicError::from)
}
pub fn reset(&mut self, code: u32) {
self.inner.reset(code.into()).ok();
}
#[must_use]
pub fn inner(&self) -> &quinn::SendStream {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut quinn::SendStream {
&mut self.inner
}
}
impl Drop for SendStream {
fn drop(&mut self) {
if self.tracker.is_closing() {
self.inner.reset(self.reset_code.into()).ok();
}
}
}
#[derive(Debug)]
pub struct RecvStream {
inner: quinn::RecvStream,
tracker: Arc<StreamTracker>,
stop_code: u32,
}
impl RecvStream {
pub(crate) fn new(inner: quinn::RecvStream, tracker: &Arc<StreamTracker>) -> Self {
Self {
inner,
tracker: Arc::clone(tracker),
stop_code: 0,
}
}
#[must_use]
pub fn id(&self) -> quinn::StreamId {
self.inner.id()
}
pub fn set_stop_code(&mut self, code: u32) {
self.stop_code = code;
}
pub async fn read(&mut self, cx: &Cx, buf: &mut [u8]) -> Result<Option<usize>, QuicError> {
check_stream_operation(cx, &self.tracker)?;
wait_result_with_cx(cx, self.inner.read(buf)).await
}
pub async fn read_exact(&mut self, cx: &Cx, buf: &mut [u8]) -> Result<(), QuicError> {
check_stream_operation(cx, &self.tracker)?;
wait_result_with_cx(cx, self.inner.read_exact(buf)).await
}
pub async fn read_to_end(&mut self, cx: &Cx, limit: usize) -> Result<Vec<u8>, QuicError> {
check_stream_operation(cx, &self.tracker)?;
wait_result_with_cx(cx, self.inner.read_to_end(limit)).await
}
pub fn stop(&mut self, code: u32) {
self.inner.stop(code.into()).ok();
}
#[must_use]
pub fn inner(&self) -> &quinn::RecvStream {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut quinn::RecvStream {
&mut self.inner
}
}
impl Drop for RecvStream {
fn drop(&mut self) {
if self.tracker.is_closing() {
self.inner.stop(self.stop_code.into()).ok();
}
}
}
fn check_stream_operation(cx: &Cx, tracker: &StreamTracker) -> Result<(), QuicError> {
cx.checkpoint()?;
if tracker.is_closing() {
return Err(QuicError::StreamClosed);
}
Ok(())
}
async fn wait_result_with_cx<T, E, F>(cx: &Cx, future: F) -> Result<T, QuicError>
where
E: Into<QuicError>,
F: Future<Output = Result<T, E>>,
{
let mut future = std::pin::pin!(future);
poll_fn(|poll_cx| {
if let Err(err) = cx.checkpoint() {
return Poll::Ready(Err(QuicError::from(err)));
}
future
.as_mut()
.poll(poll_cx)
.map(|result| result.map_err(Into::into))
})
.await
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
use std::task::Context;
fn noop_waker() -> std::task::Waker {
std::task::Waker::noop().clone()
}
struct PendingOnce {
polled: bool,
}
impl Future for PendingOnce {
type Output = Result<(), QuicError>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.polled {
Poll::Ready(Ok(()))
} else {
self.polled = true;
Poll::Pending
}
}
}
#[test]
fn tracker_initially_not_closing() {
let tracker = StreamTracker::new();
assert!(!tracker.is_closing());
}
#[test]
fn tracker_mark_closing() {
let tracker = StreamTracker::new();
tracker.mark_closing();
assert!(tracker.is_closing());
}
#[test]
fn tracker_mark_closing_idempotent() {
let tracker = StreamTracker::new();
tracker.mark_closing();
tracker.mark_closing();
assert!(tracker.is_closing());
}
#[test]
fn tracker_shared_across_arcs() {
let tracker = StreamTracker::new();
let tracker2 = Arc::clone(&tracker);
assert!(!tracker2.is_closing());
tracker.mark_closing();
assert!(tracker2.is_closing());
}
#[test]
fn tracker_default() {
let tracker = StreamTracker::default();
assert!(!tracker.closing.load(Ordering::Acquire));
}
#[test]
fn tracker_debug() {
let tracker = StreamTracker::new();
let debug = format!("{tracker:?}");
assert!(debug.contains("StreamTracker"));
}
#[test]
fn wait_result_with_cx_returns_cancelled_when_context_is_cancelled_between_polls() {
let cx = Cx::for_testing();
let mut future = std::pin::pin!(wait_result_with_cx(&cx, PendingOnce { polled: false }));
let waker = noop_waker();
let mut poll_cx = Context::from_waker(&waker);
assert!(matches!(future.as_mut().poll(&mut poll_cx), Poll::Pending));
cx.set_cancel_requested(true);
let cancelled = matches!(
future.as_mut().poll(&mut poll_cx),
Poll::Ready(Err(QuicError::Cancelled))
);
assert!(
cancelled,
"future should return cancelled after Cx cancellation"
);
}
#[test]
fn stream_operation_guard_observes_cancellation_before_quinn_call() {
let cx = Cx::for_testing();
let tracker = StreamTracker::new();
cx.set_cancel_requested(true);
let err = check_stream_operation(&cx, &tracker).expect_err("cancelled Cx must fail");
assert!(matches!(err, QuicError::Cancelled));
}
#[test]
fn stream_operation_guard_rejects_closing_tracker() {
let cx = Cx::for_testing();
let tracker = StreamTracker::new();
tracker.mark_closing();
let err = check_stream_operation(&cx, &tracker).expect_err("closing tracker must fail");
assert!(matches!(err, QuicError::StreamClosed));
}
}