1use crate::constants::DEFAULT_CLIENT_DATA_CHANNEL_CAPACITY;
2use crate::protocol::Data;
3use std::fmt::{self, Debug, Formatter};
4use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
5use tokio::sync::{mpsc, Notify};
6use tracing::trace;
7
8pub struct DataChannel {
9 pub id: u32,
10 pub closed: AtomicBool,
11 pub data_tx: mpsc::Sender<Data>,
12 pub capacity: AtomicU32,
13 pub notify: Notify,
14}
15
16impl DataChannel {
17 pub fn new_client(id: u32, data_tx: mpsc::Sender<Data>) -> DataChannel {
18 Self {
19 id,
20 closed: AtomicBool::new(false),
21 data_tx,
22 capacity: AtomicU32::new(DEFAULT_CLIENT_DATA_CHANNEL_CAPACITY),
23 notify: Notify::new(),
24 }
25 }
26
27 pub fn add_capacity(&self, amount: u32) {
28 trace!("{:?} acked, consumed {} bytes", self, amount);
29 self.capacity.fetch_add(amount, Ordering::SeqCst);
30 self.notify.notify_waiters();
31 }
32
33 pub async fn wait_for_capacity(&self, required: u32) -> Result<(), mpsc::error::SendError<()>> {
34 loop {
35 if self.closed.load(Ordering::SeqCst) {
36 trace!("{:?} channel is closed, cannot wait for capacity", self);
37 return Err(mpsc::error::SendError(()));
38 }
39 let current = self.capacity.load(Ordering::SeqCst);
40 trace!(
41 "{:?} checking capacity: {} bytes available, required: {} bytes",
42 self,
43 current,
44 required
45 );
46 if current >= required {
47 match self.capacity.compare_exchange_weak(
49 current,
50 current - required,
51 Ordering::SeqCst,
52 Ordering::SeqCst,
53 ) {
54 Ok(_) => {
55 trace!(
56 "{:?} has sufficient capacity: {} bytes available, consuming {} bytes",
57 self,
58 current,
59 required
60 );
61 return Ok(());
62 }
63 Err(_) => continue, }
65 }
66
67 trace!(
68 "{:?} insufficient capacity: {} bytes available, waiting for {} more",
69 self,
70 current,
71 required
72 );
73
74 tokio::select! {
75 _ = self.notify.notified() => {
76 trace!("{:?} notified, checking capacity again", self);
77 }
78 _ = self.data_tx.closed() => {
79 return Err(mpsc::error::SendError(()));
80 }
81 }
82 }
83 }
84
85 pub fn close(&self) {
86 trace!("{:?} closing channel", self);
87 self.closed.store(true, Ordering::SeqCst);
88 self.notify.notify_waiters();
89 }
90
91 pub async fn closed(&self) {
92 trace!("{:?} waiting for channel to close", self);
93 while !self.closed.load(Ordering::SeqCst) {
94 self.notify.notified().await;
95 }
96 }
97}
98
99impl Debug for DataChannel {
100 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
101 f.debug_struct("DataChannel")
102 .field("id", &self.id)
103 .field("capacity", &self.capacity)
104 .finish()
105 }
106}