1use std::{
2 cell::{RefCell, UnsafeCell},
3 cmp::Ordering,
4 pin::Pin,
5 task::Context,
6 task::Poll,
7 task::Waker,
8};
9
10use crate::generator::{BundleIterator, BundleStream};
11use chunky_vec::ChunkyVec;
12use futures::{ready, Stream};
13use pin_cell::{PinCell, PinMut};
14
15pub struct Cache<I, R>
16where
17 I: Iterator,
18{
19 iter: RefCell<I>,
20 items: UnsafeCell<ChunkyVec<I::Item>>,
21 res: std::marker::PhantomData<R>,
22}
23
24impl<I, R> Cache<I, R>
25where
26 I: Iterator,
27{
28 pub fn new(iter: I) -> Self {
29 Self {
30 iter: RefCell::new(iter),
31 items: Default::default(),
32 res: std::marker::PhantomData,
33 }
34 }
35
36 pub fn len(&self) -> usize {
37 unsafe {
38 let items = self.items.get();
39 (*items).len()
40 }
41 }
42
43 pub fn get(&self, index: usize) -> Option<&I::Item> {
44 unsafe {
45 let items = self.items.get();
46 (*items).get(index)
47 }
48 }
49
50 pub fn push_get(&self, new_value: I::Item) -> &I::Item {
52 unsafe {
53 let items = self.items.get();
54 (*items).push_get(new_value)
55 }
56 }
57}
58
59impl<I, R> Cache<I, R>
60where
61 I: BundleIterator + Iterator,
62{
63 pub fn prefetch(&self) {
64 self.iter.borrow_mut().prefetch_sync();
65 }
66}
67
68pub struct CacheIter<'a, I, R>
69where
70 I: Iterator,
71{
72 cache: &'a Cache<I, R>,
73 curr: usize,
74}
75
76impl<'a, I, R> Iterator for CacheIter<'a, I, R>
77where
78 I: Iterator,
79{
80 type Item = &'a I::Item;
81
82 fn next(&mut self) -> Option<Self::Item> {
83 let cache_len = self.cache.len();
84 match self.curr.cmp(&cache_len) {
85 Ordering::Less => {
86 self.curr += 1;
88 self.cache.get(self.curr - 1)
89 }
90 Ordering::Equal => {
91 let item = self.cache.iter.borrow_mut().next();
93 self.curr += 1;
94 if let Some(item) = item {
95 Some(self.cache.push_get(item))
96 } else {
97 None
98 }
99 }
100 Ordering::Greater => {
101 None
103 }
104 }
105 }
106}
107
108impl<'a, I, R> IntoIterator for &'a Cache<I, R>
109where
110 I: Iterator,
111{
112 type Item = &'a I::Item;
113 type IntoIter = CacheIter<'a, I, R>;
114
115 fn into_iter(self) -> Self::IntoIter {
116 CacheIter {
117 cache: self,
118 curr: 0,
119 }
120 }
121}
122
123pub struct AsyncCache<S, R>
126where
127 S: Stream,
128{
129 stream: PinCell<S>,
130 items: UnsafeCell<ChunkyVec<S::Item>>,
131 pending_wakes: RefCell<Vec<Waker>>,
134 res: std::marker::PhantomData<R>,
135}
136
137impl<S, R> AsyncCache<S, R>
138where
139 S: Stream,
140{
141 pub fn new(stream: S) -> Self {
142 Self {
143 stream: PinCell::new(stream),
144 items: Default::default(),
145 pending_wakes: Default::default(),
146 res: std::marker::PhantomData,
147 }
148 }
149
150 pub fn len(&self) -> usize {
151 unsafe {
152 let items = self.items.get();
153 (*items).len()
154 }
155 }
156
157 pub fn get(&self, index: usize) -> Poll<Option<&S::Item>> {
158 unsafe {
159 let items = self.items.get();
160 (*items).get(index).into()
161 }
162 }
163
164 pub fn push_get(&self, new_value: S::Item) -> &S::Item {
166 unsafe {
167 let items = self.items.get();
168 (*items).push_get(new_value)
169 }
170 }
171
172 pub fn stream(&self) -> AsyncCacheStream<'_, S, R> {
173 AsyncCacheStream {
174 cache: self,
175 curr: 0,
176 }
177 }
178}
179
180impl<S, R> AsyncCache<S, R>
181where
182 S: BundleStream + Stream,
183{
184 pub async fn prefetch(&self) {
185 let pin = unsafe { Pin::new_unchecked(&self.stream) };
186 unsafe { PinMut::as_mut(&mut pin.borrow_mut()).get_unchecked_mut() }
187 .prefetch_async()
188 .await;
189 }
190}
191
192impl<S, R> AsyncCache<S, R>
193where
194 S: Stream,
195{
196 fn poll_next_item(&self, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
198 let pin = unsafe { Pin::new_unchecked(&self.stream) };
199 let poll = PinMut::as_mut(&mut pin.borrow_mut()).poll_next(cx);
200 if poll.is_ready() {
201 let wakers = std::mem::take(&mut *self.pending_wakes.borrow_mut());
202 for waker in wakers {
203 waker.wake();
204 }
205 } else {
206 self.pending_wakes.borrow_mut().push(cx.waker().clone());
207 }
208 poll
209 }
210}
211
212pub struct AsyncCacheStream<'a, S, R>
213where
214 S: Stream,
215{
216 cache: &'a AsyncCache<S, R>,
217 curr: usize,
218}
219
220impl<'a, S, R> Stream for AsyncCacheStream<'a, S, R>
221where
222 S: Stream,
223{
224 type Item = &'a S::Item;
225
226 fn poll_next(
227 mut self: std::pin::Pin<&mut Self>,
228 cx: &mut std::task::Context<'_>,
229 ) -> Poll<Option<Self::Item>> {
230 let cache_len = self.cache.len();
231 match self.curr.cmp(&cache_len) {
232 Ordering::Less => {
233 self.curr += 1;
235 self.cache.get(self.curr - 1)
236 }
237 Ordering::Equal => {
238 let item = ready!(self.cache.poll_next_item(cx));
240 self.curr += 1;
241 if let Some(item) = item {
242 Some(self.cache.push_get(item)).into()
243 } else {
244 None.into()
245 }
246 }
247 Ordering::Greater => {
248 None.into()
250 }
251 }
252 }
253}