use crate::{
shape::shape::Shape, strides::strides::Strides, strides::strides_utils::shape_to_strides,
};
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub struct Layout {
pub(crate) shape: Shape,
pub(crate) strides: Strides,
}
impl Layout {
pub fn new<A: Into<Shape>, B: Into<Strides>>(shape: A, strides: B) -> Self {
let shape = shape.into();
let strides = strides.into();
assert_eq!(shape.len(), strides.len());
Layout { shape, strides }
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn set_shape(&mut self, shape: Shape) {
self.shape = shape;
}
pub fn strides(&self) -> &Strides {
&self.strides
}
pub fn set_strides(&mut self, strides: Strides) {
self.strides = strides;
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
}
impl From<Shape> for Layout {
fn from(shape: Shape) -> Self {
let strides = shape_to_strides(&shape);
Layout { shape, strides }
}
}
impl From<&Shape> for Layout {
fn from(shape: &Shape) -> Self {
let strides = shape_to_strides(shape);
Layout {
shape: shape.clone(),
strides,
}
}
}
impl From<(Shape, Strides)> for Layout {
fn from((shape, strides): (Shape, Strides)) -> Self {
Layout { shape, strides }
}
}
impl From<(Shape, Vec<i64>)> for Layout {
fn from((shape, strides): (Shape, Vec<i64>)) -> Self {
Layout {
shape,
strides: strides.into(),
}
}
}
impl From<(&Shape, Vec<i64>)> for Layout {
fn from((shape, strides): (&Shape, Vec<i64>)) -> Self {
Layout {
shape: shape.into(),
strides: strides.into(),
}
}
}
impl From<(&Shape, &[i64])> for Layout {
fn from((shape, strides): (&Shape, &[i64])) -> Self {
Layout {
shape: shape.into(),
strides: strides.into(),
}
}
}
impl From<&(Shape, Strides)> for Layout {
fn from((shape, strides): &(Shape, Strides)) -> Self {
Layout {
shape: shape.clone(),
strides: strides.clone(),
}
}
}
impl From<&Layout> for Layout {
fn from(layout: &Layout) -> Self {
Layout {
shape: layout.shape.clone(),
strides: layout.strides.clone(),
}
}
}
impl From<(&Shape, &Strides)> for Layout {
fn from((shape, strides): (&Shape, &Strides)) -> Self {
Layout {
shape: shape.clone(),
strides: strides.clone(),
}
}
}