use bytes::Bytes;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Notify;
use tokio::sync::Semaphore;
use crate::error::{Error, Result};
use crate::transport::h3::native::data_frame_encoded_len;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct H3TunnelOutbound {
pub bytes: Bytes,
pub fin: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum H3TunnelEvent {
Data(Bytes),
EndStream,
Reset(String),
GoAway { id: u64 },
}
pub(crate) const MAX_TUNNEL_OUTBOUND_BYTE_BUDGET: usize = u32::MAX as usize;
pub(crate) const MAX_TUNNEL_INBOUND_BYTE_BUDGET: usize = u32::MAX as usize;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct H3TunnelCapacity {
pub outbound_budget: usize,
pub outbound_available_bytes: usize,
pub outbound_pending_bytes: usize,
pub inbound_budget: usize,
pub inbound_available_bytes: usize,
pub inbound_pending_bytes: usize,
}
#[derive(Debug)]
pub(crate) struct H3TunnelCredit {
released_recv_bytes: AtomicUsize,
driver_notify: Arc<Notify>,
send_semaphore: Arc<Semaphore>,
send_budget: usize,
recv_semaphore: Arc<Semaphore>,
recv_budget: usize,
}
impl H3TunnelCredit {
pub(crate) fn new(
driver_notify: Arc<Notify>,
send_budget: usize,
recv_budget: usize,
) -> Arc<Self> {
let send_budget = send_budget.min(MAX_TUNNEL_OUTBOUND_BYTE_BUDGET);
let recv_budget = recv_budget.min(MAX_TUNNEL_INBOUND_BYTE_BUDGET);
Arc::new(Self {
released_recv_bytes: AtomicUsize::new(0),
driver_notify,
send_semaphore: Arc::new(Semaphore::new(send_budget)),
send_budget,
recv_semaphore: Arc::new(Semaphore::new(recv_budget)),
recv_budget,
})
}
pub(crate) fn take_released_recv_bytes(&self) -> usize {
self.released_recv_bytes.swap(0, Ordering::Relaxed)
}
pub(crate) fn release_send_bytes(&self, bytes: usize) {
if bytes == 0 {
return;
}
let capped = bytes.min(self.send_budget);
self.send_semaphore.add_permits(capped);
}
pub(crate) fn try_reserve_inbound_bytes(&self, bytes: usize) -> bool {
if bytes == 0 {
return true;
}
let capped = bytes.min(self.recv_budget);
match self.recv_semaphore.try_acquire_many(capped as u32) {
Ok(permit) => {
permit.forget();
true
}
Err(_) => false,
}
}
pub(crate) fn release_inbound_bytes(&self, bytes: usize) {
if bytes == 0 {
return;
}
self.recv_semaphore.add_permits(bytes.min(self.recv_budget));
}
pub(crate) fn has_inbound_capacity(&self) -> bool {
self.recv_semaphore.available_permits() > 0
}
pub(crate) fn capacity(&self) -> H3TunnelCapacity {
let outbound_available_bytes = self
.send_semaphore
.available_permits()
.min(self.send_budget);
let inbound_available_bytes = self
.recv_semaphore
.available_permits()
.min(self.recv_budget);
H3TunnelCapacity {
outbound_budget: self.send_budget,
outbound_available_bytes,
outbound_pending_bytes: self.send_budget.saturating_sub(outbound_available_bytes),
inbound_budget: self.recv_budget,
inbound_available_bytes,
inbound_pending_bytes: self.recv_budget.saturating_sub(inbound_available_bytes),
}
}
#[cfg(test)]
pub(crate) fn available_send_permits(&self) -> usize {
self.send_semaphore.available_permits()
}
#[cfg(test)]
pub(crate) fn available_inbound_permits(&self) -> usize {
self.recv_semaphore.available_permits()
}
}
#[derive(Debug)]
enum H3TunnelInboundReceiver {
Bounded(mpsc::Receiver<Result<H3TunnelEvent>>),
Unbounded(mpsc::UnboundedReceiver<Result<H3TunnelEvent>>),
}
impl H3TunnelInboundReceiver {
async fn recv(&mut self) -> Option<Result<H3TunnelEvent>> {
match self {
Self::Bounded(rx) => rx.recv().await,
Self::Unbounded(rx) => rx.recv().await,
}
}
}
#[derive(Debug)]
pub struct H3Tunnel {
outbound_tx: mpsc::UnboundedSender<H3TunnelOutbound>,
inbound_rx: H3TunnelInboundReceiver,
credit: Option<Arc<H3TunnelCredit>>,
}
impl H3Tunnel {
pub fn new(
outbound_tx: mpsc::UnboundedSender<H3TunnelOutbound>,
inbound_rx: mpsc::Receiver<Result<H3TunnelEvent>>,
) -> Self {
Self {
outbound_tx,
inbound_rx: H3TunnelInboundReceiver::Bounded(inbound_rx),
credit: None,
}
}
pub(crate) fn new_with_credit(
outbound_tx: mpsc::UnboundedSender<H3TunnelOutbound>,
inbound_rx: mpsc::UnboundedReceiver<Result<H3TunnelEvent>>,
credit: Arc<H3TunnelCredit>,
) -> Self {
Self {
outbound_tx,
inbound_rx: H3TunnelInboundReceiver::Unbounded(inbound_rx),
credit: Some(credit),
}
}
pub async fn send_bytes(&self, bytes: Bytes, fin: bool) -> Result<()> {
if !bytes.is_empty() {
if let Some(credit) = self.credit.as_ref() {
let to_acquire = bytes.len().min(credit.send_budget);
let permit = credit
.send_semaphore
.acquire_many(to_acquire as u32)
.await
.map_err(|_| Error::HttpProtocol("H3 tunnel outbound credit closed".into()))?;
permit.forget();
}
}
self.outbound_tx
.send(H3TunnelOutbound { bytes, fin })
.map_err(|_| Error::HttpProtocol("H3 tunnel outbound channel closed".into()))
}
pub async fn close_send(&self) -> Result<()> {
self.send_bytes(Bytes::new(), true).await
}
pub fn capacity(&self) -> H3TunnelCapacity {
self.credit
.as_ref()
.map(|credit| credit.capacity())
.unwrap_or_default()
}
pub async fn recv_event(&mut self) -> Option<Result<H3TunnelEvent>> {
let event = self.inbound_rx.recv().await?;
if let Ok(H3TunnelEvent::Data(bytes)) = &event {
self.release_recv_bytes(bytes.len());
} else if let Some(credit) = self.credit.as_ref() {
credit.driver_notify.notify_one();
}
Some(event)
}
pub async fn recv_bytes(&mut self) -> Option<Result<Bytes>> {
match self.recv_event().await? {
Ok(H3TunnelEvent::Data(bytes)) => Some(Ok(bytes)),
Ok(H3TunnelEvent::EndStream) => None,
Ok(H3TunnelEvent::Reset(reason)) => Some(Err(Error::HttpProtocol(format!(
"H3 tunnel reset: {reason}"
)))),
Ok(H3TunnelEvent::GoAway { id }) => Some(Err(Error::HttpProtocol(format!(
"H3 tunnel closed by GOAWAY id={id}"
)))),
Err(err) => Some(Err(err)),
}
}
fn release_recv_bytes(&self, released: usize) {
let Some(credit) = self.credit.as_ref() else {
return;
};
if released > 0 {
credit.release_inbound_bytes(released);
credit
.released_recv_bytes
.fetch_add(data_frame_encoded_len(released), Ordering::Relaxed);
}
credit.driver_notify.notify_one();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex as TokioMutex;
use tokio::time::{sleep, timeout};
struct OutboundDrainer {
outbound_rx: TokioMutex<mpsc::UnboundedReceiver<H3TunnelOutbound>>,
credit: Arc<H3TunnelCredit>,
chunk_size: usize,
per_chunk_delay: Duration,
peak_in_flight: Arc<AtomicUsize>,
}
impl OutboundDrainer {
fn new(
outbound_rx: mpsc::UnboundedReceiver<H3TunnelOutbound>,
credit: Arc<H3TunnelCredit>,
chunk_size: usize,
per_chunk_delay: Duration,
) -> Arc<Self> {
Arc::new(Self {
outbound_rx: TokioMutex::new(outbound_rx),
credit,
chunk_size,
per_chunk_delay,
peak_in_flight: Arc::new(AtomicUsize::new(0)),
})
}
async fn run(self: Arc<Self>) -> Vec<H3TunnelOutbound> {
let mut collected = Vec::new();
let mut rx = self.outbound_rx.lock().await;
while let Some(outbound) = rx.recv().await {
let budget = self.credit.send_budget;
let acquired = outbound.bytes.len().min(budget);
let in_flight = budget - self.credit.available_send_permits();
self.peak_in_flight
.fetch_max(in_flight, AtomicOrdering::SeqCst);
let mut released = 0usize;
let total = outbound.bytes.len();
if total == 0 {
collected.push(outbound.clone());
if outbound.fin {
return collected;
}
continue;
}
let mut offset = 0usize;
while offset < total {
let chunk = self.chunk_size.min(total - offset);
if !self.per_chunk_delay.is_zero() {
sleep(self.per_chunk_delay).await;
}
let release_now = chunk.min(acquired.saturating_sub(released));
if release_now > 0 {
self.credit.release_send_bytes(release_now);
released = released.saturating_add(release_now);
}
offset += chunk;
}
if released < acquired {
self.credit.release_send_bytes(acquired - released);
}
collected.push(outbound);
}
collected
}
}
fn make_tunnel(
budget: usize,
) -> (
H3Tunnel,
mpsc::UnboundedReceiver<H3TunnelOutbound>,
Arc<H3TunnelCredit>,
) {
let (outbound_tx, outbound_rx) = mpsc::unbounded_channel::<H3TunnelOutbound>();
let (_inbound_tx, inbound_rx) = mpsc::unbounded_channel::<Result<H3TunnelEvent>>();
let credit = H3TunnelCredit::new(Arc::new(Notify::new()), budget, budget);
let tunnel = H3Tunnel::new_with_credit(outbound_tx, inbound_rx, credit.clone());
(tunnel, outbound_rx, credit)
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn send_larger_than_budget_blocks_until_consumer_drains() {
let budget = 64 * 1024;
let (tunnel, outbound_rx, credit) = make_tunnel(budget);
let prefill = credit
.send_semaphore
.clone()
.try_acquire_many_owned(budget as u32)
.expect("must reserve every permit before the producer starts");
std::mem::forget(prefill);
assert_eq!(credit.available_send_permits(), 0);
let payload_size = 4 * budget;
let payload = Bytes::from(vec![0x42u8; payload_size]);
let follow_up = Bytes::from(vec![0x21u8; budget / 2]);
let drainer = OutboundDrainer::new(
outbound_rx,
credit.clone(),
budget / 8,
Duration::from_millis(1),
);
let drainer_handle = {
let drainer = drainer.clone();
tokio::spawn(async move { drainer.run().await })
};
let tunnel = Arc::new(tunnel);
let producer = {
let tunnel = tunnel.clone();
let payload = payload.clone();
let follow_up = follow_up.clone();
tokio::spawn(async move {
tunnel
.send_bytes(payload, false)
.await
.expect("oversized send_bytes must complete once credit is released");
tunnel
.send_bytes(follow_up, false)
.await
.expect("follow-up send_bytes must complete");
tunnel
.send_bytes(Bytes::new(), true)
.await
.expect("close_send must complete even after credit is drained");
})
};
credit.release_send_bytes(budget);
timeout(Duration::from_secs(5), producer)
.await
.expect("producer must not deadlock when sending more than budget")
.expect("producer task panicked");
let collected = drainer_handle.await.expect("drainer task did not panic");
let total_collected: usize = collected.iter().map(|o| o.bytes.len()).sum();
assert_eq!(
total_collected,
payload_size + follow_up.len(),
"drainer must have observed every byte the producer queued"
);
assert!(
collected.last().expect("at least one outbound").fin,
"last drained outbound must carry the producer's close_send fin"
);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn two_producers_respect_total_byte_budget() {
let budget = 8 * 1024;
let (tunnel, outbound_rx, credit) = make_tunnel(budget);
let drainer =
OutboundDrainer::new(outbound_rx, credit.clone(), 512, Duration::from_millis(2));
let peak_in_flight = drainer.peak_in_flight.clone();
let drainer_handle = {
let drainer = drainer.clone();
tokio::spawn(async move { drainer.run().await })
};
let tunnel_ref = &tunnel;
let producer_a = async move {
for _ in 0..6 {
tunnel_ref
.send_bytes(Bytes::from(vec![1u8; 2 * 1024]), false)
.await
.expect("producer A send_bytes");
}
};
let producer_b = async move {
for _ in 0..6 {
tunnel_ref
.send_bytes(Bytes::from(vec![2u8; 2 * 1024]), false)
.await
.expect("producer B send_bytes");
}
};
tokio::join!(producer_a, producer_b);
tunnel
.send_bytes(Bytes::new(), true)
.await
.expect("final fin send");
drainer_handle.await.expect("drainer did not panic");
let observed_peak = peak_in_flight.load(AtomicOrdering::SeqCst);
assert!(
observed_peak <= budget,
"peak in-flight bytes {observed_peak} must not exceed the configured budget {budget}",
);
assert!(
observed_peak >= 2 * 1024,
"peak in-flight should be at least one full producer chunk (was {observed_peak})",
);
}
#[tokio::test(start_paused = false)]
async fn close_send_works_when_budget_is_exhausted() {
let budget = 4 * 1024;
let (tunnel, mut outbound_rx, credit) = make_tunnel(budget);
let drained = credit
.send_semaphore
.clone()
.try_acquire_many_owned(budget as u32)
.expect("must reserve every permit");
std::mem::forget(drained);
assert_eq!(credit.available_send_permits(), 0);
timeout(Duration::from_secs(2), tunnel.close_send())
.await
.expect("close_send must not block on the credit semaphore when budget is exhausted")
.expect("close_send returned an error");
let queued = outbound_rx
.recv()
.await
.expect("close_send must enqueue an outbound with fin");
assert!(queued.bytes.is_empty(), "close_send must send empty bytes");
assert!(queued.fin, "close_send must mark the outbound as fin");
assert!(outbound_rx.try_recv().is_err());
}
#[test]
fn release_send_bytes_is_capped_at_send_budget() {
let budget = 16 * 1024;
let credit = H3TunnelCredit::new(Arc::new(Notify::new()), budget, budget);
let permit = credit
.send_semaphore
.clone()
.try_acquire_many_owned(budget as u32)
.expect("reserve every permit");
std::mem::forget(permit);
assert_eq!(credit.available_send_permits(), 0);
credit.release_send_bytes(4 * budget);
assert_eq!(credit.available_send_permits(), budget);
}
#[test]
fn capacity_snapshot_reports_tunnel_backpressure_budgets() {
let budget = 16 * 1024;
let (tunnel, _outbound_rx, credit) = make_tunnel(budget);
let send_permit = credit
.send_semaphore
.clone()
.try_acquire_many_owned(4 * 1024)
.expect("reserve outbound permits");
std::mem::forget(send_permit);
assert!(credit.try_reserve_inbound_bytes(2 * 1024));
let capacity = tunnel.capacity();
assert_eq!(capacity.outbound_budget, budget);
assert_eq!(capacity.outbound_available_bytes, 12 * 1024);
assert_eq!(capacity.outbound_pending_bytes, 4 * 1024);
assert_eq!(capacity.inbound_budget, budget);
assert_eq!(capacity.inbound_available_bytes, 14 * 1024);
assert_eq!(capacity.inbound_pending_bytes, 2 * 1024);
}
#[tokio::test]
async fn recv_event_releases_encoded_data_frame_credit() {
let (_outbound_tx, outbound_rx) = mpsc::unbounded_channel();
drop(outbound_rx);
let (inbound_tx, inbound_rx) = mpsc::unbounded_channel();
let credit = H3TunnelCredit::new(Arc::new(Notify::new()), 1024, 1024);
let mut tunnel = H3Tunnel::new_with_credit(_outbound_tx, inbound_rx, credit.clone());
assert!(credit.try_reserve_inbound_bytes(64));
inbound_tx
.send(Ok(H3TunnelEvent::Data(Bytes::from(vec![0x42; 64]))))
.expect("queue inbound data");
let event = tunnel.recv_event().await.expect("inbound event");
assert!(matches!(event, Ok(H3TunnelEvent::Data(bytes)) if bytes.len() == 64));
assert_eq!(
credit.take_released_recv_bytes(),
67,
"64 payload bytes must release DATA frame type + two-byte length overhead"
);
assert_eq!(credit.available_inbound_permits(), 1024);
}
}