forma_render/cpu/buffer/layout/
slice_cache.rs

1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// IMPORTANT: Upon any code-related modification to this file, please ensure that all commented-out
16//            tests that start with `fails_` actually fail to compile *independently* from one
17//            another.
18
19use std::{
20    fmt, hint,
21    marker::PhantomData,
22    mem,
23    ops::{Bound, Deref, DerefMut, RangeBounds},
24    ptr::NonNull,
25    slice,
26};
27
28use crossbeam_utils::atomic::AtomicCell;
29
30// `SliceCache` is virtually identical to a `Vec<Range<usize>>` whose range are statically
31// guaranteed not to overlap or overflow the initial `slice` that the object has been made with.
32//
33// This is achieved by forcing the user to produce `Span`s in a closure provided to the constructor
34// from a root `Span` that cannot escape the closure.
35//
36// In `SliceCache::access`, we make sure that the slice doesn't overflow the `len` passed in
37// `SliceCache::new` and then save the pointer to the global `ROOT` that is then guarded until the
38// `Ref` is dropped.
39
40#[repr(transparent)]
41#[derive(Clone, Copy, Eq, PartialEq)]
42struct SendNonNull<T> {
43    ptr: NonNull<T>,
44}
45
46unsafe impl<T> Send for SendNonNull<T> {}
47
48impl<T> From<NonNull<T>> for SendNonNull<T> {
49    fn from(ptr: NonNull<T>) -> Self {
50        Self { ptr }
51    }
52}
53
54static ROOT: AtomicCell<Option<SendNonNull<()>>> = AtomicCell::new(None);
55
56/// A [`prim@slice`] wrapper produced by [`SliceCache::access`].
57#[repr(C)]
58pub struct Slice<'a, T> {
59    offset: isize,
60    len: usize,
61    _phantom: PhantomData<&'a mut [T]>,
62}
63
64// Since this type is equivalent to `&mut [T]`, it also implements `Send`.
65unsafe impl<'a, T: Send> Send for Slice<'a, T> {}
66
67// Since this type is equivalent to `&mut [T]`, it also implements `Sync`.
68unsafe impl<'a, T: Sync> Sync for Slice<'a, T> {}
69
70impl<'a, T> Deref for Slice<'a, T> {
71    type Target = [T];
72
73    #[inline]
74    fn deref(&self) -> &'a Self::Target {
75        let root: NonNull<T> = ROOT.load().unwrap().ptr.cast();
76
77        // `Slice`s should only be dereferences when tainted with the `'s` lifetime from the
78        // `SliceCache::access` method. This ensures that the slice that results from derefrencing
79        // here will also be constrained by the same lifetime.
80        //
81        // This also expects the `ROOT` pointer to be correctly set up in `SliceCache::access`.
82        unsafe { slice::from_raw_parts(root.as_ptr().offset(self.offset), self.len) }
83    }
84}
85
86impl<'a, T> DerefMut for Slice<'a, T> {
87    #[inline]
88    fn deref_mut(&mut self) -> &'a mut Self::Target {
89        let root: NonNull<T> = ROOT.load().unwrap().ptr.cast();
90
91        // `Slice`s should only be dereferences when tainted with the `'s` lifetime from the
92        // `SliceCache::access` method. This ensures that the slice that results from derefrencing
93        // here will also be constrained by the same lifetime.
94        //
95        // This also expects the `ROOT` pointer to be correctly set up in `SliceCache::access`.
96        unsafe { slice::from_raw_parts_mut(root.as_ptr().offset(self.offset), self.len) }
97    }
98}
99
100impl<T: fmt::Debug> fmt::Debug for Slice<'_, T> {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        (**self).fmt(f)
103    }
104}
105
106/// A marker produced by [`SliceCache`] that ensures that all resulting `Span`s will be mutually
107/// non-overlapping.
108///
109/// # Examples
110///
111/// ```
112/// # use forma_render::cpu::buffer::layout::SliceCache;
113/// let _cache = SliceCache::new(4, |span| {
114///     Box::new([span])
115/// });
116/// ```
117#[repr(transparent)]
118pub struct Span<'a>(Slice<'a, ()>);
119
120impl<'a> Span<'a> {
121    fn from_slice(slice: &Slice<'a, ()>) -> Self {
122        Self(Slice {
123            offset: slice.offset,
124            len: slice.len,
125            _phantom: PhantomData,
126        })
127    }
128
129    /// cache span at `mid`. Analogous to [`slice::split_at`].
130    ///
131    /// # Examples
132    ///
133    /// ```
134    /// # use forma_render::cpu::buffer::layout::SliceCache;
135    /// let _cache = SliceCache::new(4, |span| {
136    ///     let (left, right) = span.split_at(2);
137    ///     Box::new([left, right])
138    /// });
139    /// ```
140    #[inline]
141    pub fn slice<R: RangeBounds<usize>>(&self, range: R) -> Option<Self> {
142        let start = match range.start_bound() {
143            Bound::Included(&i) => i,
144            Bound::Excluded(&i) => i + 1,
145            Bound::Unbounded => 0,
146        };
147        let end = match range.end_bound() {
148            Bound::Included(&i) => i + 1,
149            Bound::Excluded(&i) => i,
150            Bound::Unbounded => self.0.len,
151        };
152
153        (start <= end && end <= self.0.len).then_some(Span(Slice {
154            offset: self.0.offset + start as isize,
155            len: end,
156            ..self.0
157        }))
158    }
159
160    /// cache span at `mid`. Analogous to [`slice::split_at`].
161    ///
162    /// # Examples
163    ///
164    /// ```
165    /// # use forma_render::cpu::buffer::layout::SliceCache;
166    /// let _cache = SliceCache::new(4, |span| {
167    ///     let (left, right) = span.split_at(2);
168    ///     Box::new([left, right])
169    /// });
170    /// ```
171    #[inline]
172    pub fn split_at(&self, mid: usize) -> (Self, Self) {
173        assert!(mid <= self.0.len);
174
175        (
176            Span(Slice { len: mid, ..self.0 }),
177            Span(Slice {
178                offset: self.0.offset + mid as isize,
179                len: self.0.len - mid,
180                ..self.0
181            }),
182        )
183    }
184
185    /// Returns an [Iterator](Chunks) over `chunk_size` elements of thr slice. Analogous to [`slice::chunks`].
186    ///
187    /// # Examples
188    ///
189    /// ```
190    /// # use forma_render::cpu::buffer::layout::SliceCache;
191    /// let _cache = SliceCache::new(4, |span| {
192    ///     span.chunks(2).collect()
193    /// });
194    /// ```
195    #[inline]
196    pub fn chunks(self, chunk_size: usize) -> Chunks<'a> {
197        Chunks {
198            slice: self.0,
199            size: chunk_size,
200        }
201    }
202}
203
204impl fmt::Debug for Span<'_> {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        self.0.offset.fmt(f)?;
207        write!(f, "..")?;
208        self.0.len.fmt(f)?;
209
210        Ok(())
211    }
212}
213
214/// An [iterator](std::iter::Iterator) returned by [`Span::chunks`].
215pub struct Chunks<'a> {
216    slice: Slice<'a, ()>,
217    size: usize,
218}
219
220impl<'a> Iterator for Chunks<'a> {
221    type Item = Span<'a>;
222
223    #[inline]
224    fn next(&mut self) -> Option<Self::Item> {
225        (self.slice.len > 0).then(|| {
226            let span = Span(Slice {
227                len: self.size.min(self.slice.len),
228                ..self.slice
229            });
230
231            self.slice.offset += self.size as isize;
232            self.slice.len = self.slice.len.saturating_sub(self.size);
233
234            span
235        })
236    }
237}
238
239impl fmt::Debug for Chunks<'_> {
240    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241        f.debug_struct("Chunks")
242            .field("span", &Span::from_slice(&self.slice))
243            .field("size", &self.size)
244            .finish()
245    }
246}
247
248/// A [reference] wrapper returned by [`SliceCache::access`].
249#[repr(transparent)]
250#[derive(Debug)]
251pub struct Ref<'a, T: ?Sized>(&'a mut T);
252
253impl<'a, T: ?Sized> Ref<'a, T> {
254    pub fn get(&'a mut self) -> &'a mut T {
255        self.0
256    }
257}
258
259impl<T: ?Sized> Deref for Ref<'_, T> {
260    type Target = T;
261
262    #[inline]
263    fn deref(&self) -> &Self::Target {
264        self.0
265    }
266}
267
268impl<T: ?Sized> DerefMut for Ref<'_, T> {
269    #[inline]
270    fn deref_mut(&mut self) -> &mut Self::Target {
271        self.0
272    }
273}
274
275impl<T: ?Sized> Drop for Ref<'_, T> {
276    #[inline]
277    fn drop(&mut self) {
278        ROOT.store(None);
279    }
280}
281
282/// A cache of non-overlapping mutable sub-slices of that enforces lifetimes dynamically.
283///
284/// This type is useful when you have to give up on the mutable reference to a slice but need
285/// a way to cache mutable sub-slices deriving from it.
286///
287/// # Examples
288///
289/// ```
290/// # use forma_render::cpu::buffer::layout::SliceCache;
291/// let mut array = [1, 2, 3];
292///
293/// let mut cache = SliceCache::new(3, |span| {
294///     let (left, right) = span.split_at(1);
295///     Box::new([left, right])
296/// });
297///
298/// for slice in cache.access(&mut array).unwrap().iter_mut() {
299///     for val in slice.iter_mut() {
300///         *val += 1;
301///     }
302/// }
303///
304/// assert_eq!(array, [2, 3, 4]);
305/// ```
306pub struct SliceCache {
307    len: usize,
308    slices: Box<[Slice<'static, ()>]>,
309}
310
311impl SliceCache {
312    /// Creates a new slice cache by storing sub-spans created from a root passed to the closure
313    /// `f`. `len` is the minimum slice length that can then be passed to [`access`](Self::access).
314    ///
315    /// # Examples
316    ///
317    /// ```
318    /// # use forma_render::cpu::buffer::layout::SliceCache;
319    /// let _cache = SliceCache::new(3, |span| {
320    ///     let (left, right) = span.split_at(1);
321    ///     // All returned sub-spans stem from the span passed above.
322    ///     Box::new([left, right])
323    /// });
324    /// ```
325    #[inline]
326    pub fn new<F>(len: usize, f: F) -> Self
327    where
328        F: Fn(Span<'_>) -> Box<[Span<'_>]> + 'static,
329    {
330        let span = Span(Slice {
331            offset: 0,
332            len,
333            _phantom: PhantomData,
334        });
335
336        // `Span<'_>` is transparent over `Slice<'_, ()>`. Since the `'_` above is used just to
337        // trap the span inside the closure, transmuting to `Slice<'static, ()>` does not make any
338        // difference.
339        Self {
340            len,
341            slices: unsafe { mem::transmute(f(span)) },
342        }
343    }
344
345    /// Accesses the `slice` by returning all the sub-slices equivalent to the previously created
346    /// [spans](Span) in the closure passed to [`new`](Self::new).
347    ///
348    /// If the `slice` does not have a length at least as large as the one passed to
349    /// [`new`](Self::new), this function returns `None`.
350    ///
351    /// Note: this method should not be called concurrently with any other `access` calls since it
352    /// will wait for the previously returned [`Ref`] to be dropped.
353    ///
354    /// # Examples
355    ///
356    /// ```
357    /// # use forma_render::cpu::buffer::layout::SliceCache;
358    /// let mut array = [1, 2, 3];
359    ///
360    /// let mut cache = SliceCache::new(3, |span| {
361    ///     let (left, right) = span.split_at(1);
362    ///     Box::new([left, right])
363    /// });
364    ///
365    /// let copy = array;
366    /// let skipped_one = cache.access(&mut array).unwrap();
367    ///
368    /// assert_eq!(&*skipped_one[1], &copy[1..]);
369    /// ```
370    #[inline]
371    pub fn access<'c, 's, T>(&'c mut self, slice: &'s mut [T]) -> Option<Ref<'c, [Slice<'s, T>]>> {
372        if slice.len() >= self.len {
373            while ROOT
374                .compare_exchange(
375                    None,
376                    Some(NonNull::new(slice.as_mut_ptr()).unwrap().cast().into()),
377                )
378                .is_err()
379            {
380                // This spin lock here is mostly for being able to run tests in parallel. Being
381                // able to render to `forma::Composition`s in parallel is currently not supported
382                // and might poor performance due to priority inversion.
383                hint::spin_loop();
384            }
385
386            // Generic `Slice<'static, ()>` are transmuted to `Slice<'s, T>`, enforcing the
387            // original `slice`'s lifetime. Since slices are simply pairs of `(offset, len)`,
388            // transmuting `()` to `T` relies on the `ROOT` being set up above with the correct pointer.
389            return Some(unsafe { mem::transmute(&mut *self.slices) });
390        }
391
392        None
393    }
394
395    #[cfg(test)]
396    fn try_access<'c, 's, T>(&'c mut self, slice: &'s mut [T]) -> Option<Ref<'c, [Slice<'s, T>]>> {
397        if slice.len() >= self.len
398            && ROOT
399                .compare_exchange(
400                    None,
401                    Some(NonNull::new(slice.as_mut_ptr()).unwrap().cast().into()),
402                )
403                .is_ok()
404        {
405            // Generic `Slice<'static, ()>` are transmuted to `Slice<'s, T>`, enforcing the
406            // original `slice`'s lifetime. Since slices are simply pairs of `(offset, len)`,
407            // transmuting `()` to `T` relies on the `ROOT` being set up above with the correct pointer.
408            return Some(unsafe { mem::transmute(&mut *self.slices) });
409        }
410
411        None
412    }
413}
414
415impl fmt::Debug for SliceCache {
416    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417        f.debug_list()
418            .entries(self.slices.iter().map(Span::from_slice))
419            .finish()
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn split_at() {
429        let mut cache = SliceCache::new(5, |span| {
430            let (left, right) = span.split_at(2);
431            Box::new([left, right])
432        });
433        let mut array = [1, 2, 3, 4, 5];
434
435        for slice in cache.access(&mut array).unwrap().iter_mut() {
436            for val in slice.iter_mut() {
437                *val += 1;
438            }
439        }
440
441        assert_eq!(array, [2, 3, 4, 5, 6]);
442    }
443
444    #[test]
445    fn chunks() {
446        let mut cache = SliceCache::new(5, |span| span.chunks(2).collect());
447        let mut array = [1, 2, 3, 4, 5];
448
449        for slice in cache.access(&mut array).unwrap().iter_mut() {
450            for val in slice.iter_mut() {
451                *val += 1;
452            }
453        }
454
455        assert_eq!(array, [2, 3, 4, 5, 6]);
456    }
457
458    #[test]
459    fn ref_twice() {
460        let mut cache = SliceCache::new(5, |span| {
461            let (left, right) = span.split_at(2);
462            Box::new([left, right])
463        });
464        let mut array = [1, 2, 3, 4, 5];
465
466        for slice in cache.access(&mut array).unwrap().iter_mut() {
467            for val in slice.iter_mut() {
468                *val += 1;
469            }
470        }
471
472        for slice in cache.access(&mut array).unwrap().iter_mut() {
473            for val in slice.iter_mut() {
474                *val += 1;
475            }
476        }
477
478        assert_eq!(array, [3, 4, 5, 6, 7]);
479    }
480
481    #[test]
482    fn access_twice() {
483        let mut cache0 = SliceCache::new(5, |span| Box::new([span]));
484        let mut cache1 = SliceCache::new(5, |span| Box::new([span]));
485
486        let mut array0 = [1, 2, 3, 4, 5];
487        let mut array1 = [1, 2, 3, 4, 5];
488
489        let _slices = cache0.access(&mut array0).unwrap();
490
491        assert!(matches!(cache1.try_access(&mut array1), None));
492    }
493
494    // #[test]
495    // fn fails_due_to_too_short_lifetime() {
496    //     let mut cache = SliceCache::new(16, |span| Box::new([span]));
497
498    //     let slices = {
499    //         let mut buffer = [0u8; 16];
500
501    //         let slices = cache.access(&mut buffer).unwrap();
502    //         let slice = &mut *slices[0];
503
504    //         slice
505    //     };
506
507    //     &slices[0];
508    // }
509
510    // #[test]
511    // fn fails_due_to_mixed_spans() {
512    //     SliceCache::new(16, |span0| {
513    //         let (left, right) = span0.split_at(2);
514
515    //         SliceCache::new(4, |span1| {
516    //             Box::new([left])
517    //         });
518
519    //         Box::new([right])
520    //     });
521    // }
522
523    // #[test]
524    // fn fails_due_to_t_not_being_send() {
525    //     use std::rc::Rc;
526
527    //     use rayon::prelude::*;
528
529    //     let mut array = [Rc::new(1), Rc::new(2), Rc::new(3)];
530
531    //     let mut cache = SliceCache::new(3, |span| {
532    //         let (left, right) = span.split_at(1);
533    //         Box::new([left, right])
534    //     });
535
536    //     cache.access(&mut array).unwrap().par_iter_mut().for_each(|slice| {
537    //         for val in slice.iter_mut() {
538    //             *val += 1;
539    //         }
540    //     });
541    // }
542
543    // #[test]
544    // fn fails_to_export_span() {
545    //     let mut leaked = None;
546
547    //     let mut cache0 = SliceCache::new(1, |span| {
548    //         leaked = Some(span);
549    //         Box::new([])
550    //     });
551
552    //     let mut cache1 = SliceCache::new(1, |span| {
553    //         Box::new([leaked.take().unwrap()])
554    //     });
555    // }
556
557    // #[test]
558    // fn fails_due_to_dropped_slice() {
559    //     let mut array = [1, 2, 3];
560
561    //     let mut cache = SliceCache::new(3, |span| {
562    //         let (left, right) = span.split_at(1);
563    //         Box::new([left, right])
564    //     });
565
566    //     let slices = cache.access(&mut array).unwrap();
567
568    //     std::mem::drop(array);
569
570    //     slices[0][0] = 0;
571    // }
572}