use alloc::{
format,
string::{
String,
ToString,
},
vec,
vec::Vec,
};
use core::{
fmt::{
Display,
Formatter,
},
panic::Location,
};
use crate::contracts::{
StackEnvironment,
expressions::{
DimExpr,
ExprDisplayAdapter,
MatchResult,
},
shape_view::ShapeView,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DimMatcher<'a> {
Any {
label_id: Option<usize>,
},
Ellipsis {
label_id: Option<usize>,
},
Expr {
label_id: Option<usize>,
expr: DimExpr<'a>,
},
}
impl<'a> DimMatcher<'a> {
pub const fn any() -> Self {
DimMatcher::Any { label_id: None }
}
pub const fn ellipsis() -> Self {
DimMatcher::Ellipsis { label_id: None }
}
pub const fn expr(expr: DimExpr<'a>) -> Self {
DimMatcher::Expr {
label_id: None,
expr,
}
}
pub const fn label_id(&self) -> Option<usize> {
match self {
DimMatcher::Any { label_id } => *label_id,
DimMatcher::Ellipsis { label_id } => *label_id,
DimMatcher::Expr { label_id, .. } => *label_id,
}
}
pub const fn with_label_id(
self,
label_id: Option<usize>,
) -> Self {
match self {
DimMatcher::Any { .. } => DimMatcher::Any { label_id },
DimMatcher::Ellipsis { .. } => DimMatcher::Ellipsis { label_id },
DimMatcher::Expr { expr, .. } => DimMatcher::Expr { label_id, expr },
}
}
}
pub struct MatcherDisplayAdapter<'a> {
index: &'a [&'a str],
matcher: &'a DimMatcher<'a>,
}
impl<'a> Display for MatcherDisplayAdapter<'a> {
fn fmt(
&self,
f: &mut Formatter<'_>,
) -> core::fmt::Result {
if let Some(label_id) = self.matcher.label_id() {
write!(f, "{}=", self.index[label_id])?;
}
match self.matcher {
DimMatcher::Any { .. } => write!(f, "_"),
DimMatcher::Ellipsis { .. } => write!(f, "..."),
DimMatcher::Expr { expr, .. } => write!(
f,
"{}",
ExprDisplayAdapter {
index: self.index,
expr
}
),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShapeContract<'a> {
pub index: &'a [&'a str],
pub terms: &'a [DimMatcher<'a>],
pub ellipsis_pos: Option<usize>,
}
impl Display for ShapeContract<'_> {
fn fmt(
&self,
f: &mut Formatter<'_>,
) -> core::fmt::Result {
write!(f, "[")?;
for (idx, term) in self.terms.iter().enumerate() {
if idx > 0 {
write!(f, ", ")?;
}
write!(
f,
"{}",
MatcherDisplayAdapter {
index: self.index,
matcher: term
}
)?;
}
write!(f, "]")
}
}
impl<'a> ShapeContract<'a> {
pub const fn new(
index: &'a [&'a str],
terms: &'a [DimMatcher<'a>],
) -> Self {
let mut i = 0;
let mut ellipsis_pos: Option<usize> = None;
while i < terms.len() {
if matches!(terms[i], DimMatcher::Ellipsis { .. }) {
match ellipsis_pos {
Some(_) => panic!("Multiple ellipses in pattern"),
None => ellipsis_pos = Some(i),
}
}
i += 1;
}
ShapeContract {
index,
terms,
ellipsis_pos,
}
}
pub fn maybe_key_to_index(
&self,
key: &str,
) -> Option<usize> {
self.index.iter().position(|&s| s == key)
}
#[track_caller]
pub fn assert_shape<'b, S>(
&'a self,
shape: S,
env: StackEnvironment<'a>,
) where
S: Into<ShapeView<'b>>,
{
let shape = shape.into();
match self._loc_try_assert_shape(&shape, env, Location::caller()) {
Ok(()) => (),
Err(msg) => panic!("{}", msg),
}
}
#[track_caller]
pub fn try_assert_shape<'b, S>(
&'a self,
shape: S,
env: StackEnvironment<'a>,
) -> Result<(), String>
where
S: Into<ShapeView<'b>>,
{
let sv = shape.into();
self._loc_try_assert_shape(&sv, env, Location::caller())
}
fn _loc_try_assert_shape(
&'a self,
shape: &ShapeView,
env: StackEnvironment<'a>,
loc: &Location<'a>,
) -> Result<(), String> {
let mut scratch: Vec<Option<isize>> = vec![None; self.index.len()];
for (k, v) in env.iter() {
let v = *v as isize;
match self.maybe_key_to_index(k) {
Some(param_id) => scratch[param_id] = Some(v),
None => {
return Err(
format!("The key \"{k}\" is not indexed in the contract:\n{self}")
.to_string(),
);
}
}
}
self.format_resolve(shape, scratch.as_mut_slice(), loc)
}
#[must_use]
#[track_caller]
pub fn unpack_shape<'b, S, const K: usize>(
&'a self,
shape: S,
keys: &[&'a str; K],
env: StackEnvironment<'a>,
) -> [usize; K]
where
S: Into<ShapeView<'b>>,
{
let sv: ShapeView = shape.into();
self._loc_unpack_shape(&sv, keys, env, Location::caller())
}
fn _loc_unpack_shape<const K: usize>(
&'a self,
shape: &ShapeView,
keys: &[&'a str; K],
env: StackEnvironment<'a>,
loc: &Location<'a>,
) -> [usize; K] {
match self._loc_try_unpack_shape(shape, keys, env, loc) {
Ok(values) => values,
Err(msg) => panic!("{msg}"),
}
}
#[track_caller]
pub fn try_unpack_shape<'b, S, const K: usize>(
&'a self,
shape: S,
keys: &[&'a str; K],
env: StackEnvironment<'a>,
) -> Result<[usize; K], String>
where
S: Into<ShapeView<'b>>,
{
let sv: ShapeView = shape.into();
self._loc_try_unpack_shape(&sv, keys, env, Location::caller())
}
fn _loc_try_unpack_shape<const K: usize>(
&'a self,
shape: &ShapeView,
keys: &[&'a str; K],
env: StackEnvironment<'a>,
loc: &Location<'a>,
) -> Result<[usize; K], String> {
let selection = self.expect_keys_to_selection(keys);
let mut scratch: Vec<Option<isize>> = vec![None; self.index.len()];
for (k, v) in env.iter() {
let v = *v as isize;
match self.maybe_key_to_index(k) {
Some(param_id) => scratch[param_id] = Some(v),
None => {
return Err(
format!("The key \"{k}\" is not indexed in the contract:\n{self}")
.to_string(),
);
}
}
}
let selected: [isize; K] =
self._loc_try_select(shape, &selection, scratch.as_mut_slice(), loc)?;
let result: [usize; K] = selected
.into_iter()
.map(|v| v as usize)
.collect::<Vec<usize>>()
.try_into()
.unwrap();
Ok(result)
}
pub fn expect_keys_to_selection<const D: usize>(
&'a self,
keys: &[&'a str; D],
) -> [usize; D] {
let mut selection = [0; D];
for (i, key) in keys.iter().enumerate() {
match self.maybe_key_to_index(key) {
Some(param_id) => selection[i] = param_id,
None => panic!("The key \"{key}\" is not indexed in the contract:\n{self}"),
}
}
selection
}
fn _loc_try_select<const K: usize>(
&'a self,
shape: &ShapeView,
selection: &[usize; K],
env: &mut [Option<isize>],
loc: &Location<'a>,
) -> Result<[isize; K], String> {
let num_slots = self.index.len();
assert_eq!(env.len(), num_slots);
self.format_resolve(shape, env, loc)?;
let mut out = [0; K];
for (i, &k) in selection.iter().enumerate() {
out[i] = env[k].unwrap();
}
Ok(out)
}
pub(crate) fn format_resolve(
&'a self,
shape: &ShapeView,
env: &mut [Option<isize>],
location: &Location,
) -> Result<(), String> {
match self._resolve(shape.as_ref(), env) {
Ok(()) => Ok(()),
Err(msg) => Err(format!(
"at {file}:{line}: Shape Error\n {msg}\nActual:\n {shape:?}\nContract:\n {self}\nBindings:\n {{{}}}",
self.index
.iter()
.zip(env.iter())
.filter(|(_, v)| v.is_some())
.map(|(k, v)| format!("\"{}\": {}", *k, v.unwrap()))
.collect::<Vec<_>>()
.join(", "),
file = location.file(),
line = location.line(),
shape = shape.as_ref(),
)),
}
}
pub fn _resolve(
&'a self,
shape: &[usize],
env: &mut [Option<isize>],
) -> Result<(), String> {
let rank = shape.len();
let fail_at = |shape_idx: usize, term_idx: usize, msg: &str| -> String {
format!(
"{} !~ {} :: {msg}",
shape[shape_idx],
MatcherDisplayAdapter {
index: self.index,
matcher: &self.terms[term_idx]
}
)
};
let (e_start, e_size) = match self.try_ellipsis_split(rank) {
Ok((e_start, e_size)) => (e_start, e_size),
Err(msg) => return Err(msg),
};
for (shape_idx, &dim_size) in shape.iter().enumerate() {
let dim_size = dim_size as isize;
let term_idx = if shape_idx < e_start {
shape_idx
} else if shape_idx < (e_start + e_size) {
continue;
} else {
shape_idx + 1 - e_size
};
let matcher = &self.terms[term_idx];
if let Some(label_id) = matcher.label_id() {
match env[label_id] {
Some(value) => {
if value != dim_size {
return Err(fail_at(shape_idx, term_idx, "Value MissMatch."));
}
}
None => {
env[label_id] = Some(dim_size);
}
}
}
let expr = match matcher {
DimMatcher::Any { .. } => continue,
DimMatcher::Expr { expr, .. } => expr,
DimMatcher::Ellipsis { .. } => {
unreachable!("Ellipsis should have been handled before")
}
};
match expr.try_match(dim_size, env) {
Ok(MatchResult::Match) => continue,
Ok(MatchResult::Conflict) => {
return Err(fail_at(shape_idx, term_idx, "Value MissMatch."));
}
Ok(MatchResult::ParamConstraint { id, value }) => {
env[id] = Some(value);
}
Err(msg) => return Err(fail_at(shape_idx, term_idx, msg)),
}
}
Ok(())
}
fn try_ellipsis_split(
&self,
rank: usize,
) -> Result<(usize, usize), String> {
let k = self.terms.len();
match self.ellipsis_pos {
None => {
if rank != k {
Err(format!("Shape rank {rank} != pattern dim count {k}",))
} else {
Ok((k, 0))
}
}
Some(pos) => {
let non_ellipsis_terms = k - 1;
if rank < non_ellipsis_terms {
return Err(format!(
"Shape rank {rank} < non-ellipsis pattern term count {non_ellipsis_terms}",
));
}
Ok((pos, rank - non_ellipsis_terms))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::contracts::DimExpr;
#[test]
fn test_unpack_shape() {
static CONTRACT: ShapeContract = ShapeContract::new(
&["b", "h", "w", "p", "z", "c"],
&[
DimMatcher::any(),
DimMatcher::expr(DimExpr::Param { id: 0 }),
DimMatcher::ellipsis(),
DimMatcher::expr(DimExpr::Prod {
children: &[DimExpr::Param { id: 1 }, DimExpr::Param { id: 3 }],
}),
DimMatcher::expr(DimExpr::Prod {
children: &[DimExpr::Param { id: 2 }, DimExpr::Param { id: 3 }],
}),
DimMatcher::expr(DimExpr::Pow {
base: &DimExpr::Param { id: 4 },
exp: 3,
}),
DimMatcher::expr(DimExpr::Param { id: 5 }),
],
);
let b = 2;
let h = 3;
let w = 2;
let p = 4;
let c = 5;
let z = 4;
let shape = [12, b, 1, 2, 3, h * p, w * p, z * z * z, c];
let env = [("p", p), ("c", c)];
CONTRACT.assert_shape(&shape, &env);
let [u_b, u_h, u_w, u_z] = CONTRACT.unpack_shape(&shape, &["b", "h", "w", "z"], &env);
assert_eq!(u_b, b);
assert_eq!(u_h, h);
assert_eq!(u_w, w);
assert_eq!(u_z, z);
}
}