clone_stream/
fork.rs

1use core::ops::Deref;
2use std::{
3    iter,
4    pin::Pin,
5    sync::Arc,
6    task::{Poll, Wake, Waker},
7};
8
9use futures::Stream;
10use log::{debug, trace, warn};
11
12use crate::{registry::CloneRegistry, ring_queue::RingQueue};
13
14/// Maximum number of clones that can be registered simultaneously.
15const MAX_CLONE_COUNT: usize = 65536;
16
17/// Maximum number of items that can be queued simultaneously.
18const MAX_QUEUE_SIZE: usize = 1024 * 1024;
19
20#[derive(Debug, Clone, Copy)]
21pub struct ForkConfig {
22    /// Maximum number of clones allowed.
23    pub max_clone_count: usize,
24    /// Maximum queue size before panic.
25    pub max_queue_size: usize,
26}
27
28impl Default for ForkConfig {
29    fn default() -> Self {
30        Self {
31            max_clone_count: MAX_CLONE_COUNT,
32            max_queue_size: MAX_QUEUE_SIZE,
33        }
34    }
35}
36
37pub struct Fork<BaseStream>
38where
39    BaseStream: Stream<Item: Clone>,
40{
41    pub(crate) base_stream: Pin<Box<BaseStream>>,
42    pub(crate) item_buffer: RingQueue<Option<BaseStream::Item>>,
43    pub(crate) clone_registry: CloneRegistry,
44}
45
46impl<BaseStream> Fork<BaseStream>
47where
48    BaseStream: Stream<Item: Clone>,
49{
50    pub(crate) fn new(base_stream: BaseStream) -> Self {
51        Self::with_config(base_stream, ForkConfig::default())
52    }
53
54    pub(crate) fn with_config(base_stream: BaseStream, config: ForkConfig) -> Self {
55        Self {
56            base_stream: Box::pin(base_stream),
57            clone_registry: CloneRegistry::new(config.max_clone_count),
58            item_buffer: RingQueue::new(config.max_queue_size),
59        }
60    }
61
62    pub(crate) fn poll_clone(
63        &mut self,
64        clone_id: usize,
65        clone_waker: &Waker,
66    ) -> Poll<Option<BaseStream::Item>> {
67        let mut current_state = self.clone_registry.take(clone_id).unwrap();
68        debug!("State of clone {clone_id} is {current_state:?}.");
69
70        let poll_result = current_state.step(clone_id, clone_waker, self);
71
72        debug!("Clone {clone_id} transitioned to {current_state:?}.");
73        self.clone_registry
74            .restore(clone_id, current_state)
75            .expect("Failed to restore clone state - this should never happen as we just took it");
76        poll_result
77    }
78
79    pub(crate) fn waker(&self, extra_waker: &Waker) -> Waker {
80        let clone_wakers = self.clone_registry.collect_wakers_needing_base_item();
81        trace!(
82            "There are {} clone wakers needing base item. Adding one more",
83            clone_wakers.len()
84        );
85        let waker_count = clone_wakers.len() + 1;
86
87        // Avoid Arc allocation for single waker
88        if waker_count == 1 {
89            extra_waker.clone()
90        } else {
91            let all_wakers = clone_wakers
92                .into_iter()
93                .chain(iter::once(extra_waker.clone()))
94                .collect();
95            Waker::from(Arc::new(MultiWaker { wakers: all_wakers }))
96        }
97    }
98
99    pub(crate) fn remaining_queued_items(&self, clone_id: usize) -> usize {
100        (&self.item_buffer)
101            .into_iter()
102            .map(|(item_index, _)| item_index)
103            .filter(|&item_index| self.should_clone_see_item(clone_id, item_index))
104            .count()
105    }
106
107    pub(crate) fn should_clone_see_item(&self, clone_id: usize, queue_item_index: usize) -> bool {
108        self.clone_registry
109            .get_clone_state(clone_id)
110            .is_some_and(|state| match state {
111                crate::states::CloneState::AwaitingFirstItem
112                | crate::states::CloneState::AwaitingBaseStream { .. } => true,
113                crate::states::CloneState::AwaitingBaseStreamWithQueueHistory {
114                    last_seen_index,
115                    ..
116                } => self
117                    .item_buffer
118                    .is_newer_than(queue_item_index, *last_seen_index),
119                crate::states::CloneState::ProcessingQueue {
120                    last_seen_queue_index: unseen_index,
121                } => !self
122                    .item_buffer
123                    .is_newer_than(queue_item_index, *unseen_index),
124                crate::states::CloneState::BaseStreamReady
125                | crate::states::CloneState::BaseStreamReadyWithQueueHistory => false,
126            })
127    }
128
129    pub(crate) fn unregister(&mut self, clone_id: usize) {
130        self.clone_registry.unregister(clone_id);
131        self.cleanup_unneeded_queue_items();
132    }
133
134    fn cleanup_unneeded_queue_items(&mut self) {
135        if self.clone_registry.count() == 0 {
136            self.item_buffer.clear();
137            return;
138        }
139
140        let items_to_remove: Vec<usize> = (&self.item_buffer)
141            .into_iter()
142            .filter_map(|(item_index, _)| {
143                let is_needed = self
144                    .clone_registry
145                    .iter_active_with_ids()
146                    .any(|(clone_id, _)| self.should_clone_see_item(clone_id, item_index));
147                (!is_needed).then_some(item_index)
148            })
149            .collect();
150
151        for item_index in items_to_remove {
152            self.item_buffer.remove(item_index);
153        }
154    }
155}
156
157impl<BaseStream> Deref for Fork<BaseStream>
158where
159    BaseStream: Stream<Item: Clone>,
160{
161    type Target = BaseStream;
162
163    fn deref(&self) -> &Self::Target {
164        &self.base_stream
165    }
166}
167
168pub struct MultiWaker {
169    wakers: Vec<Waker>,
170}
171
172impl Wake for MultiWaker {
173    fn wake(self: Arc<Self>) {
174        warn!("New data arrived in source stream, waking up sleeping clones.");
175        self.wakers.iter().for_each(Waker::wake_by_ref);
176    }
177}