1use core::ops::Deref;
2use std::{
3 iter,
4 pin::Pin,
5 sync::Arc,
6 task::{Poll, Wake, Waker},
7};
8
9use futures::Stream;
10use log::{debug, trace, warn};
11
12use crate::{registry::CloneRegistry, ring_queue::RingQueue};
13
14const MAX_CLONE_COUNT: usize = 65536;
16
17const MAX_QUEUE_SIZE: usize = 1024 * 1024;
19
20#[derive(Debug, Clone, Copy)]
21pub struct ForkConfig {
22 pub max_clone_count: usize,
24 pub max_queue_size: usize,
26}
27
28impl Default for ForkConfig {
29 fn default() -> Self {
30 Self {
31 max_clone_count: MAX_CLONE_COUNT,
32 max_queue_size: MAX_QUEUE_SIZE,
33 }
34 }
35}
36
37pub(crate) struct Fork<BaseStream>
38where
39 BaseStream: Stream<Item: Clone>,
40{
41 pub(crate) base_stream: Pin<Box<BaseStream>>,
42 pub(crate) item_buffer: RingQueue<Option<BaseStream::Item>>,
43 pub(crate) clone_registry: CloneRegistry,
44}
45
46impl<BaseStream> Fork<BaseStream>
47where
48 BaseStream: Stream<Item: Clone>,
49{
50 pub(crate) fn new(base_stream: BaseStream) -> Self {
51 Self::with_config(base_stream, ForkConfig::default())
52 }
53
54 pub(crate) fn with_config(base_stream: BaseStream, config: ForkConfig) -> Self {
55 Self {
56 base_stream: Box::pin(base_stream),
57 clone_registry: CloneRegistry::new(config.max_clone_count),
58 item_buffer: RingQueue::new(config.max_queue_size),
59 }
60 }
61
62 pub(crate) fn poll_clone(
63 &mut self,
64 clone_id: usize,
65 clone_waker: &Waker,
66 ) -> Poll<Option<BaseStream::Item>> {
67 let mut current_state = self.clone_registry.take(clone_id).unwrap();
68 debug!("State of clone {clone_id} is {current_state:?}.");
69
70 let poll_result = current_state.step(clone_id, clone_waker, self);
71
72 debug!("Clone {clone_id} transitioned to {current_state:?}.");
73 self.clone_registry
74 .restore(clone_id, current_state)
75 .expect("Failed to restore clone state - this should never happen as we just took it");
76 poll_result
77 }
78
79 pub(crate) fn waker(&self, extra_waker: &Waker) -> Waker {
80 let clone_wakers = self.clone_registry.collect_wakers_needing_base_item();
81 trace!(
82 "There are {} clone wakers needing base item. Adding one more",
83 clone_wakers.len()
84 );
85 let waker_count = clone_wakers.len() + 1;
86
87 if waker_count == 1 {
89 extra_waker.clone()
90 } else {
91 let all_wakers = clone_wakers
92 .into_iter()
93 .chain(iter::once(extra_waker.clone()))
94 .collect();
95 Waker::from(Arc::new(MultiWaker { wakers: all_wakers }))
96 }
97 }
98
99 pub(crate) fn remaining_queued_items(&self, clone_id: usize) -> usize {
100 (&self.item_buffer)
101 .into_iter()
102 .map(|(item_index, _)| item_index)
103 .filter(|&item_index| self.should_clone_see_item(clone_id, item_index))
104 .count()
105 }
106
107 pub(crate) fn should_clone_see_item(&self, clone_id: usize, queue_item_index: usize) -> bool {
108 if let Some(state) = self.clone_registry.get_clone_state(clone_id) {
109 match state {
110 crate::states::CloneState::AwaitingFirstItem
111 | crate::states::CloneState::AwaitingBaseStream { .. } => true,
112 crate::states::CloneState::AwaitingBaseStreamWithQueueHistory {
113 last_seen_index,
114 ..
115 } => self
116 .item_buffer
117 .is_newer_than(queue_item_index, *last_seen_index),
118 crate::states::CloneState::ProcessingQueue {
119 last_seen_queue_index: unseen_index,
120 } => !self
121 .item_buffer
122 .is_newer_than(queue_item_index, *unseen_index),
123 crate::states::CloneState::BaseStreamReady
124 | crate::states::CloneState::BaseStreamReadyWithQueueHistory => false,
125 }
126 } else {
127 false
128 }
129 }
130
131 pub(crate) fn unregister(&mut self, clone_id: usize) {
132 self.clone_registry.unregister(clone_id);
133 self.cleanup_unneeded_queue_items();
134 }
135
136 fn cleanup_unneeded_queue_items(&mut self) {
137 if self.clone_registry.count() == 0 {
138 self.item_buffer.clear();
139 return;
140 }
141
142 let items_to_remove: Vec<usize> = (&self.item_buffer)
143 .into_iter()
144 .filter_map(|(item_index, _)| {
145 let is_needed = self
146 .clone_registry
147 .iter_active_with_ids()
148 .any(|(clone_id, _)| self.should_clone_see_item(clone_id, item_index));
149 (!is_needed).then_some(item_index)
150 })
151 .collect();
152
153 for item_index in items_to_remove {
154 self.item_buffer.remove(item_index);
155 }
156 }
157}
158
159impl<BaseStream> Deref for Fork<BaseStream>
160where
161 BaseStream: Stream<Item: Clone>,
162{
163 type Target = BaseStream;
164
165 fn deref(&self) -> &Self::Target {
166 &self.base_stream
167 }
168}
169
170pub(crate) struct MultiWaker {
171 wakers: Vec<Waker>,
172}
173
174impl Wake for MultiWaker {
175 fn wake(self: Arc<Self>) {
176 warn!("New data arrived in source stream, waking up sleeping clones.");
177 self.wakers.iter().for_each(Waker::wake_by_ref);
178 }
179}