fluent_fallback/
cache.rs

1use std::{
2    cell::{RefCell, UnsafeCell},
3    cmp::Ordering,
4    pin::Pin,
5    task::Context,
6    task::Poll,
7    task::Waker,
8};
9
10use crate::generator::{BundleIterator, BundleStream};
11use chunky_vec::ChunkyVec;
12use futures::{ready, Stream};
13use pin_cell::{PinCell, PinMut};
14
15pub struct Cache<I, R>
16where
17    I: Iterator,
18{
19    iter: RefCell<I>,
20    items: UnsafeCell<ChunkyVec<I::Item>>,
21    res: std::marker::PhantomData<R>,
22}
23
24impl<I, R> Cache<I, R>
25where
26    I: Iterator,
27{
28    pub fn new(iter: I) -> Self {
29        Self {
30            iter: RefCell::new(iter),
31            items: Default::default(),
32            res: std::marker::PhantomData,
33        }
34    }
35
36    pub fn len(&self) -> usize {
37        unsafe {
38            let items = self.items.get();
39            (*items).len()
40        }
41    }
42
43    pub fn get(&self, index: usize) -> Option<&I::Item> {
44        unsafe {
45            let items = self.items.get();
46            (*items).get(index)
47        }
48    }
49
50    /// Push, immediately getting a reference to the element
51    pub fn push_get(&self, new_value: I::Item) -> &I::Item {
52        unsafe {
53            let items = self.items.get();
54            (*items).push_get(new_value)
55        }
56    }
57}
58
59impl<I, R> Cache<I, R>
60where
61    I: BundleIterator + Iterator,
62{
63    pub fn prefetch(&self) {
64        self.iter.borrow_mut().prefetch_sync();
65    }
66}
67
68pub struct CacheIter<'a, I, R>
69where
70    I: Iterator,
71{
72    cache: &'a Cache<I, R>,
73    curr: usize,
74}
75
76impl<'a, I, R> Iterator for CacheIter<'a, I, R>
77where
78    I: Iterator,
79{
80    type Item = &'a I::Item;
81
82    fn next(&mut self) -> Option<Self::Item> {
83        let cache_len = self.cache.len();
84        match self.curr.cmp(&cache_len) {
85            Ordering::Less => {
86                // Cached value
87                self.curr += 1;
88                self.cache.get(self.curr - 1)
89            }
90            Ordering::Equal => {
91                // Get the next item from the iterator
92                let item = self.cache.iter.borrow_mut().next();
93                self.curr += 1;
94                if let Some(item) = item {
95                    Some(self.cache.push_get(item))
96                } else {
97                    None
98                }
99            }
100            Ordering::Greater => {
101                // Ran off the end of the cache
102                None
103            }
104        }
105    }
106}
107
108impl<'a, I, R> IntoIterator for &'a Cache<I, R>
109where
110    I: Iterator,
111{
112    type Item = &'a I::Item;
113    type IntoIter = CacheIter<'a, I, R>;
114
115    fn into_iter(self) -> Self::IntoIter {
116        CacheIter {
117            cache: self,
118            curr: 0,
119        }
120    }
121}
122
123////////////////////////////////////////////////////////////////////////////////
124
125pub struct AsyncCache<S, R>
126where
127    S: Stream,
128{
129    stream: PinCell<S>,
130    items: UnsafeCell<ChunkyVec<S::Item>>,
131    // TODO: Should probably be an SmallVec<[Waker; 1]> or something? I guess
132    // multiple pending wakes are not really all that common.
133    pending_wakes: RefCell<Vec<Waker>>,
134    res: std::marker::PhantomData<R>,
135}
136
137impl<S, R> AsyncCache<S, R>
138where
139    S: Stream,
140{
141    pub fn new(stream: S) -> Self {
142        Self {
143            stream: PinCell::new(stream),
144            items: Default::default(),
145            pending_wakes: Default::default(),
146            res: std::marker::PhantomData,
147        }
148    }
149
150    pub fn len(&self) -> usize {
151        unsafe {
152            let items = self.items.get();
153            (*items).len()
154        }
155    }
156
157    pub fn get(&self, index: usize) -> Poll<Option<&S::Item>> {
158        unsafe {
159            let items = self.items.get();
160            (*items).get(index).into()
161        }
162    }
163
164    /// Push, immediately getting a reference to the element
165    pub fn push_get(&self, new_value: S::Item) -> &S::Item {
166        unsafe {
167            let items = self.items.get();
168            (*items).push_get(new_value)
169        }
170    }
171
172    pub fn stream(&self) -> AsyncCacheStream<'_, S, R> {
173        AsyncCacheStream {
174            cache: self,
175            curr: 0,
176        }
177    }
178}
179
180impl<S, R> AsyncCache<S, R>
181where
182    S: BundleStream + Stream,
183{
184    pub async fn prefetch(&self) {
185        let pin = unsafe { Pin::new_unchecked(&self.stream) };
186        unsafe { PinMut::as_mut(&mut pin.borrow_mut()).get_unchecked_mut() }
187            .prefetch_async()
188            .await;
189    }
190}
191
192impl<S, R> AsyncCache<S, R>
193where
194    S: Stream,
195{
196    // Helper function that gets the next value from wrapped stream.
197    fn poll_next_item(&self, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
198        let pin = unsafe { Pin::new_unchecked(&self.stream) };
199        let poll = PinMut::as_mut(&mut pin.borrow_mut()).poll_next(cx);
200        if poll.is_ready() {
201            let wakers = std::mem::take(&mut *self.pending_wakes.borrow_mut());
202            for waker in wakers {
203                waker.wake();
204            }
205        } else {
206            self.pending_wakes.borrow_mut().push(cx.waker().clone());
207        }
208        poll
209    }
210}
211
212pub struct AsyncCacheStream<'a, S, R>
213where
214    S: Stream,
215{
216    cache: &'a AsyncCache<S, R>,
217    curr: usize,
218}
219
220impl<'a, S, R> Stream for AsyncCacheStream<'a, S, R>
221where
222    S: Stream,
223{
224    type Item = &'a S::Item;
225
226    fn poll_next(
227        mut self: std::pin::Pin<&mut Self>,
228        cx: &mut std::task::Context<'_>,
229    ) -> Poll<Option<Self::Item>> {
230        let cache_len = self.cache.len();
231        match self.curr.cmp(&cache_len) {
232            Ordering::Less => {
233                // Cached value
234                self.curr += 1;
235                self.cache.get(self.curr - 1)
236            }
237            Ordering::Equal => {
238                // Get the next item from the stream
239                let item = ready!(self.cache.poll_next_item(cx));
240                self.curr += 1;
241                if let Some(item) = item {
242                    Some(self.cache.push_get(item)).into()
243                } else {
244                    None.into()
245                }
246            }
247            Ordering::Greater => {
248                // Ran off the end of the cache
249                None.into()
250            }
251        }
252    }
253}