1use std::{
2 pin::Pin,
3 sync::{Arc, RwLock},
4 task::{Context, Poll},
5};
6
7use futures::{Stream, stream::FusedStream};
8use log::trace;
9
10use crate::fork::Fork;
11
12pub struct CloneStream<BaseStream>
29where
30 BaseStream: Stream<Item: Clone>,
31{
32 pub(crate) fork: Arc<RwLock<Fork<BaseStream>>>,
33 pub id: usize,
35}
36
37impl<BaseStream> From<Fork<BaseStream>> for CloneStream<BaseStream>
38where
39 BaseStream: Stream<Item: Clone>,
40{
41 fn from(mut fork: Fork<BaseStream>) -> Self {
42 let id = fork
43 .clone_registry
44 .register()
45 .expect("Failed to register initial clone");
46
47 Self {
48 id,
49 fork: Arc::new(RwLock::new(fork)),
50 }
51 }
52}
53
54impl<BaseStream> Clone for CloneStream<BaseStream>
55where
56 BaseStream: Stream<Item: Clone>,
57{
58 fn clone(&self) -> Self {
68 let mut fork = self.fork.write().expect("Fork lock poisoned during clone");
69 let clone_id = fork
70 .clone_registry
71 .register()
72 .expect("Failed to register clone - clone limit exceeded");
73 drop(fork);
74
75 Self {
76 fork: self.fork.clone(),
77 id: clone_id,
78 }
79 }
80}
81
82impl<BaseStream> Stream for CloneStream<BaseStream>
83where
84 BaseStream: Stream<Item: Clone>,
85{
86 type Item = BaseStream::Item;
87
88 fn poll_next(self: Pin<&mut Self>, current_task: &mut Context) -> Poll<Option<Self::Item>> {
89 trace!("Polling next item for clone {}.", self.id);
90 let waker = current_task.waker();
91 let mut fork = self
92 .fork
93 .write()
94 .expect("Fork lock poisoned during poll_next");
95 fork.poll_clone(self.id, waker)
96 }
97
98 fn size_hint(&self) -> (usize, Option<usize>) {
99 let fork = self
100 .fork
101 .read()
102 .expect("Fork lock poisoned during size_hint");
103 let (lower, upper) = fork.size_hint();
104 let n_cached = fork.remaining_queued_items(self.id);
105 drop(fork);
106 (lower + n_cached, upper.map(|u| u + n_cached))
107 }
108}
109
110impl<BaseStream> FusedStream for CloneStream<BaseStream>
111where
112 BaseStream: FusedStream<Item: Clone>,
113{
114 fn is_terminated(&self) -> bool {
120 let fork = self
121 .fork
122 .read()
123 .expect("Fork lock poisoned during is_terminated");
124 fork.is_terminated() && fork.remaining_queued_items(self.id) == 0
125 }
126}
127
128impl<BaseStream> Drop for CloneStream<BaseStream>
129where
130 BaseStream: Stream<Item: Clone>,
131{
132 fn drop(&mut self) {
133 if let Ok(mut fork) = self.fork.try_write() {
134 fork.unregister(self.id);
135 } else {
136 log::warn!(
137 "Failed to acquire lock during clone drop for clone {}",
138 self.id
139 );
140 }
141 }
142}
143
144impl<BaseStream> CloneStream<BaseStream>
145where
146 BaseStream: Stream<Item: Clone>,
147{
148 #[must_use]
170 pub fn n_queued_items(&self) -> usize {
171 trace!("Getting the number of queued items for clone {}.", self.id);
172 self.fork
173 .read()
174 .expect("Fork lock poisoned during n_queued_items")
175 .remaining_queued_items(self.id)
176 }
177}