1use crossbeam_queue::ArrayQueue;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
4
5use super::{Backpressure, ChannelRecv, ChannelSend, ChannelStats, CloseBehavior, RecvOutcome};
6
7#[cfg(feature = "metrics")]
8use crate::metrics::MetricsSink;
9#[cfg(feature = "async-channels")]
10use tokio::sync::Notify;
11
12struct BoundedInner<T> {
13 queue: ArrayQueue<T>,
14 closed: AtomicBool,
15 senders: AtomicUsize,
16 receivers: AtomicUsize,
17 enqueued: AtomicU64,
18 dropped: AtomicU64,
19 drained: AtomicU64,
20 depth: AtomicUsize,
21 close_behavior: CloseBehavior,
22 #[cfg(feature = "async-channels")]
23 notify: Arc<Notify>,
24 #[cfg(feature = "metrics")]
25 metrics: Option<Arc<dyn MetricsSink>>,
26}
27
28impl<T> BoundedInner<T> {
29 fn new(capacity: usize, close_behavior: CloseBehavior) -> Self {
30 Self {
31 queue: ArrayQueue::new(capacity),
32 closed: AtomicBool::new(false),
33 senders: AtomicUsize::new(1),
34 receivers: AtomicUsize::new(1),
35 enqueued: AtomicU64::new(0),
36 dropped: AtomicU64::new(0),
37 drained: AtomicU64::new(0),
38 depth: AtomicUsize::new(0),
39 close_behavior,
40 #[cfg(feature = "async-channels")]
41 notify: Arc::new(Notify::new()),
42 #[cfg(feature = "metrics")]
43 metrics: None,
44 }
45 }
46
47 #[cfg(feature = "metrics")]
48 fn new_with_metrics(
49 capacity: usize,
50 close_behavior: CloseBehavior,
51 metrics: Arc<dyn MetricsSink>,
52 ) -> Self {
53 Self {
54 queue: ArrayQueue::new(capacity),
55 closed: AtomicBool::new(false),
56 senders: AtomicUsize::new(1),
57 receivers: AtomicUsize::new(1),
58 enqueued: AtomicU64::new(0),
59 dropped: AtomicU64::new(0),
60 drained: AtomicU64::new(0),
61 depth: AtomicUsize::new(0),
62 close_behavior,
63 #[cfg(feature = "async-channels")]
64 notify: Arc::new(Notify::new()),
65 metrics: Some(metrics),
66 }
67 }
68
69 fn mark_closed(&self) {
70 self.closed.store(true, Ordering::Release);
71 }
72
73 fn try_close(&self) {
74 match self.close_behavior {
75 CloseBehavior::FailFast => {
76 if self.senders.load(Ordering::Acquire) == 0
77 || self.receivers.load(Ordering::Acquire) == 0
78 {
79 self.mark_closed();
80 }
81 }
82 CloseBehavior::DrainUntilSendersDone => {
83 if self.senders.load(Ordering::Acquire) == 0 {
84 self.mark_closed();
85 }
86 }
87 }
88 }
89
90 #[cfg(feature = "metrics")]
91 fn inc(&self, key: &'static str) {
92 if let Some(metrics) = &self.metrics {
93 metrics.increment(key, 1);
94 }
95 }
96}
97
98pub struct BoundedSender<T> {
99 inner: Arc<BoundedInner<T>>,
100}
101
102impl<T> Clone for BoundedSender<T> {
103 fn clone(&self) -> Self {
104 self.inner.senders.fetch_add(1, Ordering::Relaxed);
105 Self {
106 inner: Arc::clone(&self.inner),
107 }
108 }
109}
110
111impl<T> Drop for BoundedSender<T> {
112 fn drop(&mut self) {
113 self.inner.senders.fetch_sub(1, Ordering::Relaxed);
114 self.inner.try_close();
115 }
116}
117
118impl<T> std::fmt::Debug for BoundedSender<T> {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("BoundedSender").finish_non_exhaustive()
121 }
122}
123
124pub struct BoundedReceiver<T> {
125 inner: Arc<BoundedInner<T>>,
126}
127
128impl<T> Clone for BoundedReceiver<T> {
129 fn clone(&self) -> Self {
130 self.inner.receivers.fetch_add(1, Ordering::Relaxed);
131 Self {
132 inner: Arc::clone(&self.inner),
133 }
134 }
135}
136
137impl<T> Drop for BoundedReceiver<T> {
138 fn drop(&mut self) {
139 self.inner.receivers.fetch_sub(1, Ordering::Relaxed);
140 self.inner.try_close();
141 }
142}
143
144pub fn bounded<T>(capacity: usize) -> (BoundedSender<T>, BoundedReceiver<T>) {
145 assert!(capacity > 0, "capacity must be greater than zero");
146 let inner = Arc::new(BoundedInner::new(capacity, CloseBehavior::FailFast));
147 (
148 BoundedSender {
149 inner: Arc::clone(&inner),
150 },
151 BoundedReceiver { inner },
152 )
153}
154
155pub fn bounded_with_behavior<T>(
156 capacity: usize,
157 close_behavior: CloseBehavior,
158) -> (BoundedSender<T>, BoundedReceiver<T>) {
159 assert!(capacity > 0, "capacity must be greater than zero");
160 let inner = Arc::new(BoundedInner::new(capacity, close_behavior));
161 (
162 BoundedSender {
163 inner: Arc::clone(&inner),
164 },
165 BoundedReceiver { inner },
166 )
167}
168
169#[cfg(feature = "metrics")]
170pub fn bounded_with_metrics<T>(
171 capacity: usize,
172 metrics: Arc<dyn MetricsSink>,
173) -> (BoundedSender<T>, BoundedReceiver<T>) {
174 assert!(capacity > 0, "capacity must be greater than zero");
175 let inner = Arc::new(BoundedInner::new_with_metrics(
176 capacity,
177 CloseBehavior::FailFast,
178 metrics,
179 ));
180 (
181 BoundedSender {
182 inner: Arc::clone(&inner),
183 },
184 BoundedReceiver { inner },
185 )
186}
187
188#[cfg(feature = "metrics")]
189pub fn bounded_with_metrics_and_behavior<T>(
190 capacity: usize,
191 close_behavior: CloseBehavior,
192 metrics: Arc<dyn MetricsSink>,
193) -> (BoundedSender<T>, BoundedReceiver<T>) {
194 assert!(capacity > 0, "capacity must be greater than zero");
195 let inner = Arc::new(BoundedInner::new_with_metrics(
196 capacity,
197 close_behavior,
198 metrics,
199 ));
200 (
201 BoundedSender {
202 inner: Arc::clone(&inner),
203 },
204 BoundedReceiver { inner },
205 )
206}
207
208impl<T: Send> ChannelSend<T> for BoundedSender<T> {
209 fn send(&self, value: T) -> Backpressure {
210 match self.try_send_owned(value) {
211 Ok(()) => Backpressure::Ok,
212 Err((Backpressure::Closed, _)) => {
213 #[cfg(feature = "metrics")]
214 self.inner.inc("channel.bounded.closed");
215 Backpressure::Closed
216 }
217 Err((Backpressure::Full, _)) => {
218 self.inner.dropped.fetch_add(1, Ordering::Relaxed);
220 #[cfg(feature = "metrics")]
221 self.inner.inc("channel.bounded.full");
222 Backpressure::Full
223 }
224 Err((Backpressure::Ok, _)) => Backpressure::Ok,
225 }
226 }
227}
228
229impl<T> BoundedSender<T> {
230 pub fn try_send_owned(&self, value: T) -> Result<(), (Backpressure, T)> {
232 if self.inner.closed.load(Ordering::Acquire) {
233 return Err((Backpressure::Closed, value));
234 }
235
236 match self.inner.queue.push(value) {
237 Ok(()) => {
238 self.inner.enqueued.fetch_add(1, Ordering::Relaxed);
239 self.inner.depth.fetch_add(1, Ordering::Relaxed);
240 Ok(())
241 }
242 Err(v) => Err((Backpressure::Full, v)),
243 }
244 }
245}
246
247impl<T: Send> ChannelRecv<T> for BoundedReceiver<T> {
248 fn try_recv(&self) -> RecvOutcome<T> {
249 match self.inner.queue.pop() {
250 Some(v) => {
251 self.inner.drained.fetch_add(1, Ordering::Relaxed);
252 self.inner.depth.fetch_sub(1, Ordering::Relaxed);
253 #[cfg(feature = "async-channels")]
254 self.inner.notify.notify_one();
255 RecvOutcome::Data(v)
256 }
257 None if self.inner.closed.load(Ordering::Acquire) => RecvOutcome::Closed,
258 None => RecvOutcome::Empty,
259 }
260 }
261}
262
263impl<T> BoundedReceiver<T> {
264 pub fn stats(&self) -> ChannelStats {
265 ChannelStats {
266 enqueued: self.inner.enqueued.load(Ordering::Relaxed),
267 dropped: self.inner.dropped.load(Ordering::Relaxed),
268 drained: self.inner.drained.load(Ordering::Relaxed),
269 depth: self.inner.depth.load(Ordering::Relaxed),
270 closed: self.inner.closed.load(Ordering::Relaxed),
271 }
272 }
273}
274
275#[cfg(feature = "async-channels")]
276impl<T> BoundedSender<T> {
277 pub(crate) fn async_notify(&self) -> Arc<Notify> {
278 Arc::clone(&self.inner.notify)
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use proptest::prelude::*;
286 use std::sync::Arc;
287 use std::sync::atomic::{AtomicUsize, Ordering};
288 use std::thread;
289
290 #[test]
291 fn bounded_capacity_respected() {
292 let (tx, rx) = bounded(1);
293 assert_eq!(tx.send(1), Backpressure::Ok);
294 assert_eq!(tx.send(2), Backpressure::Full);
295 assert_eq!(rx.try_recv(), RecvOutcome::Data(1));
296 assert_eq!(rx.try_recv(), RecvOutcome::Empty);
297 }
298
299 #[test]
300 fn bounded_closed_when_senders_drop() {
301 let (tx, rx) = bounded::<u8>(1);
302 drop(tx);
303 assert_eq!(rx.try_recv(), RecvOutcome::Closed);
304 }
305
306 #[test]
307 fn bounded_preserves_order_until_capacity() {
308 let (tx, rx) = bounded(3);
309 for i in 0..5 {
310 let _ = tx.send(i);
311 }
312 assert_eq!(rx.try_recv(), RecvOutcome::Data(0));
313 assert_eq!(rx.try_recv(), RecvOutcome::Data(1));
314 assert_eq!(rx.try_recv(), RecvOutcome::Data(2));
315 assert_eq!(rx.try_recv(), RecvOutcome::Empty);
316 }
317
318 #[test]
319 fn bounded_multi_consumer_does_not_drop_when_space() {
320 let (tx, rx1) = bounded(4);
321 let rx2 = rx1.clone();
322 for i in 0..2 {
323 assert_eq!(tx.send(i), Backpressure::Ok);
324 }
325 assert_eq!(rx1.try_recv(), RecvOutcome::Data(0));
327 assert_eq!(rx2.try_recv(), RecvOutcome::Data(1));
328 }
329
330 proptest! {
331 #[test]
332 fn bounded_keeps_first_cap_messages_in_order(input in proptest::collection::vec(any::<u8>(), 1..20)) {
333 let capacity = 5usize;
334 let (tx, rx) = bounded(capacity);
335 for v in &input {
336 let _ = tx.send(*v);
337 }
338 let mut drained = Vec::new();
339 while let RecvOutcome::Data(v) = rx.try_recv() {
340 drained.push(v);
341 }
342 let expected: Vec<u8> = input.into_iter().take(capacity).collect();
343 prop_assert_eq!(drained, expected);
344 }
345 }
346
347 #[test]
348 fn bounded_mpmc_stress() {
349 let (tx, rx) = bounded(8);
350 let tx = Arc::new(tx);
351 let rx = Arc::new(rx);
352 let produced = 4usize * 50usize;
353 let received = Arc::new(AtomicUsize::new(0));
354
355 let mut handles = Vec::new();
356 for _ in 0..4 {
357 let txc = tx.clone();
358 handles.push(thread::spawn(move || {
359 for i in 0..50u32 {
360 loop {
361 if txc.send(i) == Backpressure::Ok {
362 break;
363 }
364 std::thread::yield_now();
365 }
366 }
367 }));
368 }
369
370 let mut recv_handles = Vec::new();
371 for _ in 0..2 {
372 let rxc = rx.clone();
373 let recv_count = received.clone();
374 recv_handles.push(thread::spawn(move || {
375 loop {
376 match rxc.try_recv() {
377 RecvOutcome::Data(_) => {
378 recv_count.fetch_add(1, Ordering::Relaxed);
379 }
380 RecvOutcome::Empty => {
381 if recv_count.load(Ordering::Relaxed) >= produced {
382 break;
383 }
384 std::thread::yield_now();
385 }
386 RecvOutcome::Closed => break,
387 }
388 }
389 }));
390 }
391
392 for h in handles {
393 h.join().unwrap();
394 }
395 drop(tx);
396
397 for h in recv_handles {
398 h.join().unwrap();
399 }
400 assert_eq!(received.load(Ordering::Relaxed), produced);
401 }
402}
403
404#[cfg(all(test, feature = "metrics"))]
405mod metric_tests {
406 use super::*;
407 use crate::metrics::InMemoryMetrics;
408 use std::sync::Arc;
409
410 #[test]
411 fn metrics_record_full_and_closed() {
412 let metrics = Arc::new(InMemoryMetrics::default());
413 let collector: Arc<dyn crate::metrics::MetricsSink> = metrics.clone();
414 let (tx, rx) = bounded_with_metrics(1, collector);
415 assert_eq!(tx.send(1), Backpressure::Ok);
416 assert_eq!(tx.send(2), Backpressure::Full);
417 assert_eq!(metrics.counter("channel.bounded.full"), 1);
418 drop(rx);
419 assert_eq!(tx.send(3), Backpressure::Closed);
420 assert_eq!(metrics.counter("channel.bounded.closed"), 1);
421 }
422}