pub mod dtype;
pub mod index_conversion;
pub mod indexing;
pub mod quantization;
pub mod shape;
pub mod slice;
pub use dtype::*;
pub use index_conversion::*;
pub use indexing::*;
pub use quantization::*;
pub use shape::*;
pub use slice::*;
use alloc::{vec, vec::Vec};
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
if shape.is_empty() {
return true;
}
for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
if expected != stride {
return false;
}
}
true
}
pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
let mut strides = Vec::with_capacity(shape.len());
let mut current = 1;
for &dim in shape.iter().rev() {
strides.push(current);
current *= dim;
}
strides.reverse();
strides
}
#[derive(Debug)]
pub enum ReshapeAction {
UpdateStrides {
strides: Vec<usize>,
},
Recompute,
NoChange,
}
#[derive(Debug)]
pub enum ReshapeAnalysis {
IsContiguous,
HighlyPermutated,
Broadcasted,
SmallerRank,
NoChange,
}
impl ReshapeAnalysis {
fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
match self {
ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {
strides: contiguous_strides(shape_new),
},
ReshapeAnalysis::NoChange => ReshapeAction::NoChange,
ReshapeAnalysis::HighlyPermutated | ReshapeAnalysis::SmallerRank => {
ReshapeAction::Recompute
}
ReshapeAnalysis::Broadcasted => {
let shape_rank = shape.len();
let shape_new_rank = shape_new.len();
let n_new_batch = shape_new_rank - shape_rank;
let num_elems = shape.iter().product::<usize>();
let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);
ReshapeAction::UpdateStrides {
strides: strides_new,
}
}
}
}
}
pub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {
reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)
}
pub fn broadcast_strides(
n_new_batch: usize,
rank_prev: usize,
num_elems: usize,
strides: &[usize],
) -> Vec<usize> {
let mut strides_new = vec![num_elems; rank_prev + n_new_batch];
for (i, s) in strides.iter().enumerate() {
strides_new[i + n_new_batch] = *s;
}
strides_new
}
pub fn reshape_analysis(
shape: &[usize],
strides: Option<&[usize]>,
shape_new: &[usize],
) -> ReshapeAnalysis {
let shape_rank = shape.len();
let shape_new_rank = shape_new.len();
if shape_new_rank < shape_rank {
let is_contiguous = match strides {
Some(strides) => is_contiguous(shape, strides),
None => false,
};
return match is_contiguous {
true => ReshapeAnalysis::IsContiguous,
false => ReshapeAnalysis::SmallerRank,
};
}
let n_new_batch = shape_new_rank - shape_rank;
match n_new_batch > 0 {
true => {
if shape == &shape_new[n_new_batch..shape_new_rank]
&& shape_new[0..n_new_batch] == vec![1; n_new_batch]
{
return ReshapeAnalysis::Broadcasted;
}
}
false => {
if shape == shape_new {
return ReshapeAnalysis::NoChange;
} else {
let is_contiguous = match strides {
Some(strides) => is_contiguous(shape, strides),
None => false,
};
if is_contiguous {
return ReshapeAnalysis::IsContiguous;
}
}
}
};
ReshapeAnalysis::HighlyPermutated
}