clone_stream/
fork.rs

1use core::ops::Deref;
2use std::{
3    iter,
4    pin::Pin,
5    sync::Arc,
6    task::{Poll, Wake, Waker},
7};
8
9use futures::Stream;
10use log::{debug, trace, warn};
11
12use crate::{registry::CloneRegistry, ring_queue::RingQueue};
13
14/// Maximum number of clones that can be registered simultaneously.
15const MAX_CLONE_COUNT: usize = 65536;
16
17/// Maximum number of items that can be queued simultaneously.
18const MAX_QUEUE_SIZE: usize = 1024 * 1024;
19
20#[derive(Debug, Clone, Copy)]
21pub struct ForkConfig {
22    /// Maximum number of clones allowed.
23    pub max_clone_count: usize,
24    /// Maximum queue size before panic.
25    pub max_queue_size: usize,
26}
27
28impl Default for ForkConfig {
29    fn default() -> Self {
30        Self {
31            max_clone_count: MAX_CLONE_COUNT,
32            max_queue_size: MAX_QUEUE_SIZE,
33        }
34    }
35}
36
37pub(crate) struct Fork<BaseStream>
38where
39    BaseStream: Stream<Item: Clone>,
40{
41    pub(crate) base_stream: Pin<Box<BaseStream>>,
42    pub(crate) item_buffer: RingQueue<Option<BaseStream::Item>>,
43    pub(crate) clone_registry: CloneRegistry,
44}
45
46impl<BaseStream> Fork<BaseStream>
47where
48    BaseStream: Stream<Item: Clone>,
49{
50    pub(crate) fn new(base_stream: BaseStream) -> Self {
51        Self::with_config(base_stream, ForkConfig::default())
52    }
53
54    pub(crate) fn with_config(base_stream: BaseStream, config: ForkConfig) -> Self {
55        Self {
56            base_stream: Box::pin(base_stream),
57            clone_registry: CloneRegistry::new(config.max_clone_count),
58            item_buffer: RingQueue::new(config.max_queue_size),
59        }
60    }
61
62    pub(crate) fn poll_clone(
63        &mut self,
64        clone_id: usize,
65        clone_waker: &Waker,
66    ) -> Poll<Option<BaseStream::Item>> {
67        let mut current_state = self.clone_registry.take(clone_id).unwrap();
68        debug!("State of clone {clone_id} is {current_state:?}.");
69
70        let poll_result = current_state.step(clone_id, clone_waker, self);
71
72        debug!("Clone {clone_id} transitioned to {current_state:?}.");
73        self.clone_registry
74            .restore(clone_id, current_state)
75            .expect("Failed to restore clone state - this should never happen as we just took it");
76        poll_result
77    }
78
79    pub(crate) fn waker(&self, extra_waker: &Waker) -> Waker {
80        let clone_wakers = self.clone_registry.collect_wakers_needing_base_item();
81        trace!(
82            "There are {} clone wakers needing base item. Adding one more",
83            clone_wakers.len()
84        );
85        let waker_count = clone_wakers.len() + 1;
86
87        // Avoid Arc allocation for single waker
88        if waker_count == 1 {
89            extra_waker.clone()
90        } else {
91            let all_wakers = clone_wakers
92                .into_iter()
93                .chain(iter::once(extra_waker.clone()))
94                .collect();
95            Waker::from(Arc::new(MultiWaker { wakers: all_wakers }))
96        }
97    }
98
99    pub(crate) fn remaining_queued_items(&self, clone_id: usize) -> usize {
100        (&self.item_buffer)
101            .into_iter()
102            .map(|(item_index, _)| item_index)
103            .filter(|&item_index| self.should_clone_see_item(clone_id, item_index))
104            .count()
105    }
106
107    pub(crate) fn should_clone_see_item(&self, clone_id: usize, queue_item_index: usize) -> bool {
108        if let Some(state) = self.clone_registry.get_clone_state(clone_id) {
109            match state {
110                crate::states::CloneState::AwaitingFirstItem
111                | crate::states::CloneState::AwaitingBaseStream { .. } => true,
112                crate::states::CloneState::AwaitingBaseStreamWithQueueHistory {
113                    last_seen_index,
114                    ..
115                } => self
116                    .item_buffer
117                    .is_newer_than(queue_item_index, *last_seen_index),
118                crate::states::CloneState::ProcessingQueue {
119                    last_seen_queue_index: unseen_index,
120                } => !self
121                    .item_buffer
122                    .is_newer_than(queue_item_index, *unseen_index),
123                crate::states::CloneState::BaseStreamReady
124                | crate::states::CloneState::BaseStreamReadyWithQueueHistory => false,
125            }
126        } else {
127            false
128        }
129    }
130
131    pub(crate) fn unregister(&mut self, clone_id: usize) {
132        self.clone_registry.unregister(clone_id);
133        self.cleanup_unneeded_queue_items();
134    }
135
136    fn cleanup_unneeded_queue_items(&mut self) {
137        if self.clone_registry.count() == 0 {
138            self.item_buffer.clear();
139            return;
140        }
141
142        let items_to_remove: Vec<usize> = (&self.item_buffer)
143            .into_iter()
144            .filter_map(|(item_index, _)| {
145                let is_needed = self
146                    .clone_registry
147                    .iter_active_with_ids()
148                    .any(|(clone_id, _)| self.should_clone_see_item(clone_id, item_index));
149                (!is_needed).then_some(item_index)
150            })
151            .collect();
152
153        for item_index in items_to_remove {
154            self.item_buffer.remove(item_index);
155        }
156    }
157}
158
159impl<BaseStream> Deref for Fork<BaseStream>
160where
161    BaseStream: Stream<Item: Clone>,
162{
163    type Target = BaseStream;
164
165    fn deref(&self) -> &Self::Target {
166        &self.base_stream
167    }
168}
169
170pub(crate) struct MultiWaker {
171    wakers: Vec<Waker>,
172}
173
174impl Wake for MultiWaker {
175    fn wake(self: Arc<Self>) {
176        warn!("New data arrived in source stream, waking up sleeping clones.");
177        self.wakers.iter().for_each(Waker::wake_by_ref);
178    }
179}