use super::error::QuicError;
use crate::cx::Cx;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug, Default)]
pub struct StreamTracker {
closing: AtomicBool,
}
impl StreamTracker {
#[must_use]
pub fn new() -> Arc<Self> {
Arc::new(Self {
closing: AtomicBool::new(false),
})
}
pub fn mark_closing(&self) {
self.closing.store(true, Ordering::Release);
}
pub fn is_closing(&self) -> bool {
self.closing.load(Ordering::Acquire)
}
}
#[derive(Debug)]
pub struct SendStream {
inner: quinn::SendStream,
tracker: Arc<StreamTracker>,
reset_code: u32,
}
impl SendStream {
pub(crate) fn new(inner: quinn::SendStream, tracker: &Arc<StreamTracker>) -> Self {
Self {
inner,
tracker: Arc::clone(tracker),
reset_code: 0,
}
}
#[must_use]
pub fn id(&self) -> quinn::StreamId {
self.inner.id()
}
pub fn set_reset_code(&mut self, code: u32) {
self.reset_code = code;
}
pub async fn write(&mut self, cx: &Cx, data: &[u8]) -> Result<usize, QuicError> {
cx.checkpoint()?;
if self.tracker.is_closing() {
return Err(QuicError::StreamClosed);
}
self.inner.write(data).await.map_err(QuicError::from)
}
pub async fn write_all(&mut self, cx: &Cx, data: &[u8]) -> Result<(), QuicError> {
cx.checkpoint()?;
if self.tracker.is_closing() {
return Err(QuicError::StreamClosed);
}
self.inner.write_all(data).await.map_err(QuicError::from)
}
pub async fn finish(&mut self) -> Result<(), QuicError> {
self.inner.finish().map_err(QuicError::from)
}
pub fn reset(&mut self, code: u32) {
self.inner.reset(code.into()).ok();
}
#[must_use]
pub fn inner(&self) -> &quinn::SendStream {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut quinn::SendStream {
&mut self.inner
}
}
impl Drop for SendStream {
fn drop(&mut self) {
if self.tracker.is_closing() {
self.inner.reset(self.reset_code.into()).ok();
}
}
}
#[derive(Debug)]
pub struct RecvStream {
inner: quinn::RecvStream,
tracker: Arc<StreamTracker>,
stop_code: u32,
}
impl RecvStream {
pub(crate) fn new(inner: quinn::RecvStream, tracker: &Arc<StreamTracker>) -> Self {
Self {
inner,
tracker: Arc::clone(tracker),
stop_code: 0,
}
}
#[must_use]
pub fn id(&self) -> quinn::StreamId {
self.inner.id()
}
pub fn set_stop_code(&mut self, code: u32) {
self.stop_code = code;
}
pub async fn read(&mut self, cx: &Cx, buf: &mut [u8]) -> Result<Option<usize>, QuicError> {
cx.checkpoint()?;
if self.tracker.is_closing() {
return Err(QuicError::StreamClosed);
}
match self.inner.read(buf).await {
Ok(Some(n)) => Ok(Some(n)),
Ok(None) => Ok(None),
Err(e) => Err(QuicError::from(e)),
}
}
pub async fn read_exact(&mut self, cx: &Cx, buf: &mut [u8]) -> Result<(), QuicError> {
cx.checkpoint()?;
if self.tracker.is_closing() {
return Err(QuicError::StreamClosed);
}
self.inner.read_exact(buf).await.map_err(QuicError::from)
}
pub async fn read_to_end(&mut self, cx: &Cx, limit: usize) -> Result<Vec<u8>, QuicError> {
cx.checkpoint()?;
if self.tracker.is_closing() {
return Err(QuicError::StreamClosed);
}
self.inner.read_to_end(limit).await.map_err(QuicError::from)
}
pub fn stop(&mut self, code: u32) {
self.inner.stop(code.into()).ok();
}
#[must_use]
pub fn inner(&self) -> &quinn::RecvStream {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut quinn::RecvStream {
&mut self.inner
}
}
impl Drop for RecvStream {
fn drop(&mut self) {
if self.tracker.is_closing() {
self.inner.stop(self.stop_code.into()).ok();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tracker_initially_not_closing() {
let tracker = StreamTracker::new();
assert!(!tracker.is_closing());
}
#[test]
fn tracker_mark_closing() {
let tracker = StreamTracker::new();
tracker.mark_closing();
assert!(tracker.is_closing());
}
#[test]
fn tracker_mark_closing_idempotent() {
let tracker = StreamTracker::new();
tracker.mark_closing();
tracker.mark_closing();
assert!(tracker.is_closing());
}
#[test]
fn tracker_shared_across_arcs() {
let tracker = StreamTracker::new();
let tracker2 = Arc::clone(&tracker);
assert!(!tracker2.is_closing());
tracker.mark_closing();
assert!(tracker2.is_closing());
}
#[test]
fn tracker_default() {
let tracker = StreamTracker::default();
assert!(!tracker.closing.load(Ordering::Acquire));
}
#[test]
fn tracker_debug() {
let tracker = StreamTracker::new();
let debug = format!("{tracker:?}");
assert!(debug.contains("StreamTracker"));
}
}