hpt_common/shape/
shape_utils.rs

1use std::panic::Location;
2
3use crate::{
4    error::{base::TensorError, shape::ShapeError},
5    shape::shape::Shape,
6    strides::strides::Strides,
7};
8
9/// Inserts a dimension of size 1 before the specified index in a shape.
10///
11/// The `yield_one_before` function takes an existing shape (a slice of `i64` values) and inserts
12/// a new dimension of size 1 before the specified index `idx`. This is useful in tensor operations
13/// where you need to expand the dimensions of a tensor by adding singleton dimensions, which can
14/// facilitate broadcasting or other dimension-specific operations.
15///
16/// # Parameters
17///
18/// - `shape`: A slice of `i64` representing the original shape of the tensor.
19/// - `idx`: The index before which a new dimension of size 1 will be inserted.
20///
21/// # Returns
22///
23/// - A `Vec<i64>` representing the new shape with the inserted dimension of size 1.
24///
25/// # Examples
26///
27/// ```rust
28/// // Example 1: Insert before the first dimension
29/// let shape = vec![3, 4, 5];
30/// let idx = 0;
31/// let new_shape = yield_one_before(&shape, idx);
32/// assert_eq!(new_shape, vec![1, 3, 4, 5]);
33///
34/// // Example 2: Insert before a middle dimension
35/// let idx = 2;
36/// let new_shape = yield_one_before(&shape, idx);
37/// assert_eq!(new_shape, vec![3, 4, 1, 5]);
38///
39/// // Example 3: Insert before the last dimension
40/// let idx = 2;
41/// let new_shape = yield_one_before(&shape, idx);
42/// assert_eq!(new_shape, vec![3, 4, 1, 5]);
43///
44/// // Example 4: Index out of bounds (appends 1 at the end)
45/// let idx = 5;
46/// let new_shape = yield_one_before(&shape, idx);
47/// assert_eq!(new_shape, vec![3, 4, 5, 1]);
48/// ```
49///
50/// # Notes
51///
52/// - **Index Bounds**: If `idx` is greater than the length of `shape`, the function will append a
53///   dimension of size 1 at the end of the shape.
54/// - **Use Cases**: Adding a singleton dimension is often used to adjust the shape of a tensor for
55///   broadcasting in element-wise operations or to match required input dimensions for certain
56///   functions.
57/// - **Immutability**: The original `shape` slice is not modified; a new `Vec<i64>` is returned.
58///
59/// # Implementation Details
60///
61/// The function works by iterating over the original shape and copying each dimension into a new
62/// vector. When the current index matches `idx`, it inserts a `1` before copying the next dimension.
63///
64/// # See Also
65///
66/// ```rust
67/// fn yield_one_after(shape: &[i64], idx: usize) -> Vec<i64>
68/// ```
69pub fn yield_one_before(shape: &[i64], idx: usize) -> Vec<i64> {
70    let mut new_shape = Vec::with_capacity(shape.len() + 1);
71    for (i, s) in shape.iter().enumerate() {
72        if i == idx {
73            new_shape.push(1);
74            new_shape.push(*s);
75        } else {
76            new_shape.push(*s);
77        }
78    }
79    if idx == shape.len() {
80        new_shape.push(1);
81    }
82    new_shape
83}
84
85/// Inserts a `1` into a shape vector immediately after a specified index.
86///
87/// The `yield_one_after` function takes a slice representing the shape of a tensor and an index,
88/// and returns a new shape vector where the value `1` is inserted immediately after the specified index.
89/// This is useful for reshaping tensors, especially when you need to add a singleton dimension
90/// for broadcasting or other tensor operations.
91///
92/// # Parameters
93///
94/// - `shape`: A slice of `i64` representing the original shape of the tensor.
95/// - `idx`: A `usize` index after which the value `1` will be inserted into the shape.
96///
97/// # Returns
98///
99/// - A `Vec<i64>` representing the new shape with the value `1` inserted after the specified index.
100///
101/// # Examples
102///
103/// ```rust
104/// // Example 1: Inserting after the first dimension
105/// let shape = vec![2, 3, 4];
106/// let idx = 0;
107/// let new_shape = yield_one_after(&shape, idx);
108/// assert_eq!(new_shape, vec![2, 1, 3, 4]);
109///
110/// // Example 2: Inserting after the second dimension
111/// let shape = vec![5, 6, 7];
112/// let idx = 1;
113/// let new_shape = yield_one_after(&shape, idx);
114/// assert_eq!(new_shape, vec![5, 6, 1, 7]);
115///
116/// // Example 3: Inserting after the last dimension
117/// let shape = vec![8, 9];
118/// let idx = 1;
119/// let new_shape = yield_one_after(&shape, idx);
120/// assert_eq!(new_shape, vec![8, 9, 1]);
121/// ```
122///
123/// # Notes
124///
125/// - **Index Bounds**: The `idx` parameter must be less than or equal to `shape.len() - 1`.
126///   - If `idx` is equal to `shape.len() - 1`, the `1` will be appended at the end of the shape vector.
127///   - If `idx` is greater than `shape.len() - 1`, the function will panic due to an out-of-bounds index.
128/// - **Non-mutating**: The function does not modify the original `shape` slice; it returns a new `Vec<i64>`.
129///
130/// # Use Cases
131///
132/// - **Adding a Dimension**: Useful when you need to add a singleton dimension to a tensor for operations like broadcasting.
133/// - **Reshaping Tensors**: Helps in reshaping tensors to match required dimensions for certain mathematical operations.
134///
135/// # Edge Cases
136///
137/// - **Empty Shape**: If the `shape` slice is empty, the function will panic if `idx` is not zero.
138///   ```rust
139///   let shape: Vec<i64> = vec![];
140///   let idx = 0;
141///   let new_shape = yield_one_after(&shape, idx);
142///   assert_eq!(new_shape, vec![1]); // Inserts `1` at position 0
143///   ```
144///
145/// # Panics
146///
147/// - The function will panic if `idx` is greater than `shape.len()`.
148///
149/// # See Also
150///
151/// ```rust
152/// fn yield_one_before(shape: &[i64], idx: usize) -> Vec<i64>
153/// ```
154pub fn yield_one_after(shape: &[i64], idx: usize) -> Vec<i64> {
155    let mut new_shape = Vec::with_capacity(shape.len() + 1);
156    for (i, s) in shape.iter().enumerate() {
157        if i == idx {
158            new_shape.push(*s);
159            new_shape.push(1);
160        } else {
161            new_shape.push(*s);
162        }
163    }
164    new_shape
165}
166
167/// Pads a shape with ones on the left to reach a specified length.
168///
169/// The `try_pad_shape` function takes an existing shape (a slice of `i64` values) and pads it with
170/// ones on the left side to ensure the shape has the desired length. If the existing shape's length
171/// is already equal to or greater than the desired length, the function returns the shape as is.
172///
173/// This is particularly useful in tensor operations where broadcasting rules require shapes to have
174/// the same number of dimensions.
175///
176/// # Parameters
177///
178/// - `shape`: A slice of `i64` representing the original shape of the tensor.
179/// - `length`: The desired length of the shape after padding.
180///
181/// # Returns
182///
183/// - A `Vec<i64>` representing the new shape, padded with ones on the left if necessary.
184///
185/// # Examples
186///
187/// ```rust
188/// // Example 1: Padding is needed
189/// let shape = vec![3, 4];
190/// let padded_shape = try_pad_shape(&shape, 4);
191/// assert_eq!(padded_shape, vec![1, 1, 3, 4]);
192///
193/// // Example 2: No padding is needed
194/// let shape = vec![2, 3, 4];
195/// let padded_shape = try_pad_shape(&shape, 2);
196/// assert_eq!(padded_shape, vec![2, 3, 4]); // Shape is returned as is
197/// ```
198///
199/// # Notes
200///
201/// - **Left Padding**: The function pads the shape with ones on the left side (i.e., it adds new
202///   dimensions to the beginning of the shape).
203/// - **Use Case**: This is useful for aligning shapes in operations that require input tensors to have
204///   the same number of dimensions, such as broadcasting in tensor computations.
205///
206/// # Implementation Details
207///
208/// - **Length Check**: The function first checks if the desired `length` is less than or equal to the
209///   current length of `shape`. If so, it returns a copy of `shape` as is.
210/// - **Padding Logic**: If padding is needed, it creates a new vector filled with ones of size `length`.
211///   It then copies the original shape's elements into the rightmost positions of this new vector,
212///   effectively padding the left side with ones.
213///
214/// # Edge Cases
215///
216/// - If `length` is zero, the function returns an empty vector.
217/// - If `shape` is empty and `length` is greater than zero, the function returns a vector of ones
218///   with the specified `length`.
219///
220/// # See Also
221///
222/// - Functions that handle shape manipulation and broadcasting in tensor operations.
223///
224/// # Example Usage in Context
225///
226/// ```rust
227/// // Assume we have two tensors with shapes [3, 4] and [4].
228/// // To perform element-wise operations, we need to align their shapes.
229/// let a_shape = vec![3, 4];
230/// let b_shape = vec![4];
231///
232/// // Pad the smaller shape to match the number of dimensions.
233/// let padded_b_shape = try_pad_shape(&b_shape, a_shape.len());
234/// assert_eq!(padded_b_shape, vec![1, 4]);
235///
236/// // Now both shapes have the same number of dimensions and can be broadcast together.
237/// ```
238pub fn try_pad_shape(shape: &[i64], length: usize) -> Vec<i64> {
239    // If the current shape length is already equal or greater, return it as is.
240    if length <= shape.len() {
241        return shape.to_vec();
242    }
243
244    // Otherwise, create a new shape vector with ones and overlay the existing shape on it.
245    let mut ret = vec![1; length];
246    for (existing, new) in shape.iter().rev().zip(ret.iter_mut().rev()) {
247        *new = *existing;
248    }
249
250    ret
251}
252
253/// pad shape to the shortter one, this is used for prepareing for matmul broadcast.
254///
255/// possibly we can make it works in more generic cases not only matmul
256pub fn compare_and_pad_shapes(a_shape: &[i64], b_shape: &[i64]) -> (Vec<i64>, Vec<i64>) {
257    let len_diff = i64::abs((a_shape.len() as i64) - (b_shape.len() as i64)) as usize;
258    let (longer, shorter) = if a_shape.len() > b_shape.len() {
259        (a_shape, b_shape)
260    } else {
261        (b_shape, a_shape)
262    };
263
264    let mut padded_shorter = vec![1; len_diff];
265    padded_shorter.extend_from_slice(shorter);
266    (longer.to_vec(), padded_shorter)
267}
268
269/// pad shape and strides to the shortter one, this is used for prepareing for matmul broadcast.
270///
271/// possibly we can make it works in more generic cases not only matmul
272pub fn compare_and_pad_shapes_strides(
273    a_shape: &[i64],
274    b_shape: &[i64],
275    a_strides: &[i64],
276    b_strides: &[i64],
277) -> (Vec<i64>, Vec<i64>, Vec<i64>, Vec<i64>) {
278    let len_diff = i64::abs((a_shape.len() as i64) - (b_shape.len() as i64)) as usize;
279    let (longer, shorter, longer_strides, shorter_strides) = if a_shape.len() > b_shape.len() {
280        (a_shape, b_shape, a_strides, b_strides)
281    } else {
282        (b_shape, a_shape, b_strides, a_strides)
283    };
284
285    let mut padded_shorter = vec![1; len_diff];
286    let mut padded_shorter_strides = vec![0; len_diff];
287    padded_shorter.extend_from_slice(shorter);
288    padded_shorter_strides.extend_from_slice(shorter_strides);
289    (
290        longer.to_vec(),
291        padded_shorter,
292        longer_strides.to_vec(),
293        padded_shorter_strides,
294    )
295}
296
297/// Predicts the broadcasted shape resulting from broadcasting two arrays.
298///
299/// The `predict_broadcast_shape` function computes the resulting shape when two arrays with shapes
300/// `a_shape` and `b_shape` are broadcast together. Broadcasting is a technique that allows arrays of
301/// different shapes to be used together in arithmetic operations by "stretching" one or both arrays
302/// so that they have compatible shapes.
303///
304/// # Parameters
305///
306/// - `a_shape`: A slice of `i64` representing the shape of the first array.
307/// - `b_shape`: A slice of `i64` representing the shape of the second array.
308///
309/// # Returns
310///
311/// - `Ok(Shape)`: The resulting broadcasted shape as a `Shape` object if broadcasting is possible.
312/// - `Err(anyhow::Error)`: An error if the shapes cannot be broadcast together.
313///
314/// # Broadcasting Rules
315///
316/// The broadcasting rules determine how two arrays of different shapes can be broadcast together:
317///
318/// 1. **Alignment**: The shapes are right-aligned, meaning that the last dimensions are compared first.
319///    If one shape has fewer dimensions, it is left-padded with ones to match the other shape's length.
320///
321/// 2. **Dimension Compatibility**: For each dimension from the last to the first:
322///    - If the dimensions are equal, they are compatible.
323///    - If one of the dimensions is 1, the array in that dimension can be broadcast to match the other dimension.
324///    - If the dimensions are not equal and neither is 1, broadcasting is not possible.
325///
326/// # Example
327///
328/// ```rust
329/// // Assuming Shape and the necessary imports are defined appropriately.
330///
331/// let a_shape = &[8, 1, 6, 1];
332/// let b_shape = &[7, 1, 5];
333///
334/// match predict_broadcast_shape(a_shape, b_shape) {
335///     Ok(result_shape) => {
336///         assert_eq!(result_shape, Shape::from(vec![8, 7, 6, 5]));
337///         println!("Broadcasted shape: {:?}", result_shape);
338///     },
339///     Err(e) => {
340///         println!("Error: {}", e);
341///     },
342/// }
343/// ```
344///
345/// In this example:
346///
347/// - `a_shape` has shape `[8, 1, 6, 1]`.
348/// - `b_shape` has shape `[7, 1, 5]`.
349/// - After padding `b_shape` to `[1, 7, 1, 5]`, the shapes are compared element-wise from the last dimension.
350/// - The resulting broadcasted shape is `[8, 7, 6, 5]`.
351///
352/// # Notes
353///
354/// - The function assumes that shapes are represented as slices of `i64`.
355/// - The function uses a helper function `try_pad_shape` to pad the shorter shape with ones on the left.
356/// - If broadcasting is not possible, the function returns an error indicating the dimension at which the incompatibility occurs.
357///
358/// # Errors
359///
360/// - Returns an error if at any dimension the sizes differ and neither is 1, indicating that broadcasting cannot be performed.
361///
362/// # Implementation Details
363///
364/// - The function first determines which of the two shapes is longer and which is shorter.
365/// - The shorter shape is padded on the left with ones to match the length of the longer shape.
366/// - It then iterates over the dimensions, comparing corresponding dimensions from each shape:
367///   - If the dimensions are equal or one of them is 1, the resulting dimension is set to the maximum of the two.
368///   - If neither condition is met, an error is returned.
369#[track_caller]
370pub fn predict_broadcast_shape(
371    a_shape: &[i64],
372    b_shape: &[i64],
373) -> std::result::Result<Shape, TensorError> {
374    let (longer, shorter) = if a_shape.len() >= b_shape.len() {
375        (a_shape, b_shape)
376    } else {
377        (b_shape, a_shape)
378    };
379
380    let padded_shorter = try_pad_shape(shorter, longer.len());
381    let mut result_shape = vec![0; longer.len()];
382
383    for (i, (&longer_dim, &shorter_dim)) in longer.iter().zip(&padded_shorter).enumerate() {
384        result_shape[i] = if longer_dim == shorter_dim || shorter_dim == 1 {
385            longer_dim
386        } else if longer_dim == 1 {
387            shorter_dim
388        } else {
389            return Err(ShapeError::BroadcastError {
390                message: format!(
391                    "broadcast failed at index {}, lhs shape: {:?}, rhs shape: {:?}",
392                    i, a_shape, b_shape
393                ),
394                location: Location::caller(),
395            }
396            .into());
397        };
398    }
399
400    Ok(Shape::from(result_shape))
401}
402
403/// Determines the axes along which broadcasting is required to match a desired result shape.
404///
405/// The `get_broadcast_axes_from` function computes the indices of axes along which the input array `a`
406/// needs to be broadcasted to match the target shape `res_shape`. Broadcasting is a method used in
407/// tensor operations to allow arrays of different shapes to be used together in arithmetic operations.
408///
409/// **Note**: This function is adapted from NumPy's broadcasting rules and implementation.
410///
411/// # Parameters
412///
413/// - `a_shape`: A slice of `i64` representing the shape of the input array `a`.
414/// - `res_shape`: A slice of `i64` representing the desired result shape after broadcasting.
415/// - `location`: A `Location` object indicating the source code location for error reporting.
416///
417/// # Returns
418///
419/// - `Ok(Vec<usize>)`: A vector containing the indices of the axes along which broadcasting occurs.
420/// - `Err(anyhow::Error)`: An error if broadcasting is not possible due to incompatible shapes.
421///
422/// # Broadcasting Rules
423///
424/// Broadcasting follows specific rules to align arrays of different shapes:
425///
426/// 1. **Left Padding**: If the input array `a_shape` has fewer dimensions than `res_shape`, it is left-padded
427///    with ones to match the number of dimensions of `res_shape`.
428///
429/// 2. **Dimension Compatibility**: For each dimension from the most significant (leftmost) to the least significant
430///    (rightmost):
431///    - If the dimension sizes are equal, no broadcasting is needed for that axis.
432///    - If the dimension size in `a_shape` is 1 and in `res_shape` is greater than 1, broadcasting occurs along that axis.
433///    - If the dimension size in `res_shape` is 1 and in `a_shape` is greater than 1, broadcasting is not possible,
434///      and an error is returned.
435///
436/// 3. **Collecting Broadcast Axes**: The axes where broadcasting occurs are collected and returned.
437///
438/// # Example
439///
440/// ```rust
441/// use anyhow::Result;
442/// // Assuming `get_broadcast_axes_from` and `Location` are defined appropriately
443///
444/// fn main() -> Result<()> {
445///     let a_shape = &[3, 1];
446///     let res_shape = &[3, 4];
447///     let location = Location::new("module_name", "function_name");
448///
449///     let axes = get_broadcast_axes_from(a_shape, res_shape, location)?;
450///     assert_eq!(axes, vec![1]);
451///
452///     println!("Broadcast axes: {:?}", axes);
453///     Ok(())
454/// }
455/// ```
456///
457/// In this example:
458///
459/// - The input array has shape `[3, 1]`.
460/// - The desired result shape is `[3, 4]`.
461/// - Broadcasting occurs along axis `1`, so the function returns `vec![1]`.
462///
463/// # Notes
464///
465/// - **Padding Shapes**: If `a_shape` has fewer dimensions than `res_shape`, it is padded on the left with ones
466///   to align the dimensions.
467///
468/// - **Axes Indices**: The axes indices are zero-based and correspond to the dimensions of the padded `a_shape`.
469///
470/// - **Error Handling**: If broadcasting is not possible due to incompatible dimensions, the function returns an error
471///   using `ErrHandler::BroadcastError`, providing detailed information about the mismatch.
472///
473/// - **Implementation Details**:
474///   - The function first calculates the difference in the number of dimensions and pads `a_shape` accordingly.
475///   - It then iterates over the dimensions to identify axes where broadcasting is needed or not possible.
476///
477/// # Errors
478///
479/// - Returns an error if any dimension in `res_shape` is `1` while the corresponding dimension in `a_shape` is
480///   greater than `1`, as broadcasting cannot be performed in this case.
481#[track_caller]
482pub fn get_broadcast_axes_from(
483    a_shape: &[i64],
484    res_shape: &[i64],
485) -> std::result::Result<Vec<usize>, TensorError> {
486    assert!(a_shape.len() <= res_shape.len());
487
488    let padded_a = try_pad_shape(a_shape, res_shape.len());
489
490    let mut axes = Vec::new();
491    let padded_axes = (0..res_shape.len() - a_shape.len()).collect::<Vec<usize>>();
492    for i in padded_axes.iter() {
493        axes.push(*i);
494    }
495
496    for (i, (&res_dim, &a_dim)) in res_shape.iter().zip(&padded_a).enumerate() {
497        if a_dim == 1 && res_dim != 1 && !padded_axes.contains(&i) {
498            axes.push(i);
499        } else if res_dim == 1 && a_dim != 1 {
500            return Err(ShapeError::BroadcastError {
501                message: format!(
502                    "broadcast failed at index {}, lhs shape: {:?}, rhs shape: {:?}",
503                    i, a_shape, res_shape
504                ),
505                location: Location::caller(),
506            }
507            .into());
508        }
509    }
510
511    Ok(axes)
512}
513
514// This file contains code translated from NumPy (https://github.com/numpy/numpy)
515// Original work Copyright (c) 2005-2025, NumPy Developers
516// Modified work Copyright (c) 2025 hpt Contributors
517//
518// Redistribution and use in source and binary forms, with or without
519// modification, are permitted provided that the following conditions are
520// met:
521
522//     * Redistributions of source code must retain the above copyright
523//        notice, this list of conditions and the following disclaimer.
524
525//     * Redistributions in binary form must reproduce the above
526//        copyright notice, this list of conditions and the following
527//        disclaimer in the documentation and/or other materials provided
528//        with the distribution.
529
530//     * Neither the name of the NumPy Developers nor the names of any
531//        contributors may be used to endorse or promote products derived
532//        from this software without specific prior written permission.
533
534// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
535// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
536// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
537// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
538// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
539// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
540// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
541// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
542// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
543// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
544// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
545//
546// This Rust port is additionally licensed under Apache-2.0 OR MIT
547// See repository root for details
548
549/// Attempt to reshape an array without copying data.
550/// Translated from NumPy's _attempt_nocopy_reshape function.
551pub fn is_reshape_possible(
552    original_shape: &[i64],
553    original_strides: &[i64],
554    new_shape: &[i64],
555) -> Option<Strides> {
556    let mut new_strides = vec![0; new_shape.len()];
557    let mut old_strides = vec![0; original_shape.len()];
558    let mut old_shape = vec![0; original_shape.len()];
559
560    let mut oi = 0;
561    let mut oj = 1;
562    let mut ni = 0;
563    let mut nj = 1;
564
565    let mut oldnd = 0;
566
567    for i in 0..original_shape.len() {
568        if original_shape[i] != 1 {
569            old_shape[oldnd] = original_shape[i];
570            old_strides[oldnd] = original_strides[i];
571            oldnd += 1;
572        }
573    }
574
575    while ni < new_shape.len() && oi < oldnd {
576        let mut np = new_shape[ni];
577        let mut op = old_shape[oi];
578
579        while np != op {
580            if np < op {
581                np *= new_shape[nj];
582                nj += 1;
583            } else {
584                op *= old_shape[oj];
585                oj += 1;
586            }
587        }
588
589        for i in oi..oj - 1 {
590            if old_strides[i] != old_shape[i + 1] * old_strides[i + 1] {
591                return None;
592            }
593        }
594
595        new_strides[nj - 1] = old_strides[oj - 1];
596        for i in (ni + 1..nj).rev() {
597            new_strides[i - 1] = new_strides[i] * new_shape[i];
598        }
599
600        ni = nj;
601        nj += 1;
602        oi = oj;
603        oj += 1;
604    }
605
606    let last_stride = if ni >= 1 { new_strides[ni - 1] } else { 1 };
607
608    for i in ni..new_shape.len() {
609        new_strides[i] = last_stride;
610    }
611
612    Some(new_strides.into())
613}
614
615/// Generates intervals for multi-threaded processing by dividing the outer loop into chunks.
616///
617/// The `mt_intervals` function divides a large outer loop into multiple smaller intervals to be
618/// processed by multiple threads. The function aims to distribute the workload as evenly as possible
619/// among the available threads, handling cases where the total number of iterations is not perfectly
620/// divisible by the number of threads.
621///
622/// # Parameters
623///
624/// - `outer_loop_size`: The total number of iterations in the outer loop.
625/// - `num_threads`: The number of threads to divide the work among.
626///
627/// # Returns
628///
629/// A `Vec` of tuples `(usize, usize)`, where each tuple represents the start (inclusive) and end
630/// (exclusive) indices of the interval assigned to each thread.
631///
632/// # Algorithm Overview
633///
634/// 1. **Calculate Base Workload**: Each thread is assigned at least `outer_loop_size / num_threads` iterations.
635/// 2. **Distribute Remainder**: If `outer_loop_size` is not divisible by `num_threads`, the remaining iterations
636///    (`outer_loop_size % num_threads`) are distributed one by one to the first few threads.
637/// 3. **Calculate Start and End Indices**:
638///    - The `start_index` for each thread `i` is calculated as:
639///      ```
640///      i * (outer_loop_size / num_threads) + min(i, outer_loop_size % num_threads)
641///      ```
642///    - The `end_index` is then calculated by adding the base workload and an extra iteration if the thread
643///      received an extra iteration from the remainder.
644///
645/// # Examples
646///
647/// ```rust
648/// fn main() {
649///     let outer_loop_size = 10;
650///     let num_threads = 3;
651///
652///     let intervals = mt_intervals(outer_loop_size, num_threads);
653///
654///     for (i, (start, end)) in intervals.iter().enumerate() {
655///         println!("Thread {}: Processing indices [{}..{})", i, start, end);
656///     }
657/// }
658/// ```
659///
660/// Output:
661///
662/// ```text
663/// Thread 0: Processing indices [0..4)
664/// Thread 1: Processing indices [4..7)
665/// Thread 2: Processing indices [7..10)
666/// ```
667///
668/// In this example:
669/// - The total number of iterations is 10.
670/// - The number of threads is 3.
671/// - Each thread gets at least `10 / 3 = 3` iterations.
672/// - The remainder is `10 % 3 = 1`. So, the first thread gets one extra iteration.
673///
674/// # Notes
675///
676/// - **Workload Balance**: The function ensures that the workload is distributed as evenly as possible.
677/// - **Integer Division**: Since integer division truncates towards zero, the remainder is used to distribute
678///   the extra iterations.
679/// - **Index Calculation**: The calculation uses `std::cmp::min` to ensure that only the first `remainder` threads
680///   receive the extra iteration.
681///
682/// # Function Definition
683///
684/// ```rust
685/// pub fn mt_intervals(outer_loop_size: usize, num_threads: usize) -> Vec<(usize, usize)> {
686///     let mut intervals = Vec::with_capacity(num_threads);
687///     for i in 0..num_threads {
688///         let start_index = i * (outer_loop_size / num_threads)
689///             + std::cmp::min(i, outer_loop_size % num_threads);
690///         let end_index = start_index
691///             + outer_loop_size / num_threads
692///             + ((i < outer_loop_size % num_threads) as usize);
693///         intervals.push((start_index, end_index));
694///     }
695///     intervals
696/// }
697/// ```
698///
699/// # Unit Tests
700///
701/// Here are some unit tests to verify the correctness of the function:
702///
703/// ```rust
704/// #[cfg(test)]
705/// mod tests {
706///     use super::*;
707///
708///     #[test]
709///     fn test_even_division() {
710///         let intervals = mt_intervals(100, 4);
711///         assert_eq!(intervals.len(), 4);
712///         assert_eq!(intervals[0], (0, 25));
713///         assert_eq!(intervals[1], (25, 50));
714///         assert_eq!(intervals[2], (50, 75));
715///         assert_eq!(intervals[3], (75, 100));
716///     }
717///
718///     #[test]
719///     fn test_uneven_division() {
720///         let intervals = mt_intervals(10, 3);
721///         assert_eq!(intervals.len(), 3);
722///         assert_eq!(intervals[0], (0, 4));
723///         assert_eq!(intervals[1], (4, 7));
724///         assert_eq!(intervals[2], (7, 10));
725///     }
726///
727///     #[test]
728///     fn test_more_threads_than_work() {
729///         let intervals = mt_intervals(5, 10);
730///         assert_eq!(intervals.len(), 10);
731///         assert_eq!(intervals[0], (0, 1));
732///         assert_eq!(intervals[1], (1, 2));
733///         assert_eq!(intervals[2], (2, 3));
734///         assert_eq!(intervals[3], (3, 4));
735///         assert_eq!(intervals[4], (4, 5));
736///         for i in 5..10 {
737///             assert_eq!(intervals[i], (5, 5));
738///         }
739///     }
740///
741///     #[test]
742///     fn test_zero_iterations() {
743///         let intervals = mt_intervals(0, 4);
744///         assert_eq!(intervals.len(), 4);
745///         for &(start, end) in &intervals {
746///             assert_eq!(start, 0);
747///             assert_eq!(end, 0);
748///         }
749///     }
750///
751///     #[test]
752///     fn test_zero_threads() {
753///         let intervals = mt_intervals(10, 0);
754///         assert_eq!(intervals.len(), 0);
755///     }
756/// }
757/// ```
758///
759/// # Caveats
760///
761/// - If `num_threads` is zero, the function will return an empty vector.
762/// - If `outer_loop_size` is zero, all intervals will have start and end indices of zero.
763///
764/// # Performance Considerations
765///
766/// - **Allocation**: The function pre-allocates the vector with capacity `num_threads`.
767/// - **Integer Operations**: The function uses integer division and modulo operations, which are efficient.
768///
769/// # Conclusion
770///
771/// The `mt_intervals` function is useful for dividing work among multiple threads in a balanced way, ensuring that
772/// each thread gets a fair share of the workload, even when the total number of iterations is not perfectly divisible
773/// by the number of threads.
774
775pub fn mt_intervals(outer_loop_size: usize, num_threads: usize) -> Vec<(usize, usize)> {
776    let mut intervals = Vec::with_capacity(num_threads);
777    for i in 0..num_threads {
778        let start_index =
779            i * (outer_loop_size / num_threads) + std::cmp::min(i, outer_loop_size % num_threads);
780        let end_index = start_index
781            + outer_loop_size / num_threads
782            + ((i < outer_loop_size % num_threads) as usize);
783        intervals.push((start_index, end_index));
784    }
785    intervals
786}
787
788/// Generates intervals for multi-threaded SIMD processing by dividing the outer loop into chunks.
789///
790/// The `mt_intervals_simd` function divides a large outer loop into multiple smaller intervals
791/// to be processed by multiple threads. Each interval is aligned with the SIMD vector size to
792/// optimize performance. This ensures that each thread processes a chunk of data that is a
793/// multiple of the SIMD vector size, which is beneficial for vectorized operations.
794///
795/// # Parameters
796///
797/// - `outer_loop_size`: The total size of the outer loop (number of iterations).
798/// - `num_threads`: The desired number of threads to use for processing.
799/// - `vec_size`: The size of the SIMD vector (number of elements processed in one SIMD operation).
800///
801/// # Returns
802///
803/// A `Vec` of tuples `(usize, usize)`, where each tuple represents the start (inclusive) and
804/// end (exclusive) indices of the interval assigned to a thread.
805///
806/// # Algorithm Overview
807///
808/// 1. **Determine Maximum Threads**: Calculate `max_threads` as `outer_loop_size / vec_size` to
809///    ensure each thread has at least one full SIMD vector's worth of work.
810/// 2. **Adjust Thread Count**: Set `actual_threads` to the minimum of `num_threads` and
811///    `max_threads` to avoid creating more threads than necessary.
812/// 3. **Calculate Base Block Count and Remainder**:
813///    - `base_block_count` is the number of full blocks each thread will process.
814///    - `remainder` is the number of remaining blocks that couldn't be evenly divided.
815/// 4. **Assign Intervals to Threads**:
816///    - Distribute the extra blocks from the remainder among the first `remainder` threads.
817///    - Calculate `start_index` and `end_index` for each thread accordingly.
818///
819/// # Examples
820///
821/// ```rust
822/// fn main() {
823///     let outer_loop_size = 1000;
824///     let num_threads = 4;
825///     let vec_size = 8;
826///
827///     let intervals = mt_intervals_simd(outer_loop_size, num_threads, vec_size);
828///
829///     for (i, (start, end)) in intervals.iter().enumerate() {
830///         println!("Thread {}: Processing indices [{}..{})", i, start, end);
831///     }
832/// }
833/// ```
834///
835/// Output might be:
836///
837/// ```text
838/// Thread 0: Processing indices [0..200)
839/// Thread 1: Processing indices [200..400)
840/// Thread 2: Processing indices [400..600)
841/// Thread 3: Processing indices [600..800)
842/// ```
843///
844/// # Notes
845///
846/// - **Data Alignment**: The function ensures that each interval's size is a multiple of `vec_size`
847///   to maintain data alignment for SIMD operations.
848/// - **Load Balancing**: Extra iterations resulting from the remainder are distributed among the
849///   first few threads to balance the workload.
850///
851/// # Panics
852///
853/// The function does not explicitly panic, but providing a `vec_size` of zero will result in a
854/// division by zero error.
855///
856/// # See Also
857///
858/// - SIMD (Single Instruction, Multiple Data) processing.
859/// - Multi-threading in Rust.
860///
861/// # Caveats
862///
863/// - Ensure that `vec_size` is not zero to avoid division by zero errors.
864/// - The function assumes that `outer_loop_size`, `num_threads`, and `vec_size` are positive integers.
865///
866/// # Performance Considerations
867///
868/// - **Thread Overhead**: Creating too many threads may introduce overhead. The function limits the
869///   number of threads to the maximum useful amount based on `outer_loop_size` and `vec_size`.
870/// - **SIMD Efficiency**: Aligning intervals to `vec_size` improves SIMD efficiency by preventing
871///   partial vector loads and stores.
872///
873/// # Conclusion
874///
875/// The `mt_intervals_simd` function is useful for parallelizing loops in applications that benefit
876/// from both multi-threading and SIMD vectorization. By carefully dividing the work into appropriately
877/// sized intervals, it helps maximize performance on modern CPUs.
878pub fn mt_intervals_simd(
879    outer_loop_size: usize,
880    num_threads: usize,
881    vec_size: usize,
882) -> Vec<(usize, usize)> {
883    assert!(vec_size > 0, "vec_size must be greater than zero");
884    assert!(num_threads > 0, "num_threads must be greater than zero");
885
886    let aligned_size = (outer_loop_size / vec_size) * vec_size;
887    let remainder = outer_loop_size - aligned_size;
888
889    let mut intervals = Vec::with_capacity(num_threads);
890
891    if aligned_size > 0 {
892        let total_vec_blocks = aligned_size / vec_size;
893        let base_blocks_per_thread = total_vec_blocks / num_threads;
894        let extra_blocks = total_vec_blocks % num_threads;
895
896        let mut start = 0;
897
898        for i in 0..num_threads {
899            let mut blocks = base_blocks_per_thread;
900
901            if i < extra_blocks {
902                blocks += 1;
903            }
904
905            let end = start + blocks * vec_size;
906            intervals.push((start, end));
907            start = end;
908        }
909
910        if remainder > 0 {
911            if let Some(last) = intervals.last_mut() {
912                *last = (last.0, last.1 + remainder);
913            }
914        }
915    }
916
917    if aligned_size == 0 && remainder > 0 {
918        if num_threads >= 1 {
919            intervals.push((0, remainder));
920            for _ in 1..num_threads {
921                intervals.push((0, 0));
922            }
923        }
924    } else if aligned_size > 0 {
925        while intervals.len() < num_threads {
926            intervals.push((aligned_size, aligned_size));
927        }
928    } else {
929        for _ in intervals.len()..num_threads {
930            intervals.push((0, 0));
931        }
932    }
933
934    intervals
935}