1use core::ops::Deref;
2use std::{
3 collections::{BTreeMap, BTreeSet},
4 pin::Pin,
5 sync::Arc,
6 task::{Poll, Wake, Waker},
7};
8
9use futures::Stream;
10use log::trace;
11
12use crate::{
13 error::{CloneStreamError, Result},
14 states::{CloneState, NewStateAndPollResult, StateHandler},
15};
16
17const MAX_CLONE_COUNT: usize = 65536;
19
20const MAX_QUEUE_SIZE: usize = 1024 * 1024;
22
23#[derive(Debug, Clone, Copy)]
25pub struct ForkConfig {
26 pub max_clone_count: usize,
28 pub max_queue_size: usize,
30}
31
32impl Default for ForkConfig {
33 fn default() -> Self {
34 Self {
35 max_clone_count: MAX_CLONE_COUNT,
36 max_queue_size: MAX_QUEUE_SIZE,
37 }
38 }
39}
40
41pub(crate) struct Fork<BaseStream>
42where
43 BaseStream: Stream<Item: Clone>,
44{
45 pub(crate) base_stream: Pin<Box<BaseStream>>,
46 pub(crate) queue: BTreeMap<usize, Option<BaseStream::Item>>,
47 pub(crate) clones: BTreeMap<usize, CloneState>,
48 available_clone_indices: BTreeSet<usize>,
49 pub(crate) next_queue_index: usize,
50 latest_cached_item_index: Option<usize>,
51 config: ForkConfig,
52}
53
54impl<BaseStream> Fork<BaseStream>
55where
56 BaseStream: Stream<Item: Clone>,
57{
58 pub(crate) fn new(base_stream: BaseStream) -> Self {
59 Self::with_config(base_stream, ForkConfig::default())
60 }
61
62 pub(crate) fn with_config(base_stream: BaseStream, config: ForkConfig) -> Self {
63 Self {
64 base_stream: Box::pin(base_stream),
65 clones: BTreeMap::default(),
66 queue: BTreeMap::new(),
67 next_queue_index: 0,
68 latest_cached_item_index: None,
69 available_clone_indices: BTreeSet::new(),
70 config,
71 }
72 }
73
74 pub(crate) fn poll_clone(
75 &mut self,
76 clone_id: usize,
77 clone_waker: &Waker,
78 ) -> Poll<Option<BaseStream::Item>> {
79 trace!("Clone {clone_id} is being polled through the fork.");
80 let current_state = self.clones.remove(&clone_id).unwrap();
81
82 let NewStateAndPollResult {
83 poll_result,
84 new_state,
85 } = current_state.handle(clone_waker, self);
86
87 trace!("Inserting clone {clone_id} back into the fork with state: {new_state:?}.");
88 self.clones.insert(clone_id, new_state);
89 poll_result
90 }
91
92 pub(crate) fn waker(&self, extra_waker: &Waker) -> Waker {
93 let wakers = self
94 .clones
95 .iter()
96 .filter(|(_clone_id, state)| state.should_still_see_base_item())
97 .filter_map(|(_clone_id, state)| state.waker().clone())
98 .chain(std::iter::once(extra_waker.clone()))
99 .collect::<Vec<_>>();
100
101 trace!("Found {} wakers.", wakers.len());
102
103 Waker::from(Arc::new(MultiWaker { wakers }))
104 }
105
106 pub(crate) fn register(&mut self) -> Result<usize> {
108 if let Some(reused_id) = self.available_clone_indices.pop_first() {
109 trace!("Registering clone {reused_id} (reused index).");
110 self.clones.insert(reused_id, CloneState::default());
111 return Ok(reused_id);
112 }
113
114 let next_clone_index = (0..self.config.max_clone_count)
116 .find(|&id| !self.clones.contains_key(&id))
117 .ok_or(CloneStreamError::MaxClonesExceeded {
118 current_count: self.clones.len(),
119 max_allowed: self.config.max_clone_count,
120 })?;
121
122 trace!("Registering clone {next_clone_index} (new index).");
123 self.clones.insert(next_clone_index, CloneState::default());
124 Ok(next_clone_index)
125 }
126
127 fn queue_capacity(&self) -> usize {
129 self.config.max_queue_size.saturating_sub(self.queue.len())
130 }
131
132 pub(crate) fn allocate_queue_index(&mut self) -> Result<usize> {
134 if self.queue_capacity() == 0 {
136 return Err(CloneStreamError::MaxQueueSizeExceeded {
137 max_allowed: self.config.max_queue_size,
138 current_size: self.queue.len(),
139 });
140 }
141
142 let candidate_index = self.next_queue_index;
143 self.next_queue_index = (self.next_queue_index + 1) % self.config.max_queue_size;
144
145 self.latest_cached_item_index = Some(candidate_index);
146
147 Ok(candidate_index)
148 }
149 pub(crate) fn unregister(&mut self, clone_id: usize) {
150 trace!("Unregistering clone {clone_id}.");
151 if self.clones.remove(&clone_id).is_none() {
152 log::warn!("Attempted to unregister clone {clone_id} that was not registered");
153 return;
154 }
155
156 if !self.available_clone_indices.insert(clone_id) {
159 log::warn!("Clone index {clone_id} was already in available pool");
160 }
161
162 self.queue.retain(|item_index, _| {
163 self.clones
164 .values()
165 .any(|state| state.should_still_see_item(*item_index))
166 });
167 }
168
169 pub(crate) fn remaining_queued_items(&self, clone_id: usize) -> usize {
170 self.queue
171 .iter()
172 .filter(|(item_index, _)| {
173 self.clones
174 .get(&clone_id)
175 .unwrap()
176 .should_still_see_item(**item_index)
177 })
178 .count()
179 }
180}
181
182impl<BaseStream> Deref for Fork<BaseStream>
183where
184 BaseStream: Stream<Item: Clone>,
185{
186 type Target = BaseStream;
187
188 fn deref(&self) -> &Self::Target {
189 &self.base_stream
190 }
191}
192
193pub(crate) struct MultiWaker {
194 wakers: Vec<Waker>,
195}
196
197impl Wake for MultiWaker {
198 fn wake(self: Arc<Self>) {
199 trace!("Waking up all sleeping clones.");
200 self.wakers.iter().for_each(Waker::wake_by_ref);
201 }
202}