use crate::{Error, Result};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Dim {
Static(usize),
Symbolic(String),
Bounded {
name: String,
min: usize,
max: usize,
},
DataDependent(String),
}
impl Dim {
pub fn value(&self) -> Option<usize> {
match self {
Self::Static(v) => Some(*v),
_ => None,
}
}
pub fn value_or(&self, default: usize) -> usize {
self.value().unwrap_or(default)
}
pub fn proves_equal(&self, other: &Self) -> bool {
match (self, other) {
(Self::Static(lhs), Self::Static(rhs)) => lhs == rhs,
(Self::Symbolic(lhs), Self::Symbolic(rhs)) => lhs == rhs,
(
Self::Bounded {
name: lhs_name,
min: lhs_min,
max: lhs_max,
},
Self::Bounded {
name: rhs_name,
min: rhs_min,
max: rhs_max,
},
) => lhs_name == rhs_name && lhs_min == rhs_min && lhs_max == rhs_max,
(Self::Static(value), Self::Bounded { min, max, .. })
| (Self::Bounded { min, max, .. }, Self::Static(value)) => {
*min <= *value && *value <= *max
}
_ => false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Shape {
pub dims: Vec<Dim>,
}
impl Shape {
pub fn new(dims: impl Into<Vec<Dim>>) -> Self {
Self { dims: dims.into() }
}
pub fn scalar() -> Self {
Self { dims: Vec::new() }
}
pub fn rank(&self) -> usize {
self.dims.len()
}
pub fn first_dim(&self) -> Option<&Dim> {
self.dims.first()
}
pub fn last_dim(&self) -> Option<&Dim> {
self.dims.last()
}
pub fn dim(&self, index: usize) -> Option<&Dim> {
self.dims.get(index)
}
pub fn ensure_rank(&self, rank: usize) -> Result<()> {
if self.rank() == rank {
Ok(())
} else {
Err(Error::shape(format!(
"shape {self:?} has rank {}, expected {rank}",
self.rank()
)))
}
}
pub fn ensure_same(&self, other: &Self) -> Result<()> {
if self == other {
Ok(())
} else {
Err(Error::shape(format!(
"shape mismatch: left={self:?}, right={other:?}"
)))
}
}
pub fn ensure_dim_proves_equal(
&self,
lhs_index: usize,
other: &Self,
rhs_index: usize,
) -> Result<()> {
let lhs = self.dims.get(lhs_index).ok_or_else(|| {
Error::shape(format!(
"missing left dimension {lhs_index} for shape {self:?}"
))
})?;
let rhs = other.dims.get(rhs_index).ok_or_else(|| {
Error::shape(format!(
"missing right dimension {rhs_index} for shape {other:?}"
))
})?;
if lhs.proves_equal(rhs) {
Ok(())
} else {
Err(Error::shape(format!(
"dimension equality is not proven: left={lhs:?}, right={rhs:?}"
)))
}
}
}
impl From<Vec<usize>> for Shape {
fn from(value: Vec<usize>) -> Self {
Self {
dims: value.into_iter().map(Dim::Static).collect(),
}
}
}
impl std::fmt::Display for Dim {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Static(value) => write!(f, "{value}"),
Self::Symbolic(name) => write!(f, "{name}"),
Self::Bounded { name, min, max } => write!(f, "{name}∈[{min}..{max}]"),
Self::DataDependent(name) => write!(f, "{name}?"),
}
}
}
impl std::fmt::Display for Shape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.dims.is_empty() {
return write!(f, "scalar");
}
let parts: Vec<String> = self.dims.iter().map(|d| d.to_string()).collect();
write!(f, "[{}]", parts.join(", "))
}
}