hpt_common/shape/shape_utils.rs
1use std::panic::Location;
2
3use crate::{
4 error::{base::TensorError, shape::ShapeError},
5 shape::shape::Shape,
6 strides::strides::Strides,
7};
8
9/// Inserts a dimension of size 1 before the specified index in a shape.
10///
11/// The `yield_one_before` function takes an existing shape (a slice of `i64` values) and inserts
12/// a new dimension of size 1 before the specified index `idx`. This is useful in tensor operations
13/// where you need to expand the dimensions of a tensor by adding singleton dimensions, which can
14/// facilitate broadcasting or other dimension-specific operations.
15///
16/// # Parameters
17///
18/// - `shape`: A slice of `i64` representing the original shape of the tensor.
19/// - `idx`: The index before which a new dimension of size 1 will be inserted.
20///
21/// # Returns
22///
23/// - A `Vec<i64>` representing the new shape with the inserted dimension of size 1.
24///
25/// # Examples
26///
27/// ```rust
28/// // Example 1: Insert before the first dimension
29/// let shape = vec![3, 4, 5];
30/// let idx = 0;
31/// let new_shape = yield_one_before(&shape, idx);
32/// assert_eq!(new_shape, vec![1, 3, 4, 5]);
33///
34/// // Example 2: Insert before a middle dimension
35/// let idx = 2;
36/// let new_shape = yield_one_before(&shape, idx);
37/// assert_eq!(new_shape, vec![3, 4, 1, 5]);
38///
39/// // Example 3: Insert before the last dimension
40/// let idx = 2;
41/// let new_shape = yield_one_before(&shape, idx);
42/// assert_eq!(new_shape, vec![3, 4, 1, 5]);
43///
44/// // Example 4: Index out of bounds (appends 1 at the end)
45/// let idx = 5;
46/// let new_shape = yield_one_before(&shape, idx);
47/// assert_eq!(new_shape, vec![3, 4, 5, 1]);
48/// ```
49///
50/// # Notes
51///
52/// - **Index Bounds**: If `idx` is greater than the length of `shape`, the function will append a
53/// dimension of size 1 at the end of the shape.
54/// - **Use Cases**: Adding a singleton dimension is often used to adjust the shape of a tensor for
55/// broadcasting in element-wise operations or to match required input dimensions for certain
56/// functions.
57/// - **Immutability**: The original `shape` slice is not modified; a new `Vec<i64>` is returned.
58///
59/// # Implementation Details
60///
61/// The function works by iterating over the original shape and copying each dimension into a new
62/// vector. When the current index matches `idx`, it inserts a `1` before copying the next dimension.
63///
64/// # See Also
65///
66/// ```rust
67/// fn yield_one_after(shape: &[i64], idx: usize) -> Vec<i64>
68/// ```
69pub fn yield_one_before(shape: &[i64], idx: usize) -> Vec<i64> {
70 let mut new_shape = Vec::with_capacity(shape.len() + 1);
71 for (i, s) in shape.iter().enumerate() {
72 if i == idx {
73 new_shape.push(1);
74 new_shape.push(*s);
75 } else {
76 new_shape.push(*s);
77 }
78 }
79 if idx == shape.len() {
80 new_shape.push(1);
81 }
82 new_shape
83}
84
85/// Inserts a `1` into a shape vector immediately after a specified index.
86///
87/// The `yield_one_after` function takes a slice representing the shape of a tensor and an index,
88/// and returns a new shape vector where the value `1` is inserted immediately after the specified index.
89/// This is useful for reshaping tensors, especially when you need to add a singleton dimension
90/// for broadcasting or other tensor operations.
91///
92/// # Parameters
93///
94/// - `shape`: A slice of `i64` representing the original shape of the tensor.
95/// - `idx`: A `usize` index after which the value `1` will be inserted into the shape.
96///
97/// # Returns
98///
99/// - A `Vec<i64>` representing the new shape with the value `1` inserted after the specified index.
100///
101/// # Examples
102///
103/// ```rust
104/// // Example 1: Inserting after the first dimension
105/// let shape = vec![2, 3, 4];
106/// let idx = 0;
107/// let new_shape = yield_one_after(&shape, idx);
108/// assert_eq!(new_shape, vec![2, 1, 3, 4]);
109///
110/// // Example 2: Inserting after the second dimension
111/// let shape = vec![5, 6, 7];
112/// let idx = 1;
113/// let new_shape = yield_one_after(&shape, idx);
114/// assert_eq!(new_shape, vec![5, 6, 1, 7]);
115///
116/// // Example 3: Inserting after the last dimension
117/// let shape = vec![8, 9];
118/// let idx = 1;
119/// let new_shape = yield_one_after(&shape, idx);
120/// assert_eq!(new_shape, vec![8, 9, 1]);
121/// ```
122///
123/// # Notes
124///
125/// - **Index Bounds**: The `idx` parameter must be less than or equal to `shape.len() - 1`.
126/// - If `idx` is equal to `shape.len() - 1`, the `1` will be appended at the end of the shape vector.
127/// - If `idx` is greater than `shape.len() - 1`, the function will panic due to an out-of-bounds index.
128/// - **Non-mutating**: The function does not modify the original `shape` slice; it returns a new `Vec<i64>`.
129///
130/// # Use Cases
131///
132/// - **Adding a Dimension**: Useful when you need to add a singleton dimension to a tensor for operations like broadcasting.
133/// - **Reshaping Tensors**: Helps in reshaping tensors to match required dimensions for certain mathematical operations.
134///
135/// # Edge Cases
136///
137/// - **Empty Shape**: If the `shape` slice is empty, the function will panic if `idx` is not zero.
138/// ```rust
139/// let shape: Vec<i64> = vec![];
140/// let idx = 0;
141/// let new_shape = yield_one_after(&shape, idx);
142/// assert_eq!(new_shape, vec![1]); // Inserts `1` at position 0
143/// ```
144///
145/// # Panics
146///
147/// - The function will panic if `idx` is greater than `shape.len()`.
148///
149/// # See Also
150///
151/// ```rust
152/// fn yield_one_before(shape: &[i64], idx: usize) -> Vec<i64>
153/// ```
154pub fn yield_one_after(shape: &[i64], idx: usize) -> Vec<i64> {
155 let mut new_shape = Vec::with_capacity(shape.len() + 1);
156 for (i, s) in shape.iter().enumerate() {
157 if i == idx {
158 new_shape.push(*s);
159 new_shape.push(1);
160 } else {
161 new_shape.push(*s);
162 }
163 }
164 new_shape
165}
166
167/// Pads a shape with ones on the left to reach a specified length.
168///
169/// The `try_pad_shape` function takes an existing shape (a slice of `i64` values) and pads it with
170/// ones on the left side to ensure the shape has the desired length. If the existing shape's length
171/// is already equal to or greater than the desired length, the function returns the shape as is.
172///
173/// This is particularly useful in tensor operations where broadcasting rules require shapes to have
174/// the same number of dimensions.
175///
176/// # Parameters
177///
178/// - `shape`: A slice of `i64` representing the original shape of the tensor.
179/// - `length`: The desired length of the shape after padding.
180///
181/// # Returns
182///
183/// - A `Vec<i64>` representing the new shape, padded with ones on the left if necessary.
184///
185/// # Examples
186///
187/// ```rust
188/// // Example 1: Padding is needed
189/// let shape = vec![3, 4];
190/// let padded_shape = try_pad_shape(&shape, 4);
191/// assert_eq!(padded_shape, vec![1, 1, 3, 4]);
192///
193/// // Example 2: No padding is needed
194/// let shape = vec![2, 3, 4];
195/// let padded_shape = try_pad_shape(&shape, 2);
196/// assert_eq!(padded_shape, vec![2, 3, 4]); // Shape is returned as is
197/// ```
198///
199/// # Notes
200///
201/// - **Left Padding**: The function pads the shape with ones on the left side (i.e., it adds new
202/// dimensions to the beginning of the shape).
203/// - **Use Case**: This is useful for aligning shapes in operations that require input tensors to have
204/// the same number of dimensions, such as broadcasting in tensor computations.
205///
206/// # Implementation Details
207///
208/// - **Length Check**: The function first checks if the desired `length` is less than or equal to the
209/// current length of `shape`. If so, it returns a copy of `shape` as is.
210/// - **Padding Logic**: If padding is needed, it creates a new vector filled with ones of size `length`.
211/// It then copies the original shape's elements into the rightmost positions of this new vector,
212/// effectively padding the left side with ones.
213///
214/// # Edge Cases
215///
216/// - If `length` is zero, the function returns an empty vector.
217/// - If `shape` is empty and `length` is greater than zero, the function returns a vector of ones
218/// with the specified `length`.
219///
220/// # See Also
221///
222/// - Functions that handle shape manipulation and broadcasting in tensor operations.
223///
224/// # Example Usage in Context
225///
226/// ```rust
227/// // Assume we have two tensors with shapes [3, 4] and [4].
228/// // To perform element-wise operations, we need to align their shapes.
229/// let a_shape = vec![3, 4];
230/// let b_shape = vec![4];
231///
232/// // Pad the smaller shape to match the number of dimensions.
233/// let padded_b_shape = try_pad_shape(&b_shape, a_shape.len());
234/// assert_eq!(padded_b_shape, vec![1, 4]);
235///
236/// // Now both shapes have the same number of dimensions and can be broadcast together.
237/// ```
238pub fn try_pad_shape(shape: &[i64], length: usize) -> Vec<i64> {
239 // If the current shape length is already equal or greater, return it as is.
240 if length <= shape.len() {
241 return shape.to_vec();
242 }
243
244 // Otherwise, create a new shape vector with ones and overlay the existing shape on it.
245 let mut ret = vec![1; length];
246 for (existing, new) in shape.iter().rev().zip(ret.iter_mut().rev()) {
247 *new = *existing;
248 }
249
250 ret
251}
252
253/// pad shape to the shortter one, this is used for prepareing for matmul broadcast.
254///
255/// possibly we can make it works in more generic cases not only matmul
256pub fn compare_and_pad_shapes(a_shape: &[i64], b_shape: &[i64]) -> (Vec<i64>, Vec<i64>) {
257 let len_diff = i64::abs((a_shape.len() as i64) - (b_shape.len() as i64)) as usize;
258 let (longer, shorter) = if a_shape.len() > b_shape.len() {
259 (a_shape, b_shape)
260 } else {
261 (b_shape, a_shape)
262 };
263
264 let mut padded_shorter = vec![1; len_diff];
265 padded_shorter.extend_from_slice(shorter);
266 (longer.to_vec(), padded_shorter)
267}
268
269/// pad shape and strides to the shortter one, this is used for prepareing for matmul broadcast.
270///
271/// possibly we can make it works in more generic cases not only matmul
272pub fn compare_and_pad_shapes_strides(
273 a_shape: &[i64],
274 b_shape: &[i64],
275 a_strides: &[i64],
276 b_strides: &[i64],
277) -> (Vec<i64>, Vec<i64>, Vec<i64>, Vec<i64>) {
278 let len_diff = i64::abs((a_shape.len() as i64) - (b_shape.len() as i64)) as usize;
279 let (longer, shorter, longer_strides, shorter_strides) = if a_shape.len() > b_shape.len() {
280 (a_shape, b_shape, a_strides, b_strides)
281 } else {
282 (b_shape, a_shape, b_strides, a_strides)
283 };
284
285 let mut padded_shorter = vec![1; len_diff];
286 let mut padded_shorter_strides = vec![0; len_diff];
287 padded_shorter.extend_from_slice(shorter);
288 padded_shorter_strides.extend_from_slice(shorter_strides);
289 (
290 longer.to_vec(),
291 padded_shorter,
292 longer_strides.to_vec(),
293 padded_shorter_strides,
294 )
295}
296
297/// Predicts the broadcasted shape resulting from broadcasting two arrays.
298///
299/// The `predict_broadcast_shape` function computes the resulting shape when two arrays with shapes
300/// `a_shape` and `b_shape` are broadcast together. Broadcasting is a technique that allows arrays of
301/// different shapes to be used together in arithmetic operations by "stretching" one or both arrays
302/// so that they have compatible shapes.
303///
304/// # Parameters
305///
306/// - `a_shape`: A slice of `i64` representing the shape of the first array.
307/// - `b_shape`: A slice of `i64` representing the shape of the second array.
308///
309/// # Returns
310///
311/// - `Ok(Shape)`: The resulting broadcasted shape as a `Shape` object if broadcasting is possible.
312/// - `Err(anyhow::Error)`: An error if the shapes cannot be broadcast together.
313///
314/// # Broadcasting Rules
315///
316/// The broadcasting rules determine how two arrays of different shapes can be broadcast together:
317///
318/// 1. **Alignment**: The shapes are right-aligned, meaning that the last dimensions are compared first.
319/// If one shape has fewer dimensions, it is left-padded with ones to match the other shape's length.
320///
321/// 2. **Dimension Compatibility**: For each dimension from the last to the first:
322/// - If the dimensions are equal, they are compatible.
323/// - If one of the dimensions is 1, the array in that dimension can be broadcast to match the other dimension.
324/// - If the dimensions are not equal and neither is 1, broadcasting is not possible.
325///
326/// # Example
327///
328/// ```rust
329/// // Assuming Shape and the necessary imports are defined appropriately.
330///
331/// let a_shape = &[8, 1, 6, 1];
332/// let b_shape = &[7, 1, 5];
333///
334/// match predict_broadcast_shape(a_shape, b_shape) {
335/// Ok(result_shape) => {
336/// assert_eq!(result_shape, Shape::from(vec![8, 7, 6, 5]));
337/// println!("Broadcasted shape: {:?}", result_shape);
338/// },
339/// Err(e) => {
340/// println!("Error: {}", e);
341/// },
342/// }
343/// ```
344///
345/// In this example:
346///
347/// - `a_shape` has shape `[8, 1, 6, 1]`.
348/// - `b_shape` has shape `[7, 1, 5]`.
349/// - After padding `b_shape` to `[1, 7, 1, 5]`, the shapes are compared element-wise from the last dimension.
350/// - The resulting broadcasted shape is `[8, 7, 6, 5]`.
351///
352/// # Notes
353///
354/// - The function assumes that shapes are represented as slices of `i64`.
355/// - The function uses a helper function `try_pad_shape` to pad the shorter shape with ones on the left.
356/// - If broadcasting is not possible, the function returns an error indicating the dimension at which the incompatibility occurs.
357///
358/// # Errors
359///
360/// - Returns an error if at any dimension the sizes differ and neither is 1, indicating that broadcasting cannot be performed.
361///
362/// # Implementation Details
363///
364/// - The function first determines which of the two shapes is longer and which is shorter.
365/// - The shorter shape is padded on the left with ones to match the length of the longer shape.
366/// - It then iterates over the dimensions, comparing corresponding dimensions from each shape:
367/// - If the dimensions are equal or one of them is 1, the resulting dimension is set to the maximum of the two.
368/// - If neither condition is met, an error is returned.
369#[track_caller]
370pub fn predict_broadcast_shape(
371 a_shape: &[i64],
372 b_shape: &[i64],
373) -> std::result::Result<Shape, TensorError> {
374 let (longer, shorter) = if a_shape.len() >= b_shape.len() {
375 (a_shape, b_shape)
376 } else {
377 (b_shape, a_shape)
378 };
379
380 let padded_shorter = try_pad_shape(shorter, longer.len());
381 let mut result_shape = vec![0; longer.len()];
382
383 for (i, (&longer_dim, &shorter_dim)) in longer.iter().zip(&padded_shorter).enumerate() {
384 result_shape[i] = if longer_dim == shorter_dim || shorter_dim == 1 {
385 longer_dim
386 } else if longer_dim == 1 {
387 shorter_dim
388 } else {
389 return Err(ShapeError::BroadcastError {
390 message: format!(
391 "broadcast failed at index {}, lhs shape: {:?}, rhs shape: {:?}",
392 i, a_shape, b_shape
393 ),
394 location: Location::caller(),
395 }
396 .into());
397 };
398 }
399
400 Ok(Shape::from(result_shape))
401}
402
403/// Determines the axes along which broadcasting is required to match a desired result shape.
404///
405/// The `get_broadcast_axes_from` function computes the indices of axes along which the input array `a`
406/// needs to be broadcasted to match the target shape `res_shape`. Broadcasting is a method used in
407/// tensor operations to allow arrays of different shapes to be used together in arithmetic operations.
408///
409/// **Note**: This function is adapted from NumPy's broadcasting rules and implementation.
410///
411/// # Parameters
412///
413/// - `a_shape`: A slice of `i64` representing the shape of the input array `a`.
414/// - `res_shape`: A slice of `i64` representing the desired result shape after broadcasting.
415/// - `location`: A `Location` object indicating the source code location for error reporting.
416///
417/// # Returns
418///
419/// - `Ok(Vec<usize>)`: A vector containing the indices of the axes along which broadcasting occurs.
420/// - `Err(anyhow::Error)`: An error if broadcasting is not possible due to incompatible shapes.
421///
422/// # Broadcasting Rules
423///
424/// Broadcasting follows specific rules to align arrays of different shapes:
425///
426/// 1. **Left Padding**: If the input array `a_shape` has fewer dimensions than `res_shape`, it is left-padded
427/// with ones to match the number of dimensions of `res_shape`.
428///
429/// 2. **Dimension Compatibility**: For each dimension from the most significant (leftmost) to the least significant
430/// (rightmost):
431/// - If the dimension sizes are equal, no broadcasting is needed for that axis.
432/// - If the dimension size in `a_shape` is 1 and in `res_shape` is greater than 1, broadcasting occurs along that axis.
433/// - If the dimension size in `res_shape` is 1 and in `a_shape` is greater than 1, broadcasting is not possible,
434/// and an error is returned.
435///
436/// 3. **Collecting Broadcast Axes**: The axes where broadcasting occurs are collected and returned.
437///
438/// # Example
439///
440/// ```rust
441/// use anyhow::Result;
442/// // Assuming `get_broadcast_axes_from` and `Location` are defined appropriately
443///
444/// fn main() -> Result<()> {
445/// let a_shape = &[3, 1];
446/// let res_shape = &[3, 4];
447/// let location = Location::new("module_name", "function_name");
448///
449/// let axes = get_broadcast_axes_from(a_shape, res_shape, location)?;
450/// assert_eq!(axes, vec![1]);
451///
452/// println!("Broadcast axes: {:?}", axes);
453/// Ok(())
454/// }
455/// ```
456///
457/// In this example:
458///
459/// - The input array has shape `[3, 1]`.
460/// - The desired result shape is `[3, 4]`.
461/// - Broadcasting occurs along axis `1`, so the function returns `vec![1]`.
462///
463/// # Notes
464///
465/// - **Padding Shapes**: If `a_shape` has fewer dimensions than `res_shape`, it is padded on the left with ones
466/// to align the dimensions.
467///
468/// - **Axes Indices**: The axes indices are zero-based and correspond to the dimensions of the padded `a_shape`.
469///
470/// - **Error Handling**: If broadcasting is not possible due to incompatible dimensions, the function returns an error
471/// using `ErrHandler::BroadcastError`, providing detailed information about the mismatch.
472///
473/// - **Implementation Details**:
474/// - The function first calculates the difference in the number of dimensions and pads `a_shape` accordingly.
475/// - It then iterates over the dimensions to identify axes where broadcasting is needed or not possible.
476///
477/// # Errors
478///
479/// - Returns an error if any dimension in `res_shape` is `1` while the corresponding dimension in `a_shape` is
480/// greater than `1`, as broadcasting cannot be performed in this case.
481#[track_caller]
482pub fn get_broadcast_axes_from(
483 a_shape: &[i64],
484 res_shape: &[i64],
485) -> std::result::Result<Vec<usize>, TensorError> {
486 assert!(a_shape.len() <= res_shape.len());
487
488 let padded_a = try_pad_shape(a_shape, res_shape.len());
489
490 let mut axes = Vec::new();
491 let padded_axes = (0..res_shape.len() - a_shape.len()).collect::<Vec<usize>>();
492 for i in padded_axes.iter() {
493 axes.push(*i);
494 }
495
496 for (i, (&res_dim, &a_dim)) in res_shape.iter().zip(&padded_a).enumerate() {
497 if a_dim == 1 && res_dim != 1 && !padded_axes.contains(&i) {
498 axes.push(i);
499 } else if res_dim == 1 && a_dim != 1 {
500 return Err(ShapeError::BroadcastError {
501 message: format!(
502 "broadcast failed at index {}, lhs shape: {:?}, rhs shape: {:?}",
503 i, a_shape, res_shape
504 ),
505 location: Location::caller(),
506 }
507 .into());
508 }
509 }
510
511 Ok(axes)
512}
513
514// This file contains code translated from NumPy (https://github.com/numpy/numpy)
515// Original work Copyright (c) 2005-2025, NumPy Developers
516// Modified work Copyright (c) 2025 hpt Contributors
517//
518// Redistribution and use in source and binary forms, with or without
519// modification, are permitted provided that the following conditions are
520// met:
521
522// * Redistributions of source code must retain the above copyright
523// notice, this list of conditions and the following disclaimer.
524
525// * Redistributions in binary form must reproduce the above
526// copyright notice, this list of conditions and the following
527// disclaimer in the documentation and/or other materials provided
528// with the distribution.
529
530// * Neither the name of the NumPy Developers nor the names of any
531// contributors may be used to endorse or promote products derived
532// from this software without specific prior written permission.
533
534// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
535// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
536// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
537// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
538// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
539// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
540// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
541// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
542// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
543// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
544// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
545//
546// This Rust port is additionally licensed under Apache-2.0 OR MIT
547// See repository root for details
548
549/// Attempt to reshape an array without copying data.
550/// Translated from NumPy's _attempt_nocopy_reshape function.
551pub fn is_reshape_possible(
552 original_shape: &[i64],
553 original_strides: &[i64],
554 new_shape: &[i64],
555) -> Option<Strides> {
556 let mut new_strides = vec![0; new_shape.len()];
557 let mut old_strides = vec![0; original_shape.len()];
558 let mut old_shape = vec![0; original_shape.len()];
559
560 let mut oi = 0;
561 let mut oj = 1;
562 let mut ni = 0;
563 let mut nj = 1;
564
565 let mut oldnd = 0;
566
567 for i in 0..original_shape.len() {
568 if original_shape[i] != 1 {
569 old_shape[oldnd] = original_shape[i];
570 old_strides[oldnd] = original_strides[i];
571 oldnd += 1;
572 }
573 }
574
575 while ni < new_shape.len() && oi < oldnd {
576 let mut np = new_shape[ni];
577 let mut op = old_shape[oi];
578
579 while np != op {
580 if np < op {
581 np *= new_shape[nj];
582 nj += 1;
583 } else {
584 op *= old_shape[oj];
585 oj += 1;
586 }
587 }
588
589 for i in oi..oj - 1 {
590 if old_strides[i] != old_shape[i + 1] * old_strides[i + 1] {
591 return None;
592 }
593 }
594
595 new_strides[nj - 1] = old_strides[oj - 1];
596 for i in (ni + 1..nj).rev() {
597 new_strides[i - 1] = new_strides[i] * new_shape[i];
598 }
599
600 ni = nj;
601 nj += 1;
602 oi = oj;
603 oj += 1;
604 }
605
606 let last_stride = if ni >= 1 { new_strides[ni - 1] } else { 1 };
607
608 for i in ni..new_shape.len() {
609 new_strides[i] = last_stride;
610 }
611
612 Some(new_strides.into())
613}
614
615/// Generates intervals for multi-threaded processing by dividing the outer loop into chunks.
616///
617/// The `mt_intervals` function divides a large outer loop into multiple smaller intervals to be
618/// processed by multiple threads. The function aims to distribute the workload as evenly as possible
619/// among the available threads, handling cases where the total number of iterations is not perfectly
620/// divisible by the number of threads.
621///
622/// # Parameters
623///
624/// - `outer_loop_size`: The total number of iterations in the outer loop.
625/// - `num_threads`: The number of threads to divide the work among.
626///
627/// # Returns
628///
629/// A `Vec` of tuples `(usize, usize)`, where each tuple represents the start (inclusive) and end
630/// (exclusive) indices of the interval assigned to each thread.
631///
632/// # Algorithm Overview
633///
634/// 1. **Calculate Base Workload**: Each thread is assigned at least `outer_loop_size / num_threads` iterations.
635/// 2. **Distribute Remainder**: If `outer_loop_size` is not divisible by `num_threads`, the remaining iterations
636/// (`outer_loop_size % num_threads`) are distributed one by one to the first few threads.
637/// 3. **Calculate Start and End Indices**:
638/// - The `start_index` for each thread `i` is calculated as:
639/// ```
640/// i * (outer_loop_size / num_threads) + min(i, outer_loop_size % num_threads)
641/// ```
642/// - The `end_index` is then calculated by adding the base workload and an extra iteration if the thread
643/// received an extra iteration from the remainder.
644///
645/// # Examples
646///
647/// ```rust
648/// fn main() {
649/// let outer_loop_size = 10;
650/// let num_threads = 3;
651///
652/// let intervals = mt_intervals(outer_loop_size, num_threads);
653///
654/// for (i, (start, end)) in intervals.iter().enumerate() {
655/// println!("Thread {}: Processing indices [{}..{})", i, start, end);
656/// }
657/// }
658/// ```
659///
660/// Output:
661///
662/// ```text
663/// Thread 0: Processing indices [0..4)
664/// Thread 1: Processing indices [4..7)
665/// Thread 2: Processing indices [7..10)
666/// ```
667///
668/// In this example:
669/// - The total number of iterations is 10.
670/// - The number of threads is 3.
671/// - Each thread gets at least `10 / 3 = 3` iterations.
672/// - The remainder is `10 % 3 = 1`. So, the first thread gets one extra iteration.
673///
674/// # Notes
675///
676/// - **Workload Balance**: The function ensures that the workload is distributed as evenly as possible.
677/// - **Integer Division**: Since integer division truncates towards zero, the remainder is used to distribute
678/// the extra iterations.
679/// - **Index Calculation**: The calculation uses `std::cmp::min` to ensure that only the first `remainder` threads
680/// receive the extra iteration.
681///
682/// # Function Definition
683///
684/// ```rust
685/// pub fn mt_intervals(outer_loop_size: usize, num_threads: usize) -> Vec<(usize, usize)> {
686/// let mut intervals = Vec::with_capacity(num_threads);
687/// for i in 0..num_threads {
688/// let start_index = i * (outer_loop_size / num_threads)
689/// + std::cmp::min(i, outer_loop_size % num_threads);
690/// let end_index = start_index
691/// + outer_loop_size / num_threads
692/// + ((i < outer_loop_size % num_threads) as usize);
693/// intervals.push((start_index, end_index));
694/// }
695/// intervals
696/// }
697/// ```
698///
699/// # Unit Tests
700///
701/// Here are some unit tests to verify the correctness of the function:
702///
703/// ```rust
704/// #[cfg(test)]
705/// mod tests {
706/// use super::*;
707///
708/// #[test]
709/// fn test_even_division() {
710/// let intervals = mt_intervals(100, 4);
711/// assert_eq!(intervals.len(), 4);
712/// assert_eq!(intervals[0], (0, 25));
713/// assert_eq!(intervals[1], (25, 50));
714/// assert_eq!(intervals[2], (50, 75));
715/// assert_eq!(intervals[3], (75, 100));
716/// }
717///
718/// #[test]
719/// fn test_uneven_division() {
720/// let intervals = mt_intervals(10, 3);
721/// assert_eq!(intervals.len(), 3);
722/// assert_eq!(intervals[0], (0, 4));
723/// assert_eq!(intervals[1], (4, 7));
724/// assert_eq!(intervals[2], (7, 10));
725/// }
726///
727/// #[test]
728/// fn test_more_threads_than_work() {
729/// let intervals = mt_intervals(5, 10);
730/// assert_eq!(intervals.len(), 10);
731/// assert_eq!(intervals[0], (0, 1));
732/// assert_eq!(intervals[1], (1, 2));
733/// assert_eq!(intervals[2], (2, 3));
734/// assert_eq!(intervals[3], (3, 4));
735/// assert_eq!(intervals[4], (4, 5));
736/// for i in 5..10 {
737/// assert_eq!(intervals[i], (5, 5));
738/// }
739/// }
740///
741/// #[test]
742/// fn test_zero_iterations() {
743/// let intervals = mt_intervals(0, 4);
744/// assert_eq!(intervals.len(), 4);
745/// for &(start, end) in &intervals {
746/// assert_eq!(start, 0);
747/// assert_eq!(end, 0);
748/// }
749/// }
750///
751/// #[test]
752/// fn test_zero_threads() {
753/// let intervals = mt_intervals(10, 0);
754/// assert_eq!(intervals.len(), 0);
755/// }
756/// }
757/// ```
758///
759/// # Caveats
760///
761/// - If `num_threads` is zero, the function will return an empty vector.
762/// - If `outer_loop_size` is zero, all intervals will have start and end indices of zero.
763///
764/// # Performance Considerations
765///
766/// - **Allocation**: The function pre-allocates the vector with capacity `num_threads`.
767/// - **Integer Operations**: The function uses integer division and modulo operations, which are efficient.
768///
769/// # Conclusion
770///
771/// The `mt_intervals` function is useful for dividing work among multiple threads in a balanced way, ensuring that
772/// each thread gets a fair share of the workload, even when the total number of iterations is not perfectly divisible
773/// by the number of threads.
774
775pub fn mt_intervals(outer_loop_size: usize, num_threads: usize) -> Vec<(usize, usize)> {
776 let mut intervals = Vec::with_capacity(num_threads);
777 for i in 0..num_threads {
778 let start_index =
779 i * (outer_loop_size / num_threads) + std::cmp::min(i, outer_loop_size % num_threads);
780 let end_index = start_index
781 + outer_loop_size / num_threads
782 + ((i < outer_loop_size % num_threads) as usize);
783 intervals.push((start_index, end_index));
784 }
785 intervals
786}
787
788/// Generates intervals for multi-threaded SIMD processing by dividing the outer loop into chunks.
789///
790/// The `mt_intervals_simd` function divides a large outer loop into multiple smaller intervals
791/// to be processed by multiple threads. Each interval is aligned with the SIMD vector size to
792/// optimize performance. This ensures that each thread processes a chunk of data that is a
793/// multiple of the SIMD vector size, which is beneficial for vectorized operations.
794///
795/// # Parameters
796///
797/// - `outer_loop_size`: The total size of the outer loop (number of iterations).
798/// - `num_threads`: The desired number of threads to use for processing.
799/// - `vec_size`: The size of the SIMD vector (number of elements processed in one SIMD operation).
800///
801/// # Returns
802///
803/// A `Vec` of tuples `(usize, usize)`, where each tuple represents the start (inclusive) and
804/// end (exclusive) indices of the interval assigned to a thread.
805///
806/// # Algorithm Overview
807///
808/// 1. **Determine Maximum Threads**: Calculate `max_threads` as `outer_loop_size / vec_size` to
809/// ensure each thread has at least one full SIMD vector's worth of work.
810/// 2. **Adjust Thread Count**: Set `actual_threads` to the minimum of `num_threads` and
811/// `max_threads` to avoid creating more threads than necessary.
812/// 3. **Calculate Base Block Count and Remainder**:
813/// - `base_block_count` is the number of full blocks each thread will process.
814/// - `remainder` is the number of remaining blocks that couldn't be evenly divided.
815/// 4. **Assign Intervals to Threads**:
816/// - Distribute the extra blocks from the remainder among the first `remainder` threads.
817/// - Calculate `start_index` and `end_index` for each thread accordingly.
818///
819/// # Examples
820///
821/// ```rust
822/// fn main() {
823/// let outer_loop_size = 1000;
824/// let num_threads = 4;
825/// let vec_size = 8;
826///
827/// let intervals = mt_intervals_simd(outer_loop_size, num_threads, vec_size);
828///
829/// for (i, (start, end)) in intervals.iter().enumerate() {
830/// println!("Thread {}: Processing indices [{}..{})", i, start, end);
831/// }
832/// }
833/// ```
834///
835/// Output might be:
836///
837/// ```text
838/// Thread 0: Processing indices [0..200)
839/// Thread 1: Processing indices [200..400)
840/// Thread 2: Processing indices [400..600)
841/// Thread 3: Processing indices [600..800)
842/// ```
843///
844/// # Notes
845///
846/// - **Data Alignment**: The function ensures that each interval's size is a multiple of `vec_size`
847/// to maintain data alignment for SIMD operations.
848/// - **Load Balancing**: Extra iterations resulting from the remainder are distributed among the
849/// first few threads to balance the workload.
850///
851/// # Panics
852///
853/// The function does not explicitly panic, but providing a `vec_size` of zero will result in a
854/// division by zero error.
855///
856/// # See Also
857///
858/// - SIMD (Single Instruction, Multiple Data) processing.
859/// - Multi-threading in Rust.
860///
861/// # Caveats
862///
863/// - Ensure that `vec_size` is not zero to avoid division by zero errors.
864/// - The function assumes that `outer_loop_size`, `num_threads`, and `vec_size` are positive integers.
865///
866/// # Performance Considerations
867///
868/// - **Thread Overhead**: Creating too many threads may introduce overhead. The function limits the
869/// number of threads to the maximum useful amount based on `outer_loop_size` and `vec_size`.
870/// - **SIMD Efficiency**: Aligning intervals to `vec_size` improves SIMD efficiency by preventing
871/// partial vector loads and stores.
872///
873/// # Conclusion
874///
875/// The `mt_intervals_simd` function is useful for parallelizing loops in applications that benefit
876/// from both multi-threading and SIMD vectorization. By carefully dividing the work into appropriately
877/// sized intervals, it helps maximize performance on modern CPUs.
878pub fn mt_intervals_simd(
879 outer_loop_size: usize,
880 num_threads: usize,
881 vec_size: usize,
882) -> Vec<(usize, usize)> {
883 assert!(vec_size > 0, "vec_size must be greater than zero");
884 assert!(num_threads > 0, "num_threads must be greater than zero");
885
886 let aligned_size = (outer_loop_size / vec_size) * vec_size;
887 let remainder = outer_loop_size - aligned_size;
888
889 let mut intervals = Vec::with_capacity(num_threads);
890
891 if aligned_size > 0 {
892 let total_vec_blocks = aligned_size / vec_size;
893 let base_blocks_per_thread = total_vec_blocks / num_threads;
894 let extra_blocks = total_vec_blocks % num_threads;
895
896 let mut start = 0;
897
898 for i in 0..num_threads {
899 let mut blocks = base_blocks_per_thread;
900
901 if i < extra_blocks {
902 blocks += 1;
903 }
904
905 let end = start + blocks * vec_size;
906 intervals.push((start, end));
907 start = end;
908 }
909
910 if remainder > 0 {
911 if let Some(last) = intervals.last_mut() {
912 *last = (last.0, last.1 + remainder);
913 }
914 }
915 }
916
917 if aligned_size == 0 && remainder > 0 {
918 if num_threads >= 1 {
919 intervals.push((0, remainder));
920 for _ in 1..num_threads {
921 intervals.push((0, 0));
922 }
923 }
924 } else if aligned_size > 0 {
925 while intervals.len() < num_threads {
926 intervals.push((aligned_size, aligned_size));
927 }
928 } else {
929 for _ in intervals.len()..num_threads {
930 intervals.push((0, 0));
931 }
932 }
933
934 intervals
935}