use std::fmt;
#[derive(Clone, PartialEq, Eq, Hash, Default)]
pub struct Shape(Vec<usize>);
impl Shape {
pub fn scalar() -> Self {
Shape(vec![])
}
pub fn vector(n: usize) -> Self {
Shape(vec![n])
}
pub fn matrix(m: usize, n: usize) -> Self {
Shape(vec![m, n])
}
pub fn from_dims(dims: impl Into<Vec<usize>>) -> Self {
Shape(dims.into())
}
pub fn size(&self) -> usize {
self.0.iter().product::<usize>().max(1)
}
pub fn ndim(&self) -> usize {
self.0.len()
}
pub fn dims(&self) -> &[usize] {
&self.0
}
pub fn is_scalar(&self) -> bool {
self.0.is_empty()
}
pub fn is_vector(&self) -> bool {
self.0.len() == 1
}
pub fn is_matrix(&self) -> bool {
self.0.len() == 2
}
pub fn rows(&self) -> usize {
match self.0.len() {
0 => 1,
1 => self.0[0],
_ => self.0[0],
}
}
pub fn cols(&self) -> usize {
match self.0.len() {
0 => 1,
1 => 1,
_ => self.0[1],
}
}
pub fn transpose(&self) -> Self {
match self.0.len() {
0 => Shape::scalar(),
1 => Shape::matrix(1, self.0[0]),
2 => Shape::matrix(self.0[1], self.0[0]),
_ => {
let mut dims = self.0.clone();
dims.reverse();
Shape(dims)
}
}
}
pub fn broadcast(&self, other: &Shape) -> Option<Shape> {
let max_ndim = self.ndim().max(other.ndim());
let mut result = Vec::with_capacity(max_ndim);
let self_padded: Vec<usize> = std::iter::repeat(1)
.take(max_ndim - self.ndim())
.chain(self.0.iter().copied())
.collect();
let other_padded: Vec<usize> = std::iter::repeat(1)
.take(max_ndim - other.ndim())
.chain(other.0.iter().copied())
.collect();
for (a, b) in self_padded.iter().zip(other_padded.iter()) {
if *a == *b {
result.push(*a);
} else if *a == 1 {
result.push(*b);
} else if *b == 1 {
result.push(*a);
} else {
return None; }
}
Some(Shape(result))
}
pub fn matmul(&self, other: &Shape) -> Option<Shape> {
match (self.ndim(), other.ndim()) {
(2, 2) if self.cols() == other.rows() => Some(Shape::matrix(self.rows(), other.cols())),
(2, 1) if self.cols() == other.rows() => Some(Shape::vector(self.rows())),
(1, 2) if self.rows() == other.rows() => Some(Shape::vector(other.cols())),
(1, 1) if self.rows() == other.rows() => Some(Shape::scalar()),
_ => None,
}
}
}
impl fmt::Debug for Shape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Shape({:?})", self.0)
}
}
impl fmt::Display for Shape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.0.is_empty() {
write!(f, "()")
} else if self.0.len() == 1 {
write!(f, "({},)", self.0[0])
} else {
write!(f, "({}, {})", self.0[0], self.0[1])
}
}
}
impl From<()> for Shape {
fn from(_: ()) -> Self {
Shape::scalar()
}
}
impl From<usize> for Shape {
fn from(n: usize) -> Self {
Shape::vector(n)
}
}
impl From<(usize,)> for Shape {
fn from((n,): (usize,)) -> Self {
Shape::vector(n)
}
}
impl From<(usize, usize)> for Shape {
fn from((m, n): (usize, usize)) -> Self {
Shape::matrix(m, n)
}
}
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
Shape(dims)
}
}
impl From<&[usize]> for Shape {
fn from(dims: &[usize]) -> Self {
Shape(dims.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar() {
let s = Shape::scalar();
assert!(s.is_scalar());
assert_eq!(s.size(), 1);
assert_eq!(s.ndim(), 0);
assert_eq!(s.rows(), 1);
assert_eq!(s.cols(), 1);
}
#[test]
fn test_vector() {
let s = Shape::vector(5);
assert!(s.is_vector());
assert_eq!(s.size(), 5);
assert_eq!(s.ndim(), 1);
assert_eq!(s.rows(), 5);
assert_eq!(s.cols(), 1);
}
#[test]
fn test_matrix() {
let s = Shape::matrix(3, 4);
assert!(s.is_matrix());
assert_eq!(s.size(), 12);
assert_eq!(s.ndim(), 2);
assert_eq!(s.rows(), 3);
assert_eq!(s.cols(), 4);
}
#[test]
fn test_transpose() {
assert_eq!(Shape::scalar().transpose(), Shape::scalar());
assert_eq!(Shape::vector(3).transpose(), Shape::matrix(1, 3));
assert_eq!(Shape::matrix(3, 4).transpose(), Shape::matrix(4, 3));
}
#[test]
fn test_broadcast() {
assert_eq!(
Shape::vector(3).broadcast(&Shape::vector(3)),
Some(Shape::vector(3))
);
assert_eq!(
Shape::scalar().broadcast(&Shape::matrix(3, 4)),
Some(Shape::matrix(3, 4))
);
assert_eq!(
Shape::vector(4).broadcast(&Shape::matrix(3, 4)),
Some(Shape::matrix(3, 4))
);
assert_eq!(Shape::vector(3).broadcast(&Shape::vector(4)), None);
}
#[test]
fn test_matmul() {
assert_eq!(
Shape::matrix(3, 4).matmul(&Shape::matrix(4, 5)),
Some(Shape::matrix(3, 5))
);
assert_eq!(
Shape::matrix(3, 4).matmul(&Shape::vector(4)),
Some(Shape::vector(3))
);
assert_eq!(
Shape::vector(3).matmul(&Shape::vector(3)),
Some(Shape::scalar())
);
assert_eq!(Shape::matrix(3, 4).matmul(&Shape::vector(3)), None);
}
#[test]
fn test_conversions() {
let _: Shape = ().into();
let _: Shape = 5.into();
let _: Shape = (5,).into();
let _: Shape = (3, 4).into();
}
}