use std::sync::Arc;
use vyre::ir::{DataType, Ident};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct TensorRef {
pub name: Ident,
pub dtype: DataType,
pub shape: Arc<[u32]>,
}
impl TensorRef {
#[must_use]
pub fn new(name: impl Into<Ident>, dtype: DataType, shape: Vec<u32>) -> Self {
Self {
name: name.into(),
dtype,
shape: Arc::from(shape),
}
}
#[must_use]
pub fn u32_1d(name: impl Into<Ident>, len: u32) -> Self {
Self::new(name, DataType::U32, vec![len])
}
#[must_use]
pub fn f32_1d(name: impl Into<Ident>, len: u32) -> Self {
Self::new(name, DataType::F32, vec![len])
}
#[must_use]
pub fn u32_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
Self::new(name, DataType::U32, vec![rows, cols])
}
#[must_use]
pub fn f16_1d(name: impl Into<Ident>, len: u32) -> Self {
Self::new(name, DataType::F16, vec![len])
}
#[must_use]
pub fn f16_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
Self::new(name, DataType::F16, vec![rows, cols])
}
#[must_use]
pub fn f32_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
Self::new(name, DataType::F32, vec![rows, cols])
}
#[must_use]
pub fn element_count(&self) -> Option<u32> {
self.shape
.iter()
.try_fold(1u32, |acc, &dim| acc.checked_mul(dim))
}
#[must_use]
pub fn name_str(&self) -> &str {
self.name.as_str()
}
}
#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum TensorRefError {
#[error(
"TensorRef `{name}` has dtype `{found:?}`; op `{op}` expects `{expected:?}`. Fix: pass a buffer of the correct dtype or cast."
)]
DtypeMismatch {
name: String,
found: DataType,
expected: DataType,
op: &'static str,
},
#[error(
"TensorRef `{name}` has shape {found:?}; op `{op}` expects {expected:?}. Fix: reshape or pick a compatible op variant."
)]
ShapeMismatch {
name: String,
found: Vec<u32>,
expected: Vec<u32>,
op: &'static str,
},
#[error(
"TensorRef name collision in op `{op}`: `{name}` appears on multiple arguments. Fix: use distinct buffer names per argument."
)]
NameCollision {
name: String,
op: &'static str,
},
#[error(
"TensorRef `{name}` element-count overflows u32 for shape {shape:?}. Fix: reduce dimensions below the u32 boundary."
)]
ElementCountOverflow {
name: String,
shape: Vec<u32>,
},
}
pub fn check_unique_names(refs: &[&TensorRef], op: &'static str) -> Result<(), TensorRefError> {
for (idx, t) in refs.iter().enumerate() {
if refs[..idx]
.iter()
.any(|previous| previous.name_str() == t.name_str())
{
return Err(TensorRefError::NameCollision {
name: t.name.as_str().to_string(),
op,
});
}
}
Ok(())
}
pub fn check_dtype(
r: &TensorRef,
expected: DataType,
op: &'static str,
) -> Result<(), TensorRefError> {
if r.dtype != expected {
return Err(TensorRefError::DtypeMismatch {
name: r.name.as_str().to_string(),
found: r.dtype.clone(),
expected,
op,
});
}
Ok(())
}
pub fn check_shape(
r: &TensorRef,
expected: &[u32],
op: &'static str,
) -> Result<(), TensorRefError> {
if r.shape.as_ref() != expected {
return Err(TensorRefError::ShapeMismatch {
name: r.name.as_str().to_string(),
found: r.shape.to_vec(),
expected: expected.to_vec(),
op,
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn u32_1d_builder_produces_expected_fields() {
let t = TensorRef::u32_1d("x", 64);
assert_eq!(t.name.as_str(), "x");
assert_eq!(t.dtype, DataType::U32);
assert_eq!(t.shape.as_ref(), [64]);
assert_eq!(t.element_count(), Some(64));
}
#[test]
fn element_count_detects_overflow() {
let t = TensorRef::new("big", DataType::U32, vec![1u32 << 20, 1u32 << 20]);
assert_eq!(t.element_count(), None);
}
#[test]
fn check_unique_names_catches_collision() {
let a = TensorRef::u32_1d("x", 4);
let b = TensorRef::u32_1d("x", 4);
let err = check_unique_names(&[&a, &b], "test").unwrap_err();
assert!(matches!(err, TensorRefError::NameCollision { .. }));
}
#[test]
fn check_dtype_passes_on_match() {
let t = TensorRef::f32_1d("y", 8);
assert!(matches!(check_dtype(&t, DataType::F32, "op"), Ok(())));
}
#[test]
fn check_dtype_fails_on_mismatch() {
let t = TensorRef::u32_1d("y", 8);
let err = check_dtype(&t, DataType::F32, "op").unwrap_err();
assert!(matches!(err, TensorRefError::DtypeMismatch { .. }));
}
#[test]
fn check_shape_passes_and_fails() {
let t = TensorRef::u32_2d("m", 4, 8);
assert!(check_shape(&t, &[4, 8], "op").is_ok());
let err = check_shape(&t, &[4, 16], "op").unwrap_err();
assert!(matches!(err, TensorRefError::ShapeMismatch { .. }));
}
}