use crate::{QuicTransportBidi, QuicTransportReceive, QuicTransportSend, Transport};
use atomic_waker::AtomicWaker;
use futures_lite::{AsyncRead, AsyncWrite};
use std::{
borrow::Cow,
collections::HashMap,
io,
net::SocketAddr,
pin::Pin,
sync::{
Arc, RwLock,
atomic::{AtomicBool, AtomicI32, Ordering},
},
task::{Context, Poll},
time::Duration,
};
use trillium_http::Priority;
pub(crate) fn transport_priority(priority: Priority) -> i32 {
-((i32::from(priority.urgency()) << 1) | i32::from(priority.is_incremental()))
}
#[derive(Debug, Default)]
pub(crate) struct PrioritySlot {
priority: AtomicI32,
update_seen: AtomicBool,
waker: AtomicWaker,
}
impl PrioritySlot {
fn store_update(&self, priority: i32) {
self.update_seen.store(true, Ordering::Relaxed);
self.store_and_wake(priority);
}
fn store_initial(&self, priority: i32) {
if !self.update_seen.load(Ordering::Relaxed) {
self.store_and_wake(priority);
}
}
fn store_and_wake(&self, priority: i32) {
if self.priority.swap(priority, Ordering::Relaxed) != priority {
self.waker.wake();
}
}
fn register(&self, cx: &Context<'_>) {
self.waker.register(cx.waker());
}
fn load(&self) -> i32 {
self.priority.load(Ordering::Relaxed)
}
}
const MAX_PENDING_PRIORITY_UPDATES: usize = 128;
#[derive(Debug, Default)]
struct Streams {
live: HashMap<u64, Arc<PrioritySlot>>,
pending: HashMap<u64, i32>,
}
#[derive(Clone, Debug, Default)]
pub(crate) struct PriorityRegistry(Arc<RwLock<Streams>>);
impl PriorityRegistry {
pub(crate) fn register(&self, stream_id: u64) -> Arc<PrioritySlot> {
let slot = Arc::<PrioritySlot>::default();
let mut streams = self.0.write().unwrap();
if let Some(priority) = streams.pending.remove(&stream_id) {
log::trace!(
"H3 stream {stream_id}: applying buffered PRIORITY_UPDATE {priority} on open"
);
slot.store_update(priority);
}
streams.live.insert(stream_id, slot.clone());
slot
}
pub(crate) fn deregister(&self, stream_id: u64) {
self.0.write().unwrap().live.remove(&stream_id);
}
pub(crate) fn apply(&self, stream_id: u64, priority: i32, is_update: bool) {
let mut streams = self.0.write().unwrap();
if let Some(slot) = streams.live.get(&stream_id) {
if is_update {
log::trace!("H3 stream {stream_id}: PRIORITY_UPDATE {priority} stored");
slot.store_update(priority);
} else {
log::trace!("H3 stream {stream_id}: initial priority {priority} stored");
slot.store_initial(priority);
}
} else if is_update {
let at_capacity = streams.pending.len() >= MAX_PENDING_PRIORITY_UPDATES;
if at_capacity && !streams.pending.contains_key(&stream_id) {
log::trace!(
"H3 stream {stream_id}: dropping PRIORITY_UPDATE {priority} (pending table \
full)"
);
} else {
log::trace!(
"H3 stream {stream_id}: buffering PRIORITY_UPDATE {priority} (stream not yet \
open)"
);
streams.pending.insert(stream_id, priority);
}
} else {
log::trace!(
"H3 stream {stream_id}: dropping initial priority {priority} (no live stream)"
);
}
}
}
#[derive(Debug)]
pub(crate) struct PrioritizedStream<T> {
inner: T,
slot: Arc<PrioritySlot>,
stream_id: u64,
applied: Option<i32>,
}
impl<T> PrioritizedStream<T> {
pub(crate) fn new(inner: T, slot: Arc<PrioritySlot>, stream_id: u64) -> Self {
Self {
inner,
slot,
stream_id,
applied: None,
}
}
}
impl<T: QuicTransportSend + Unpin> PrioritizedStream<T> {
fn sync_priority(&mut self) {
let target = self.slot.load();
if self.applied != Some(target) {
log::trace!(
"H3 stream {}: applying transport priority {target} to send stream",
self.stream_id
);
self.applied = Some(target);
self.inner.set_priority(target);
}
}
}
impl<T: AsyncRead + Unpin> AsyncRead for PrioritizedStream<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<T: QuicTransportSend + Unpin> AsyncWrite for PrioritizedStream<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.slot.register(cx);
self.sync_priority();
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.slot.register(cx);
self.sync_priority();
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.slot.register(cx);
self.sync_priority();
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_close(cx)
}
}
impl<T: QuicTransportReceive + Unpin> QuicTransportReceive for PrioritizedStream<T> {
fn stop(&mut self, code: u64) {
self.inner.stop(code);
}
}
impl<T: QuicTransportSend + Unpin> QuicTransportSend for PrioritizedStream<T> {
fn reset(&mut self, code: u64) {
self.inner.reset(code);
}
fn set_priority(&mut self, priority: i32) {
self.inner.set_priority(priority);
}
}
impl<T: Transport + QuicTransportSend> Transport for PrioritizedStream<T> {
fn set_linger(&mut self, linger: Option<Duration>) -> io::Result<()> {
self.inner.set_linger(linger)
}
fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
self.inner.set_nodelay(nodelay)
}
fn set_ip_ttl(&mut self, ttl: u32) -> io::Result<()> {
self.inner.set_ip_ttl(ttl)
}
fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
self.inner.peer_addr()
}
fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
self.inner.negotiated_alpn()
}
}
impl<T: QuicTransportBidi + Unpin> QuicTransportBidi for PrioritizedStream<T> {}
#[cfg(test)]
mod tests {
use super::*;
use futures_lite::{AsyncWriteExt, future::block_on};
use std::sync::Mutex;
struct RecordingSend {
set_priority_calls: Arc<Mutex<Vec<i32>>>,
}
impl AsyncWrite for RecordingSend {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl QuicTransportSend for RecordingSend {
fn reset(&mut self, _code: u64) {}
fn set_priority(&mut self, priority: i32) {
self.set_priority_calls.lock().unwrap().push(priority);
}
}
#[test]
fn applies_initial_and_updates_only_on_change() {
let calls = Arc::new(Mutex::new(Vec::new()));
let registry = PriorityRegistry::default();
let slot = registry.register(4);
let mut stream = PrioritizedStream::new(
RecordingSend {
set_priority_calls: calls.clone(),
},
slot,
4,
);
registry.apply(4, -6, false);
block_on(stream.write_all(b"a")).unwrap();
registry.apply(4, -1, true);
block_on(stream.write_all(b"b")).unwrap();
block_on(stream.write_all(b"c")).unwrap();
assert_eq!(*calls.lock().unwrap(), vec![-6, -1]);
}
#[test]
fn routes_by_stream_id_and_drops_initial_without_a_stream() {
let registry = PriorityRegistry::default();
let slot4 = registry.register(4);
let slot8 = registry.register(8);
registry.apply(4, -6, true);
assert_eq!(slot4.load(), -6);
assert_eq!(
slot8.load(),
0,
"a signal for stream 4 must not touch stream 8"
);
registry.apply(999, -3, false);
}
#[test]
fn buffered_update_survives_open_and_outranks_a_later_initial() {
let registry = PriorityRegistry::default();
registry.apply(4, -1, true);
let slot = registry.register(4);
assert_eq!(slot.load(), -1);
registry.apply(4, -6, false);
assert_eq!(slot.load(), -1);
registry.apply(4, -8, true);
assert_eq!(slot.load(), -8);
}
#[test]
fn update_outranks_initial_when_both_target_a_live_stream() {
let registry = PriorityRegistry::default();
let slot = registry.register(4);
registry.apply(4, -1, true);
registry.apply(4, -6, false);
assert_eq!(slot.load(), -1);
}
#[test]
fn changed_value_wakes_registered_task_unchanged_does_not() {
use std::{
sync::atomic::AtomicUsize,
task::{Wake, Waker},
};
struct CountingWaker(AtomicUsize);
impl Wake for CountingWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.fetch_add(1, Ordering::Relaxed);
}
}
let counter = Arc::new(CountingWaker(AtomicUsize::new(0)));
let waker = Waker::from(counter.clone());
let cx = Context::from_waker(&waker);
let wakes = || counter.0.load(Ordering::Relaxed);
let slot = PrioritySlot::default();
slot.register(&cx);
slot.store_update(-6);
assert_eq!(wakes(), 1);
slot.register(&cx);
slot.store_update(-6);
assert_eq!(wakes(), 1);
slot.store_update(-1);
assert_eq!(wakes(), 2);
}
}