Skip to main content

poulpy_hal/layouts/
scratch.rs

1use std::{marker::PhantomData, ptr::NonNull};
2
3use crate::layouts::Backend;
4
5/// Owned scratch buffer for temporary workspace during polynomial operations.
6///
7/// Operations such as normalization, DFT, and vector-matrix products require
8/// temporary scratch memory. `ScratchOwned` holds a backend-owned buffer that
9/// can be borrowed as a [`ScratchArena`].
10///
11/// The required size for each operation is obtained via the corresponding
12/// `*_tmp_bytes` method on the API trait (e.g.
13/// [`VecZnxNormalizeTmpBytes`](crate::api::VecZnxNormalizeTmpBytes)).
14#[repr(C)]
15pub struct ScratchOwned<B: Backend> {
16    pub data: B::OwnedBuf,
17    pub _phantom: PhantomData<B>,
18}
19
20/// Backend-native scratch arena borrowed from a [`ScratchOwned`].
21///
22/// This arena keeps backend ownership explicit and carves typed temporaries
23/// using the backend's native borrowed buffer view (`B::BufMut<'a>`).
24pub struct ScratchArena<'a, B: Backend> {
25    data: NonNull<B::OwnedBuf>,
26    start: usize,
27    end: usize,
28    _phantom: PhantomData<&'a mut B::OwnedBuf>,
29}
30
31impl<B: Backend> ScratchOwned<B> {
32    /// Borrows this owned scratch buffer as a backend-native arena.
33    pub fn arena(&mut self) -> ScratchArena<'_, B> {
34        ScratchArena {
35            data: NonNull::from(&mut self.data),
36            start: 0,
37            end: B::len_bytes(&self.data),
38            _phantom: PhantomData,
39        }
40    }
41}
42
43impl<'a, B: Backend> ScratchArena<'a, B> {
44    /// Reborrows this arena with a shorter lifetime.
45    pub fn borrow<'b>(&'b mut self) -> ScratchArena<'b, B> {
46        ScratchArena {
47            data: self.data,
48            start: self.start,
49            end: self.end,
50            _phantom: PhantomData,
51        }
52    }
53
54    /// Runs `f` with a shorter-lived reborrow of this arena.
55    ///
56    /// This is useful for nested workspace use where the borrowed arena
57    /// must not leak into the outer function's scratch lifetime.
58    pub fn scope<R>(&mut self, f: impl for<'b> FnOnce(ScratchArena<'b, B>) -> R) -> R {
59        f(self.borrow())
60    }
61
62    /// Applies `f` to this arena through a temporary mutable borrow and returns the advanced arena.
63    ///
64    /// This is useful while migrating callers that still thread scratch by value around newer
65    /// `&mut ScratchArena` APIs.
66    pub fn apply_mut(mut self, f: impl FnOnce(&mut ScratchArena<'a, B>)) -> Self {
67        f(&mut self);
68        self
69    }
70
71    /// Runs `f` on a shorter-lived owned reborrow and commits the returned remainder.
72    ///
73    /// This is useful while migrating APIs from arena-by-value to `&mut ScratchArena`:
74    /// existing helpers can keep their `(result, remainder)` style internally, while the
75    /// outer mutable arena advances to the returned remainder.
76    pub fn consume<R>(&mut self, f: impl for<'b> FnOnce(ScratchArena<'b, B>) -> (R, ScratchArena<'b, B>)) -> R {
77        let arena = ScratchArena {
78            data: self.data,
79            start: self.start,
80            end: self.end,
81            _phantom: PhantomData,
82        };
83        let (res, rem) = f(arena);
84        self.start = rem.start;
85        self.end = rem.end;
86        res
87    }
88    /// Returns the number of aligned bytes that can still be carved out.
89    pub fn available(&self) -> usize {
90        self.end.saturating_sub(align_up::<B>(self.start))
91    }
92
93    /// Splits off `len` aligned bytes from the front of this arena.
94    pub fn split_at(self, len: usize) -> (Self, Self) {
95        let start: usize = align_up::<B>(self.start);
96        let mid: usize = start.checked_add(len).expect("scratch arena split overflow");
97        assert!(
98            mid <= self.end,
99            "Attempted to take {len} from scratch arena with {} aligned bytes left",
100            self.available()
101        );
102        (
103            Self {
104                data: self.data,
105                start,
106                end: mid,
107                _phantom: PhantomData,
108            },
109            Self {
110                data: self.data,
111                start: mid,
112                end: self.end,
113                _phantom: PhantomData,
114            },
115        )
116    }
117
118    /// Splits this arena into `n` disjoint aligned chunks of `len` bytes each.
119    pub fn split(self, n: usize, len: usize) -> (Vec<Self>, Self) {
120        assert!(self.available() >= n * len);
121        let mut arenas: Vec<Self> = Vec::with_capacity(n);
122        let mut arena: Self = self;
123        for _ in 0..n {
124            let (taken, rem) = arena.split_at(len);
125            arena = rem;
126            arenas.push(taken);
127        }
128        (arenas, arena)
129    }
130
131    /// Takes a backend-native mutable region of `len` bytes.
132    pub fn take_region(self, len: usize) -> (B::BufMut<'a>, Self) {
133        let start: usize = align_up::<B>(self.start);
134        let end: usize = start.checked_add(len).expect("scratch arena take overflow");
135        assert!(
136            end <= self.end,
137            "Attempted to take {len} from scratch arena with {} aligned bytes left",
138            self.available()
139        );
140
141        let data: &mut B::OwnedBuf = unsafe {
142            // Safety: `self.data` originates from `ScratchOwned::arena`, which ties
143            // the raw pointer to the lifetime `'a`. Each new arena produced from this
144            // value advances or splits the byte range, so callers can only obtain
145            // disjoint mutable regions from the same backing buffer.
146            self.data.as_ptr().as_mut().expect("scratch arena owner pointer is null")
147        };
148        let region: B::BufMut<'a> = B::region_mut(data, start, len);
149        (
150            region,
151            Self {
152                data: self.data,
153                start: end,
154                end: self.end,
155                _phantom: PhantomData,
156            },
157        )
158    }
159}
160
161#[inline]
162fn align_up<B: Backend>(offset: usize) -> usize {
163    offset.next_multiple_of(B::SCRATCH_ALIGN)
164}