qudit_core/memory.rs
1//! Memory management types and helper functions for the Openqudit library.
2
3use std::mem::size_of;
4
5use aligned_vec::AVec;
6use aligned_vec::CACHELINE_ALIGN;
7use aligned_vec::avec;
8use bytemuck::Zeroable;
9
10/// A trait for types that can be stored in memory buffers.
11pub trait Memorable: Sized + Zeroable + Copy {}
12
13impl<T: Sized + Zeroable + Copy> Memorable for T {}
14
15/// An aligned memory buffer.
16///
17/// The memory buffer is used to store the data for matrices, vectors,
18/// and other data structures. This type aliases a group of aligned
19/// vectors of the unit type of a given faer entity. For faer's native
20/// complex numbers this simplfies to a single vector of complex numbers.
21/// For num_complex complex numbers this because two vectors, one for the
22/// real part and one for the imaginary part.
23///
24/// # Type Parameters
25///
26/// * `C`: The faer entity type for the memory buffer.
27///
28/// # See Also
29///
30/// - [alloc_zeroed_memory] function to allocate a new memory buffer.
31/// - [Memorable] trait for more information on trait bounds.
32#[allow(type_alias_bounds)]
33pub type MemoryBuffer<C: Memorable> = AVec<C>;
34
35/// A pointer to an element in a memory buffer.
36///
37/// The pointer is used to access elements in a memory buffer. This type
38/// aliases a group of pointers to the unit type of a given faer entity.
39/// For faer's native complex numbers this simplfies to a single pointer.
40/// For num_complex complex numbers this becomes two pointers, one for the
41/// real part and one for the imaginary part.
42///
43/// # Type Parameters
44///
45/// * `C`: The faer entity type for the memory buffer.
46///
47/// # See Also
48///
49/// - [MemoryBuffer] type alias for more information on memory buffers.
50/// - [Memorable] trait for more information on trait bounds.
51/// - [MemoryPointerMut] type alias for more information on mutable pointers.
52/// - [MemoryPointerNonNull] type alias for more information on non-null pointers.
53#[allow(type_alias_bounds)]
54pub type MemoryPointer<C: Memorable> = *const C;
55
56/// A mutable pointer to an element in a memory buffer.
57///
58/// The mutable pointer is used to access or modify elements in a memory
59/// buffer.
60///
61/// # Type Parameters
62///
63/// * `C`: The faer entity type for the memory buffer.
64///
65/// # See Also
66///
67/// - [MemoryBuffer] type alias for more information on memory buffers.
68/// - [Memorable] trait for more information on trait bounds.
69/// - [MemoryPointer] type alias for more information on const pointers.
70/// - [MemoryPointerNonNull] type alias for more information on non-null pointers.
71#[allow(type_alias_bounds)]
72pub type MemoryPointerMut<C> = *mut C;
73
74/// A const non-null pointer to an element in a memory buffer.
75///
76/// # Type Parameters
77///
78/// * `C`: The faer entity type for the memory buffer.
79///
80/// # See Also
81///
82/// - [MemoryBuffer] type alias for more information on memory buffers.
83/// - [Memorable] trait for more information on trait bounds.
84/// - [MemoryPointer] type alias for more information on const pointers.
85/// - [MemoryPointerMut] type alias for more information on mutable pointers.
86#[allow(type_alias_bounds)]
87pub type MemoryPointerNonNull<C> = core::ptr::NonNull<C>;
88
89/// Allocate a new memory buffer with the given size.
90///
91/// The size is the number of elements to allocate in the memory buffer,
92/// not the number of bytes. The memory buffer is aligned to cachelines,
93/// and the elements will be zeroed.
94///
95/// # Type Parameters
96///
97/// * `C`: The faer entity type for the memory buffer.
98///
99/// # Arguments
100///
101/// * `size` - The number of elements to allocate in the memory buffer.
102///
103/// # Returns
104///
105/// A new memory buffer with the given size.
106///
107/// # Panics
108///
109/// This function will panic if the memory requirement overflows isize.
110///
111/// # Example
112///
113/// ```
114/// use qudit_core::memory::alloc_zeroed_memory;
115/// let size = 10;
116/// let mem = alloc_zeroed_memory::<f32>(size);
117/// assert_eq!(mem.len(), size);
118/// ```
119///
120/// # See Also
121///
122/// - [MemoryBuffer] type alias for more information on memory buffers.
123pub fn alloc_zeroed_memory<C: Memorable>(size: usize) -> MemoryBuffer<C> {
124 let mem_size = size
125 .checked_mul(size_of::<C>())
126 .expect("Memory size overflows usize");
127
128 if mem_size > isize::MAX as usize {
129 panic!("Memory size overflows isize");
130 }
131
132 avec![<C as Zeroable>::zeroed(); size]
133}
134
135/// Calculate column stride for a matrix with given rows and columns.
136///
137/// The column stride is the number of elements between the start of each column
138/// in the matrix. Faer, the library qudit-core builds on, uses column-major
139/// storage for matrices, so the column stride is the major stride for matrices.
140/// Extra padding is added to columns to ensure that columns are aligned to
141/// cachelines.
142///
143/// # Type Parameters
144///
145/// * `C`: The faer entity type for the matrix.
146///
147/// # Arguments
148///
149/// * `nrows` - The number of rows in the matrix.
150///
151/// * `ncols` - The number of columns in the matrix.
152///
153/// # Returns
154///
155/// The column stride for the matrix.
156///
157/// # Panics
158///
159/// This function will panic if any of the arithmetic operations overflow.
160/// This should only happen for extremely large matrices.
161///
162/// # Example
163///
164/// ```
165/// use qudit_core::memory::calc_col_stride;
166/// let nrows = 3;
167/// let ncols = 4;
168/// let col_stride = calc_col_stride::<f32>(nrows, ncols);
169/// assert_eq!(col_stride, 3);
170///
171/// let nrows = 14;
172/// let ncols = 40;
173/// let col_stride = calc_col_stride::<f32>(nrows, ncols);
174/// assert_eq!(col_stride, 16);
175/// ```
176///
177/// # See Also
178///
179/// - [calc_mat_stride] function to calculate the matrix stride.
180///
181/// # Notes
182///
183/// This function assumes that the first element starts at the beginning of
184/// a cacheline but does not guarantee that matrices will end at a cacheline
185/// boundary. If multiple matrices are stored in a single memory buffer, the
186/// memory buffer should be aligned to cachelines, and you should calculate
187/// a separate stride for the matrix dimension. See [calc_mat_stride] for
188/// more information.
189///
190/// Additionally, this function returns a usize, while pointer arithmetic
191/// is performed with isize types. Extra care should be taken to avoid
192/// overflow when using the result of this function in pointer arithmetic.
193///
194/// This function always assumes the row stride is 1. If the row stride is
195/// not 1, the column stride will be incorrect.
196pub fn calc_col_stride<C>(nrows: usize, ncols: usize) -> usize {
197 if nrows == 0 || ncols == 0 {
198 return 0;
199 }
200
201 // let unit_size = size_of::<C::Unit>();
202 let unit_size = size_of::<C>();
203
204 if unit_size == 0 {
205 return 0;
206 }
207
208 if nrows == 1 {
209 return 1;
210 }
211
212 if ncols == 1 {
213 return nrows;
214 }
215
216 if unit_size > CACHELINE_ALIGN {
217 // This shouldn't happen for any reasonable type. If it does, we
218 // can't do anything about it, since the following code won't work:
219 // ```
220 // let col_size = nrows * unit_size;
221 // let remainder = CACHELINE_ALIGN - (col_size % CACHELINE_ALIGN);
222 // return (col_size + remainder) / unit_size; // Due to this division
223 // ```
224 // It's guaranteed that col_size + remainder is not divisible by
225 // unit_size, and appending empty cache lines seems extremely
226 // wasteful (in impl effort and computation).
227 return nrows;
228 }
229
230 let units_per_cache_line = CACHELINE_ALIGN / unit_size;
231 let mat_size = nrows
232 .checked_mul(ncols)
233 .expect("Matrix size overflows usize");
234
235 if mat_size <= units_per_cache_line {
236 // TODO: Pad for SIMD? is it necessary? yes it is
237 // If simd gets compiled into binary, then use a const with cfg
238 // otherwise do runtime checks
239 return nrows;
240 }
241
242 if nrows > units_per_cache_line {
243 let remainder = units_per_cache_line - (nrows % units_per_cache_line);
244 return nrows
245 .checked_add(remainder)
246 .expect("Column stride overflows usize");
247 }
248
249 // We now have the following:
250 // - ncols > 1
251 // - 1 < nrows < units_per_cache_line
252 // - nrows * ncols > units_per_cache_line
253 //
254 // This means that we can potentially fit multiple columns in a single
255 // cache line. We now want to find the number of columns that can fit
256 // in a single cache line, and then pad the columns to ensure that they
257 // are aligned to cachelines.
258 //
259 // We start with an initial packed guess:
260 let mut ncols_per_line = units_per_cache_line / nrows;
261
262 // We need the padding to be consistent across all columns, so we
263 // continue to reduce ncols_per_line until the padding can be made
264 // consistent. This happens when the leftover space in the cache line
265 // (units_per_cache_line - cols_per_line * nrows) can be evenely split
266 // into ncols_per_line pieces. In the worst case, we stop at
267 // ncols_per_line = 1, which is the same as the faer implementation.
268 let mut left_over = units_per_cache_line - (ncols_per_line * nrows);
269 while !left_over.is_multiple_of(ncols_per_line) {
270 ncols_per_line -= 1;
271 left_over = units_per_cache_line - (ncols_per_line * nrows);
272 }
273
274 left_over / ncols_per_line + nrows
275}
276
277/// Calculate matrix stride for a tensor with given rows, columns, and column stride.
278///
279/// The matrix stride is the number of elements between the start of each matrix
280/// in a tensor. Extra padding is added to matrices to ensure that matrices are
281/// aligned to cachelines.
282///
283/// # Type Parameters
284///
285/// * `C`: The faer entity type for the matrix.
286///
287/// # Arguments
288///
289/// * `nrows` - The number of rows in the matrix.
290///
291/// * `ncols` - The number of columns in the matrix.
292///
293/// * `col_stride` - The column stride for the matrix.
294///
295/// # Returns
296///
297/// The matrix stride for the tensor.
298///
299/// # Example
300///
301/// ```
302/// use qudit_core::memory::calc_mat_stride;
303/// let nrows = 3;
304/// let ncols = 4;
305/// let col_stride = 4;
306/// let mat_stride = calc_mat_stride::<f64>(nrows, ncols, col_stride);
307/// assert_eq!(mat_stride, 16);
308/// ```
309///
310/// # See Also
311///
312/// - [calc_col_stride] function to calculate the column stride.
313/// - [MemoryBuffer] type alias for more information on memory buffers.
314/// - [crate::array::Tensor] type for more information on three dimension tensors.
315///
316/// # Notes
317///
318/// This function assumes that the first element starts at the beginning
319/// of a cacheline and will guarantee that matrices will end at a
320/// cacheline boundary.
321///
322/// Additionally, this function returns a usize, while pointer arithmetic
323/// is performed with isize types. Extra care should be taken to avoid
324/// overflow when using the result of this function in pointer arithmetic.
325///
326/// This function always assumes the row stride is 1. If the row stride is
327/// not 1, the column stride will be incorrect.
328pub fn calc_mat_stride<C>(_nrows: usize, ncols: usize, col_stride: usize) -> usize {
329 let packed_mat_size = ncols
330 .checked_mul(col_stride)
331 .expect("Matrix size overflows usize");
332
333 // let unit_size = size_of::<C::Unit>();
334 let unit_size = size_of::<C>();
335
336 if unit_size == 0 {
337 return 0;
338 }
339
340 // TODO: maybe check if all matrices can fit in a single cache line
341 // A U3 of c32 has 4 matrices (one fn, 3 grad) and each one takes
342 // 4 elements (32 bytes). If all packed that could nicely fit into
343 // one 128 byte line or two 64 byte line rather than wasting half
344 // or 3/4 of the space depending on line size. Especially since they
345 // typically used together, performance should be speed up (speculation)
346
347 if unit_size > CACHELINE_ALIGN {
348 // See similar comment in [calc_col_stride].
349 return packed_mat_size;
350 }
351
352 let units_per_cache_line = CACHELINE_ALIGN / unit_size;
353
354 let remainder = units_per_cache_line - (packed_mat_size % units_per_cache_line);
355 if remainder == units_per_cache_line {
356 return packed_mat_size;
357 }
358 packed_mat_size + remainder
359}
360
361/// Calculates a subtensor's stride, given its size. We ensure the subtensor
362/// is aligned to cachelines.
363///
364/// # Arguments
365///
366/// * `packed_subtensor_size` - The number of elements in the packed subtensor.
367///
368/// # Returns
369///
370/// * The subtensor stride, aligned to cachelines.
371///
372/// # Example
373/// ```
374/// use qudit_core::memory::calc_next_stride;
375/// use qudit_core::c64;
376///
377/// let stride = calc_next_stride::<c64>(5);
378/// let expected = 8;
379/// assert_eq!(stride, expected);
380/// ```
381pub fn calc_next_stride<C>(packed_subtensor_size: usize) -> usize {
382 let unit_size = size_of::<C>();
383
384 if unit_size == 0 {
385 return 0;
386 }
387
388 // See similar comment in [calc_col_stride].
389 if unit_size > CACHELINE_ALIGN {
390 return packed_subtensor_size;
391 }
392
393 let units_per_cache_line = CACHELINE_ALIGN / unit_size;
394 let remainder = units_per_cache_line - (packed_subtensor_size % units_per_cache_line);
395
396 // The case where the packed subtensor is already aligned to the cacheline
397 if remainder == units_per_cache_line {
398 return packed_subtensor_size;
399 }
400
401 // Adds padding `remainder` to align with the next cacheline.
402 packed_subtensor_size + remainder
403}
404
405#[cfg(test)]
406mod tests {
407 use faer::c32;
408
409 use super::*;
410
411 // #[test]
412 // fn test_alloc_zeroed_memory() {
413 // let size = 10;
414 // let mem = alloc_zeroed_memory::<f32>(size);
415 // assert_eq!(mem.len(), size);
416 // for i in 0..size {
417 // assert_eq!(mem[i], 0.0);
418 // }
419 // }
420
421 #[test]
422 fn test_calc_col_stride() {
423 let nrows = 3;
424 let ncols = 4;
425 let col_stride = calc_col_stride::<f32>(nrows, ncols);
426 assert_eq!(col_stride, 3);
427
428 let nrows = 14;
429 let ncols = 40;
430 let col_stride = calc_col_stride::<f32>(nrows, ncols);
431 assert_eq!(col_stride, 16);
432
433 let nrows = 4;
434 let ncols = 4;
435 let col_stride = calc_col_stride::<c32>(nrows, ncols);
436 assert_eq!(col_stride, 4);
437 }
438
439 #[test]
440 fn test_calc_mat_stride() {
441 let nrows = 3;
442 let ncols = 4;
443 let col_stride = 4;
444 let mat_stride = calc_mat_stride::<f64>(nrows, ncols, col_stride);
445 assert_eq!(mat_stride, 16);
446 }
447}