disjoint_borrow/
lib.rs

1#![no_std]
2
3//! Disjoint borrows of slices.
4//!
5//! Provides the [`DisjointSlice`](struct.DisjointSlice.html) type, allowing disjoint borrows over
6//! slices by adding runtime checks. Immutable borrows are allowed to intersect with other immutable
7//! borrows, while mutable borrows may not intersect with any borrows.
8//!
9//! Borrow tracking is implemented as type-level list. This has the advantage that no allocation is
10//! necessary, but also limits the number of disjoint borrows to a compile-time constant.
11//!
12//! No-std compatible.
13//!
14//! # Example
15//!
16//! ```
17//! use disjoint_borrow::DisjointSlice;
18//!
19//! let mut array = [1, 2, 3, 4, 5];
20//! let mut ds = DisjointSlice::new(&mut array);
21//! let (mut ds, mut a) = ds.get_mut(0..2);
22//! let (_, mut b) = ds.get_mut(3..5);
23//!
24//! a[0] *= -1;
25//! b[1] *= -1;
26//!
27//! assert_eq!(a, &[-1, 2]);
28//! assert_eq!(b, &[4, -5]);
29//! ```
30
31use core::{
32    fmt,
33    marker::PhantomData,
34    ops::{Bound, Range, RangeBounds},
35};
36
37/// A set of ranges.
38///
39/// This traits should generally not be implemented outside the `disjoint_borrow` crate, but may be
40/// referenced in order to be generic over the number of disjoint borrows.
41pub unsafe trait RangeSet {
42    #[doc(hidden)]
43    fn intersects(&self, range: &Range<usize>) -> bool;
44
45    #[doc(hidden)]
46    fn fmt(&self, _formatter: &mut fmt::Formatter, _first: bool) -> Result<(), fmt::Error> {
47        Ok(())
48    }
49}
50
51unsafe impl<T: RangeSet> RangeSet for &T {
52    #[inline]
53    fn intersects(&self, range: &Range<usize>) -> bool {
54        (*self).intersects(range)
55    }
56
57    #[inline]
58    fn fmt(&self, formatter: &mut fmt::Formatter, first: bool) -> Result<(), fmt::Error> {
59        (*self).fmt(formatter, first)
60    }
61}
62
63unsafe impl RangeSet for () {
64    #[inline]
65    fn intersects(&self, _: &Range<usize>) -> bool {
66        false
67    }
68}
69
70/// A "link" in the borrow chain. Every time a new range is borrowed, a new link is added.
71#[derive(Clone)]
72pub struct RangeLink<T> {
73    range: Range<usize>,
74    next: T,
75}
76
77unsafe impl<T> RangeSet for RangeLink<T>
78where
79    T: RangeSet,
80{
81    #[inline]
82    fn intersects(&self, range: &Range<usize>) -> bool {
83        (self.range.end > range.start && self.range.start < range.end)
84            || self.next.intersects(range)
85    }
86
87    #[inline]
88    fn fmt(&self, formatter: &mut fmt::Formatter, first: bool) -> Result<(), fmt::Error> {
89        if !first {
90            formatter.write_str(", ")?;
91        }
92
93        fmt::Debug::fmt(&self.range, formatter)?;
94        self.next.fmt(formatter, false)
95    }
96}
97
98/// An error when retrieving a slice.
99#[derive(Debug, PartialEq, Eq)]
100pub enum Error {
101    /// An error caused by indices being out of range for the given slice.
102    InvalidIndex,
103
104    /// An error caused by two intersecting non-compatible borrows.
105    BorrowIntersection,
106}
107
108/// A slice that allows disjoint borrows over its elements. Mutable borrows may not intersect any
109/// other borrows, but immutable borrows may intersect other immutable borrows.
110///
111/// Functions that borrow slices returns a new `DisjointSlice` object as the first parameter and
112/// the slice as the second. The returned `DisjointSlice` object can be used to borrow further
113/// slcies.
114///
115/// See the package documentation for more information.
116pub struct DisjointSlice<'a, T, Borrows, MutBorrows> {
117    ptr: *mut T,
118    len: usize,
119    borrows: Borrows,
120    borrows_mut: MutBorrows,
121    phantom: PhantomData<&'a mut T>,
122}
123
124impl<'a, T> DisjointSlice<'a, T, (), ()> {
125    /// Creates a new `DistjointSlice` from a mutable slice.
126    pub fn new(slice: &'a mut [T]) -> Self {
127        DisjointSlice {
128            ptr: slice.as_mut_ptr(),
129            len: slice.len(),
130            borrows: (),
131            borrows_mut: (),
132            phantom: PhantomData,
133        }
134    }
135}
136
137impl<'a, T, Borrows, MutBorrows> DisjointSlice<'a, T, Borrows, MutBorrows>
138where
139    Borrows: RangeSet,
140    MutBorrows: RangeSet,
141{
142    /// Gets the length of the underlying slice.
143    pub fn len(&self) -> usize {
144        self.len
145    }
146
147    /// Returns `true` if the underlying slice is empty.
148    pub fn is_empty(&self) -> bool {
149        self.len == 0
150    }
151
152    /// Rertrieves an immutable subslice, panicking if the range is out of range of the slice or it
153    /// intersects with any other mutable borrowed slices.
154    #[inline]
155    pub fn get<'b, R>(
156        &'b mut self,
157        index: R,
158    ) -> (
159        DisjointSlice<'b, T, RangeLink<&'b Borrows>, &'b MutBorrows>,
160        &'b [T],
161    )
162    where
163        R: RangeBounds<usize>,
164    {
165        let range = self.range(index);
166
167        assert!(range.start <= self.len, "Range start out of range");
168        assert!(range.end <= self.len, "Range end out of range");
169        assert!(
170            !self.borrows_mut.intersects(&range),
171            "Range intersects with mutable borrows"
172        );
173
174        let len = range.end.saturating_sub(range.start);
175
176        let slice = unsafe { core::slice::from_raw_parts(self.ptr.add(range.start), len) };
177
178        (
179            DisjointSlice {
180                ptr: self.ptr,
181                len: self.len,
182                borrows: RangeLink {
183                    range,
184                    next: &self.borrows,
185                },
186                borrows_mut: &self.borrows_mut,
187                phantom: self.phantom,
188            },
189            slice,
190        )
191    }
192
193    /// Retrieves an immutable subslice, returning `Ok` if successfull or `Err` if the range
194    /// is out of range of the slice or intersects with a mutable borrow.
195    #[inline]
196    pub fn try_get<'b, R>(
197        &'b mut self,
198        range: R,
199    ) -> Result<
200        (
201            DisjointSlice<'b, T, RangeLink<&'b Borrows>, &'b MutBorrows>,
202            &'b [T],
203        ),
204        Error,
205    >
206    where
207        R: RangeBounds<usize>,
208    {
209        self.try_get_range(self.range(range))
210    }
211
212    #[inline]
213    fn try_get_range<'b>(
214        &'b mut self,
215        range: Range<usize>,
216    ) -> Result<
217        (
218            DisjointSlice<'b, T, RangeLink<&'b Borrows>, &'b MutBorrows>,
219            &'b [T],
220        ),
221        Error,
222    > {
223        if range.start > self.len {
224            return Err(Error::InvalidIndex);
225        }
226
227        if range.end > self.len {
228            return Err(Error::InvalidIndex);
229        }
230
231        if self.borrows_mut.intersects(&range) {
232            return Err(Error::BorrowIntersection);
233        }
234
235        let len = range.end.saturating_sub(range.start);
236
237        let slice = unsafe { core::slice::from_raw_parts(self.ptr.add(range.start), len) };
238
239        Ok((
240            DisjointSlice {
241                ptr: self.ptr,
242                len: self.len,
243                borrows: RangeLink {
244                    range,
245                    next: &self.borrows,
246                },
247                borrows_mut: &self.borrows_mut,
248                phantom: self.phantom,
249            },
250            slice,
251        ))
252    }
253
254    /// Rertrieves a mutable subslice, panicking if the range is out of range of the slice or it
255    /// intersects with any other immuitable or mutable borrowed slices.
256    #[inline]
257    pub fn get_mut<'b, R>(
258        &'b mut self,
259        index: R,
260    ) -> (
261        DisjointSlice<'b, T, &'b Borrows, RangeLink<&'b MutBorrows>>,
262        &'b mut [T],
263    )
264    where
265        R: RangeBounds<usize>,
266    {
267        let range = self.range(index);
268
269        assert!(range.start <= self.len, "Range start out of range");
270        assert!(range.end <= self.len, "Range end out of range");
271        assert!(
272            !self.borrows.intersects(&range),
273            "Range intersects with immutable borrows"
274        );
275        assert!(
276            !self.borrows_mut.intersects(&range),
277            "Range intersects with mutable borrows"
278        );
279
280        let len = range.end.saturating_sub(range.start);
281
282        let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.add(range.start), len) };
283
284        (
285            DisjointSlice {
286                ptr: self.ptr,
287                len: self.len,
288                borrows: &self.borrows,
289                borrows_mut: RangeLink {
290                    range,
291                    next: &self.borrows_mut,
292                },
293                phantom: self.phantom,
294            },
295            slice,
296        )
297    }
298
299    /// Retrieves an immutable subslice, returning `Ok` if successfull or `Err` if the range
300    /// is out of range of the slice or intersects with any other immutable or mutable borrow.
301    #[inline]
302    pub fn try_get_mut<'b, R>(
303        &'b mut self,
304        range: R,
305    ) -> Result<
306        (
307            DisjointSlice<'b, T, &'b Borrows, RangeLink<&'b MutBorrows>>,
308            &'b mut [T],
309        ),
310        Error,
311    >
312    where
313        R: RangeBounds<usize>,
314    {
315        self.try_get_range_mut(self.range(range))
316    }
317
318    #[inline]
319    fn try_get_range_mut<'b>(
320        &'b mut self,
321        range: Range<usize>,
322    ) -> Result<
323        (
324            DisjointSlice<'b, T, &'b Borrows, RangeLink<&'b MutBorrows>>,
325            &'b mut [T],
326        ),
327        Error,
328    > {
329        if range.start > self.len {
330            return Err(Error::InvalidIndex);
331        }
332
333        if range.end > self.len {
334            return Err(Error::InvalidIndex);
335        }
336
337        if self.borrows.intersects(&range) {
338            return Err(Error::BorrowIntersection);
339        }
340
341        if self.borrows_mut.intersects(&range) {
342            return Err(Error::BorrowIntersection);
343        }
344
345        let len = range.end.saturating_sub(range.start);
346
347        let slice = unsafe { core::slice::from_raw_parts_mut(self.ptr.add(range.start), len) };
348
349        Ok((
350            DisjointSlice {
351                ptr: self.ptr,
352                len: self.len,
353                borrows: &self.borrows,
354                borrows_mut: RangeLink {
355                    range,
356                    next: &self.borrows_mut,
357                },
358                phantom: self.phantom,
359            },
360            slice,
361        ))
362    }
363
364    #[inline]
365    fn range<R>(&self, range: R) -> Range<usize>
366    where
367        R: RangeBounds<usize>,
368    {
369        let lo = match range.start_bound() {
370            Bound::Included(&n) => n,
371            Bound::Excluded(&n) => n + 1,
372            Bound::Unbounded => 0,
373        };
374
375        let hi = match range.end_bound() {
376            Bound::Included(&n) => n + 1,
377            Bound::Excluded(&n) => n,
378            Bound::Unbounded => self.len,
379        };
380
381        (lo..hi)
382    }
383}
384
385impl<'a, T, Borrows, MutBorrows> fmt::Debug for DisjointSlice<'a, T, Borrows, MutBorrows>
386where
387    Borrows: RangeSet,
388    MutBorrows: RangeSet,
389{
390    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
391        write!(f, "DisjointSlice{{ borrows: [")?;
392        self.borrows.fmt(f, true)?;
393        write!(f, "], borrows_mut: [")?;
394        self.borrows_mut.fmt(f, true)?;
395        write!(f, "]}}")
396    }
397}
398
399unsafe impl<'a, T, Borrows, MutBorrows> Send for DisjointSlice<'a, T, Borrows, MutBorrows> where
400    T: Send
401{
402}
403
404unsafe impl<'a, T, Borrows, MutBorrows> Sync for DisjointSlice<'a, T, Borrows, MutBorrows> where
405    T: Sync
406{
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn basic() {
415        let mut array = [1, 2, 3, 4, 5];
416        let mut ds = DisjointSlice::new(&mut array[..]);
417        let (mut ds, a) = ds.get(0..1);
418        let (mut ds, b) = ds.get(1..2);
419        let (mut ds, c) = ds.get_mut(2..3);
420        let (_, d) = ds.get_mut(3..5);
421
422        assert_eq!(a, &[1]);
423        assert_eq!(b, &[2]);
424        assert_eq!(c, &[3]);
425        assert_eq!(d, &[4, 5]);
426    }
427
428    #[test]
429    fn slice() {
430        let mut array = [1, 2, 3, 4, 5];
431        let mut sl = DisjointSlice::new(&mut array[..]);
432        let (mut sl, a) = sl.get_mut(..3);
433        let (_, b) = sl.get_mut(3..);
434
435        assert_eq!(a, &[1, 2, 3]);
436        assert_eq!(b, &[4, 5]);
437    }
438
439    #[test]
440    #[should_panic]
441    fn get_after_mut() {
442        let mut array = [1, 2, 3, 4, 5];
443        let mut sl = DisjointSlice::new(&mut array[..]);
444        let (mut sl, _) = sl.get_mut(..3);
445        sl.get(2..);
446    }
447
448    #[test]
449    fn scoped() {
450        let mut array = [1, 2, 3, 4, 5];
451        let mut sl = DisjointSlice::new(&mut array[..]);
452        let (mut sl, a) = sl.get_mut(..3);
453
454        {
455            sl.get_mut(3..);
456        }
457
458        let (_, b) = sl.get_mut(3..);
459
460        assert_eq!(a, &[1, 2, 3]);
461        assert_eq!(b, &[4, 5]);
462    }
463
464    #[test]
465    #[should_panic]
466    fn out_of_range() {
467        let mut array = [1, 2, 3, 4, 5];
468        let mut sl = DisjointSlice::new(&mut array[..]);
469        sl.get(6..10);
470    }
471
472    #[test]
473    fn backward_range() {
474        let mut array = [1, 2, 3, 4, 5];
475        let mut sl = DisjointSlice::new(&mut array[..]);
476        let (_, a) = sl.get(4..3);
477        assert_eq!(a, &[]);
478    }
479
480    #[test]
481    fn one_past_end() {
482        let mut array = [1, 2, 3, 4, 5];
483        let mut sl = DisjointSlice::new(&mut array[..]);
484        let (_, a) = sl.get(5..5);
485        assert_eq!(a, &[]);
486    }
487
488    /// Edge case: does an empty range with start inside another range count as an intersection?
489    /// Currently, we say that it does.
490    #[test]
491    #[should_panic]
492    fn empty_intersect() {
493        let mut array = [1, 2, 3];
494        let mut ds = DisjointSlice::new(&mut array[..]);
495        let (mut ds, _) = ds.get(0..3);
496        ds.get_mut(1..1);
497    }
498
499    #[test]
500    #[should_panic]
501    fn backward_out_of_range() {
502        let mut array = [1, 2, 3, 4, 5];
503        let mut sl = DisjointSlice::new(&mut array[..]);
504        sl.get(10..2);
505    }
506
507    #[test]
508    fn try_get_out_of_range() {
509        let mut array = [1, 2, 3, 4, 5];
510        let mut ds = DisjointSlice::new(&mut array[..]);
511        assert_eq!(ds.try_get(6..10).unwrap_err(), Error::InvalidIndex);
512    }
513
514    #[test]
515    fn try_get_intersect() {
516        let mut array = [1, 2, 3, 4, 5];
517        let mut ds = DisjointSlice::new(&mut array[..]);
518        let (mut ds, _) = ds.get_mut(1..3);
519        assert_eq!(ds.try_get(0..2).unwrap_err(), Error::BorrowIntersection);
520    }
521
522    #[test]
523    fn try_get_mut_out_of_range() {
524        let mut array = [1, 2, 3, 4, 5];
525        let mut ds = DisjointSlice::new(&mut array[..]);
526        assert_eq!(ds.try_get_mut(6..10).unwrap_err(), Error::InvalidIndex);
527    }
528
529    #[test]
530    fn try_get_mut_intersect() {
531        let mut array = [1, 2, 3, 4, 5];
532        let mut ds = DisjointSlice::new(&mut array[..]);
533        let (mut ds, _) = ds.get(1..3);
534        assert_eq!(ds.try_get_mut(0..2).unwrap_err(), Error::BorrowIntersection);
535    }
536}