use std::sync::Arc;
use smallvec::{SmallVec, smallvec};
use snafu::ensure;
use crate::{ConstValue, Op, Result, SInt, UOp, error::*};
pub type Shape = SmallVec<[SInt; 4]>;
pub fn is_static(shape: &Shape) -> bool {
shape.iter().all(|dim| dim.is_const())
}
pub fn to_static(shape: &Shape) -> Option<SmallVec<[usize; 4]>> {
is_static(shape).then_some(shape.iter().map(|dim| dim.as_const().unwrap()).collect())
}
pub fn validate_shape(shape: &[isize]) -> Result<SmallVec<[usize; 4]>> {
ensure!(shape.iter().all(|&s| s >= 0), ReshapeNegativeDimensionSnafu { shape });
Ok(shape.iter().map(|&s| s as usize).collect())
}
pub fn shapes_equal(lhs: &Shape, rhs: &Shape) -> bool {
lhs == rhs
}
pub fn all_shapes_equal(shapes: &[Shape]) -> bool {
(!shapes.is_empty()) && shapes.iter().all(|s| shapes_equal(s, &shapes[0]))
}
pub fn align_shapes_left(shapes: &[Shape]) -> Vec<Shape> {
if shapes.is_empty() {
return Vec::new();
}
let max_dims = shapes.iter().map(|s| s.len()).max().unwrap();
shapes
.iter()
.map(|shape| {
let padding = max_dims - shape.len();
let mut aligned = SmallVec::with_capacity(max_dims);
aligned.extend(std::iter::repeat_n(SInt::from(1), padding));
aligned.extend(shape.iter().cloned());
aligned
})
.collect()
}
pub fn can_broadcast(lhs: &Shape, rhs: &Shape) -> bool {
if lhs.len() != rhs.len() {
return false;
}
lhs.iter().zip(rhs.iter()).all(|(l, r)| {
if let (Some(lv), Some(rv)) = (l.as_const(), r.as_const()) {
lv == rv || lv == 1 || rv == 1
} else if l == r {
true
} else {
true
}
})
}
pub fn broadcast_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape> {
use crate::error::BroadcastShapeMismatchSnafu;
use snafu::ensure;
ensure!(lhs.len() == rhs.len(), BroadcastShapeMismatchSnafu { lhs: lhs.clone(), rhs: rhs.clone() });
let mut result = SmallVec::with_capacity(lhs.len());
for (l, r) in lhs.iter().zip(rhs.iter()) {
if l == r {
result.push(l.clone());
} else if let (Some(lv), Some(rv)) = (l.as_const(), r.as_const()) {
if lv == 1 {
result.push(r.clone());
} else if rv == 1 || lv == rv {
result.push(l.clone());
} else {
return BroadcastShapeMismatchSnafu { lhs: lhs.clone(), rhs: rhs.clone() }.fail();
}
} else {
result.push(crate::sint_max(&[l.clone(), r.clone()]));
}
}
Ok(result)
}
pub fn broadcast_shapes(shapes: &[Shape]) -> Result<Shape> {
if shapes.is_empty() {
return Ok(SmallVec::new());
}
let aligned = align_shapes_left(shapes);
let mut result = aligned[0].clone();
for shape in &aligned[1..] {
result = broadcast_shape(&result, shape)?;
}
Ok(result)
}
pub fn to_vec_usize(shape: &Shape) -> Result<Vec<usize>> {
shape
.iter()
.map(|dim| {
dim.as_const().ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "shape conversion".to_string() })
})
.collect()
}
pub fn to_vec_isize(shape: &Shape) -> Result<Vec<isize>> {
shape
.iter()
.map(|dim| {
dim.as_const()
.map(|v| v as isize)
.ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "shape conversion".to_string() })
})
.collect()
}
fn extract_shape_from_uop(shape_uop: &Arc<UOp>) -> Option<Shape> {
match shape_uop.op() {
Op::Vectorize { elements } => Some(elements.into_iter().cloned().map(SInt::from).collect()),
Op::Const(const_hash) => match const_hash.0 {
ConstValue::Int(v) if v >= 0 => Some(smallvec![SInt::from(v as usize)]),
ConstValue::UInt(v) => Some(smallvec![SInt::from(v as usize)]),
_ => None,
},
Op::VConst { values } => {
let mut dims = SmallVec::with_capacity(values.len());
for val in values {
match val {
ConstValue::Int(v) if *v >= 0 => dims.push(SInt::from(*v as usize)),
ConstValue::UInt(v) => dims.push(SInt::from(*v as usize)),
_ => return None,
}
}
Some(dims)
}
_ => None,
}
}
fn extract_ranges_from_uops(begins_uop: &Arc<UOp>, ends_uop: &Arc<UOp>) -> Option<Vec<(SInt, SInt)>> {
let begins = extract_shape_from_uop(begins_uop)?;
let ends = extract_shape_from_uop(ends_uop)?;
if begins.len() != ends.len() {
return None;
}
Some(begins.into_iter().zip(ends).collect())
}
pub fn shape_to_uop(shape: &Shape) -> Arc<UOp> {
use morok_dtype::DType;
use smallvec::SmallVec;
if shape.is_empty() {
return UOp::vconst(vec![], DType::Index);
}
let elements: SmallVec<[Arc<UOp>; 4]> = shape.iter().map(|dim| dim.to_uop(DType::Index)).collect();
UOp::vectorize(elements)
}
pub fn ranges_to_uops(ranges: &[(SInt, SInt)]) -> (Arc<UOp>, Arc<UOp>) {
use morok_dtype::DType;
use smallvec::SmallVec;
assert!(!ranges.is_empty(), "ranges_to_uops does not support empty ranges (scalars); handle at callsite");
let begins: SmallVec<[Arc<UOp>; 4]> = ranges.iter().map(|(begin, _)| begin.to_uop(DType::Index)).collect();
let ends: SmallVec<[Arc<UOp>; 4]> = ranges.iter().map(|(_, end)| end.to_uop(DType::Index)).collect();
(UOp::vectorize(begins), UOp::vectorize(ends))
}
pub fn infer_shape_from_op(uop: &UOp) -> crate::Result<Option<Shape>> {
Ok(match uop.op() {
Op::Const(_) => Some(SmallVec::new()),
Op::VConst { .. } => None,
Op::Unique(_) | Op::Device(_) | Op::Noop | Op::Invalid => None,
Op::DefineLocal(_id) => {
use morok_dtype::DType;
match uop.dtype() {
DType::Ptr { size: Some(s), .. } => Some(smallvec![SInt::from(s)]),
DType::Ptr { size: None, .. } => {
let neg_one = UOp::index_const(-1);
Some(smallvec![SInt::from(neg_one)])
}
dtype => {
return crate::error::BufferDefRequiresPtrDTypeSnafu { op: "DefineLocal", dtype: dtype.clone() }
.fail();
}
}
}
Op::Unary(_, input) => input.shape()?.cloned(),
Op::Binary(op, lhs, rhs) => {
match (lhs.shape()?, rhs.shape()?) {
(Some(lhs_shape), Some(rhs_shape)) if !shapes_equal(lhs_shape, rhs_shape) => {
return BinaryShapeMismatchSnafu {
op: *op,
lhs: Box::new(lhs_shape.clone()),
rhs: Box::new(rhs_shape.clone()),
}
.fail();
}
(Some(s), _) | (_, Some(s)) => Some(s.clone()),
(None, None) => None, }
}
Op::Ternary(_, _condition, true_val, false_val) => {
let true_shape = true_val.shape()?;
let false_shape = false_val.shape()?;
match (true_shape, false_shape) {
(Some(ts), Some(fs)) if !shapes_equal(ts, fs) => {
return crate::error::TernaryBranchShapeMismatchSnafu {
true_branch: Box::new(ts.clone()),
false_branch: Box::new(fs.clone()),
}
.fail();
}
(Some(s), _) | (_, Some(s)) => Some(s.clone()),
(None, None) => None,
}
}
Op::Cast { src, .. } => src.shape()?.cloned(),
Op::BitCast { src, dtype } => {
let src_shape = src.shape()?;
match src_shape {
Some(shape) if !shape.is_empty() => {
let src_bytes = src.dtype().bytes();
let dst_bytes = dtype.bytes();
if src_bytes == dst_bytes {
Some(shape.clone())
} else {
let mut new_shape = shape.clone();
let last = new_shape.last().unwrap().clone();
let new_last = (last * SInt::Const(src_bytes)) / SInt::Const(dst_bytes);
*new_shape.last_mut().unwrap() = new_last;
Some(new_shape)
}
}
other => other.cloned(),
}
}
Op::Vectorize { .. } => None,
Op::Gep { .. } => Some(SmallVec::new()),
Op::Reshape { new_shape, .. } => {
extract_shape_from_uop(new_shape)
}
Op::Permute { axes, src } => {
let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
Some(axes.iter().map(|&i| src_shape[i].clone()).collect())
}
Op::Expand { new_shape, .. } => {
extract_shape_from_uop(new_shape)
}
Op::Pad { src, begin_pads, end_pads } => {
let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
let ranges = extract_ranges_from_uops(begin_pads, end_pads).ok_or_else(|| crate::Error::VoidTypeInOp)?;
if src_shape.len() != ranges.len() {
return Ok(None);
}
Some(
src_shape
.iter()
.zip(ranges.iter())
.map(|(dim, (begin, end))| Ok(dim + begin + end))
.collect::<crate::Result<Shape>>()?,
)
}
Op::Shrink { src, begins, ends } => {
let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
let ranges = extract_ranges_from_uops(begins, ends).ok_or_else(|| crate::Error::VoidTypeInOp)?;
if src_shape.len() != ranges.len() {
return Ok(None);
}
Some(
ranges
.iter()
.zip(src_shape.iter())
.map(|((begin, end), dim)| {
if begin.as_const() == Some(0) && end == dim {
return Ok(dim.clone());
}
Ok(end - begin)
})
.collect::<crate::Result<Shape>>()?,
)
}
Op::Flip { src, .. } => {
src.shape()?.cloned()
}
Op::Multi { src, .. } => {
src.shape()?.cloned()
}
Op::ReduceAxis { axes, src, .. } => {
let src_shape = src.shape()?.ok_or_else(|| crate::Error::VoidTypeInOp)?;
Some(
src_shape
.iter()
.enumerate()
.map(|(i, dim)| if axes.contains(&i) { SInt::from(1) } else { dim.clone() })
.collect(),
)
}
Op::Reduce { .. } => {
None
}
Op::AllReduce { src, .. } => {
src.shape()?.cloned()
}
Op::Buffer { size, .. } => Some(smallvec![SInt::from(*size)]),
Op::Param { size, .. } => Some(smallvec![SInt::from(*size)]),
Op::BufferView { size, .. } => Some(smallvec![SInt::from(*size)]),
Op::Copy { src, .. } => src.shape()?.cloned(),
Op::MStack { buffers } => match buffers.first() {
Some(b) => b.shape()?.cloned(),
None => None,
},
Op::Bufferize { ranges, .. } => {
let mut dims: Shape = SmallVec::new();
for range in ranges.iter() {
match range.op() {
Op::Range { end, .. } => {
if let Op::Const(val) = end.op() {
match val.0 {
ConstValue::Int(v) if v >= 0 => {
dims.push(SInt::Const(v as usize));
continue;
}
ConstValue::UInt(v) => {
dims.push(SInt::Const(v as usize));
continue;
}
_ => {}
}
}
dims.push(SInt::Symbolic(end.clone()));
}
Op::Const(val) => {
match val.0 {
ConstValue::Int(v) if v >= 0 => {
dims.push(SInt::Const((v + 1) as usize)); }
ConstValue::UInt(v) => {
dims.push(SInt::Const((v + 1) as usize)); }
_ => return Ok(None), }
}
_ => {
dims.push(SInt::Symbolic(range.clone()));
}
}
}
Some(dims)
}
Op::Index { .. } | Op::Load { .. } | Op::Store { .. } => None,
Op::If { .. } | Op::EndIf { .. } | Op::Range { .. } | Op::Barrier { .. } => None,
Op::End { computation, .. } => computation.shape()?.cloned(),
Op::MSelect { buffer, .. } => buffer.shape()?.cloned(),
Op::Special { .. } => None,
Op::DefineVar { .. } => Some(SmallVec::new()),
Op::Bind { value, .. } => value.shape()?.cloned(),
Op::DefineReg { size, .. } => Some(smallvec![SInt::from(*size)]),
Op::Wmma { .. } | Op::Contract { .. } | Op::Unroll { .. } => {
None
}
Op::Kernel { .. } => None,
Op::Assign { target, .. } => target.shape()?.cloned(),
Op::Detach { src } | Op::Contiguous { src, .. } | Op::ContiguousBackward { src } | Op::Precast { src } => {
src.shape()?.cloned()
}
Op::After { passthrough, .. } => passthrough.shape()?.cloned(),
Op::Custom { .. } | Op::CustomI { .. } => None,
Op::Sink { .. } => None,
Op::Group { sources } => match sources.first() {
Some(src) => src.shape()?.cloned(),
None => None,
},
Op::PointerIndex { .. } => Some(smallvec![]),
Op::Cat { .. } | Op::PtrCat { .. } => None,
})
}