clone_stream/
fork.rs

1use core::ops::Deref;
2use std::{
3    collections::{BTreeMap, BTreeSet},
4    pin::Pin,
5    sync::Arc,
6    task::{Poll, Wake, Waker},
7};
8
9use futures::Stream;
10use log::trace;
11
12use crate::{
13    error::{CloneStreamError, Result},
14    states::{CloneState, NewStateAndPollResult, StateHandler},
15};
16
17/// Maximum number of clones that can be registered simultaneously.
18const MAX_CLONE_COUNT: usize = 65536;
19
20/// Maximum number of items that can be queued simultaneously.
21const MAX_QUEUE_SIZE: usize = 1024 * 1024;
22
23/// Configuration for Fork behavior.
24#[derive(Debug, Clone, Copy)]
25pub struct ForkConfig {
26    /// Maximum number of clones allowed.
27    pub max_clone_count: usize,
28    /// Maximum queue size before panic.
29    pub max_queue_size: usize,
30}
31
32impl Default for ForkConfig {
33    fn default() -> Self {
34        Self {
35            max_clone_count: MAX_CLONE_COUNT,
36            max_queue_size: MAX_QUEUE_SIZE,
37        }
38    }
39}
40
41pub(crate) struct Fork<BaseStream>
42where
43    BaseStream: Stream<Item: Clone>,
44{
45    pub(crate) base_stream: Pin<Box<BaseStream>>,
46    pub(crate) queue: BTreeMap<usize, Option<BaseStream::Item>>,
47    pub(crate) clones: BTreeMap<usize, CloneState>,
48    available_clone_indices: BTreeSet<usize>,
49    pub(crate) next_queue_index: usize,
50    latest_cached_item_index: Option<usize>,
51    config: ForkConfig,
52}
53
54impl<BaseStream> Fork<BaseStream>
55where
56    BaseStream: Stream<Item: Clone>,
57{
58    pub(crate) fn new(base_stream: BaseStream) -> Self {
59        Self::with_config(base_stream, ForkConfig::default())
60    }
61
62    pub(crate) fn with_config(base_stream: BaseStream, config: ForkConfig) -> Self {
63        Self {
64            base_stream: Box::pin(base_stream),
65            clones: BTreeMap::default(),
66            queue: BTreeMap::new(),
67            next_queue_index: 0,
68            latest_cached_item_index: None,
69            available_clone_indices: BTreeSet::new(),
70            config,
71        }
72    }
73
74    pub(crate) fn poll_clone(
75        &mut self,
76        clone_id: usize,
77        clone_waker: &Waker,
78    ) -> Poll<Option<BaseStream::Item>> {
79        trace!("Clone {clone_id} is being polled through the fork.");
80        let current_state = self.clones.remove(&clone_id).unwrap();
81
82        let NewStateAndPollResult {
83            poll_result,
84            new_state,
85        } = current_state.handle(clone_waker, self);
86
87        trace!("Inserting clone {clone_id} back into the fork with state: {new_state:?}.");
88        self.clones.insert(clone_id, new_state);
89        poll_result
90    }
91
92    pub(crate) fn waker(&self, extra_waker: &Waker) -> Waker {
93        let wakers = self
94            .clones
95            .iter()
96            .filter(|(_clone_id, state)| state.should_still_see_base_item())
97            .filter_map(|(_clone_id, state)| state.waker().clone())
98            .chain(std::iter::once(extra_waker.clone()))
99            .collect::<Vec<_>>();
100
101        trace!("Found {} wakers.", wakers.len());
102
103        Waker::from(Arc::new(MultiWaker { wakers }))
104    }
105
106    /// Register a new clone and return its ID
107    pub(crate) fn register(&mut self) -> Result<usize> {
108        if let Some(reused_id) = self.available_clone_indices.pop_first() {
109            trace!("Registering clone {reused_id} (reused index).");
110            self.clones.insert(reused_id, CloneState::default());
111            return Ok(reused_id);
112        }
113
114        // Derive the next new index by finding the lowest unused index
115        let next_clone_index = (0..self.config.max_clone_count)
116            .find(|&id| !self.clones.contains_key(&id))
117            .ok_or(CloneStreamError::MaxClonesExceeded {
118                current_count: self.clones.len(),
119                max_allowed: self.config.max_clone_count,
120            })?;
121
122        trace!("Registering clone {next_clone_index} (new index).");
123        self.clones.insert(next_clone_index, CloneState::default());
124        Ok(next_clone_index)
125    }
126
127    /// Calculates the remaining capacity in the queue.
128    fn queue_capacity(&self) -> usize {
129        self.config.max_queue_size.saturating_sub(self.queue.len())
130    }
131
132    /// Allocates a new queue index with ring-buffer wrapping.
133    pub(crate) fn allocate_queue_index(&mut self) -> Result<usize> {
134        // Check if we have capacity for more items
135        if self.queue_capacity() == 0 {
136            return Err(CloneStreamError::MaxQueueSizeExceeded {
137                max_allowed: self.config.max_queue_size,
138                current_size: self.queue.len(),
139            });
140        }
141
142        let candidate_index = self.next_queue_index;
143        self.next_queue_index = (self.next_queue_index + 1) % self.config.max_queue_size;
144
145        self.latest_cached_item_index = Some(candidate_index);
146
147        Ok(candidate_index)
148    }
149    pub(crate) fn unregister(&mut self, clone_id: usize) {
150        trace!("Unregistering clone {clone_id}.");
151        if self.clones.remove(&clone_id).is_none() {
152            log::warn!("Attempted to unregister clone {clone_id} that was not registered");
153            return;
154        }
155
156        // Insert the index back to the available pool - BTreeSet handles ordering
157        // automatically
158        if !self.available_clone_indices.insert(clone_id) {
159            log::warn!("Clone index {clone_id} was already in available pool");
160        }
161
162        self.queue.retain(|item_index, _| {
163            self.clones
164                .values()
165                .any(|state| state.should_still_see_item(*item_index))
166        });
167    }
168
169    pub(crate) fn remaining_queued_items(&self, clone_id: usize) -> usize {
170        self.queue
171            .iter()
172            .filter(|(item_index, _)| {
173                self.clones
174                    .get(&clone_id)
175                    .unwrap()
176                    .should_still_see_item(**item_index)
177            })
178            .count()
179    }
180}
181
182impl<BaseStream> Deref for Fork<BaseStream>
183where
184    BaseStream: Stream<Item: Clone>,
185{
186    type Target = BaseStream;
187
188    fn deref(&self) -> &Self::Target {
189        &self.base_stream
190    }
191}
192
193pub(crate) struct MultiWaker {
194    wakers: Vec<Waker>,
195}
196
197impl Wake for MultiWaker {
198    fn wake(self: Arc<Self>) {
199        trace!("Waking up all sleeping clones.");
200        self.wakers.iter().for_each(Waker::wake_by_ref);
201    }
202}