use std::fmt;
use std::sync::Arc;
use crate::UOp;
#[derive(Debug, Clone)]
pub enum SInt {
Const(usize),
Symbolic(Arc<UOp>),
Infer,
}
impl PartialEq for SInt {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(SInt::Const(a), SInt::Const(b)) => a == b,
(SInt::Symbolic(a), SInt::Symbolic(b)) => a.id == b.id,
(SInt::Infer, SInt::Infer) => true,
_ => false,
}
}
}
impl Eq for SInt {}
impl std::hash::Hash for SInt {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
std::mem::discriminant(self).hash(state);
match self {
SInt::Const(v) => v.hash(state),
SInt::Symbolic(uop) => uop.id.hash(state),
SInt::Infer => {}
}
}
}
impl fmt::Display for SInt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SInt::Const(v) => write!(f, "{v}"),
SInt::Symbolic(_) => write!(f, "<symbolic>"),
SInt::Infer => write!(f, "-1"),
}
}
}
impl SInt {
pub fn is_const(&self) -> bool {
matches!(self, SInt::Const(_))
}
pub fn is_infer(&self) -> bool {
matches!(self, SInt::Infer)
}
pub fn is_symbolic(&self) -> bool {
matches!(self, SInt::Symbolic(_))
}
pub fn as_const(&self) -> Option<usize> {
match self {
SInt::Const(v) => Some(*v),
SInt::Symbolic(_) | SInt::Infer => None,
}
}
pub fn as_symbolic(&self) -> Option<&Arc<UOp>> {
match self {
SInt::Symbolic(uop) => Some(uop),
SInt::Const(_) | SInt::Infer => None,
}
}
pub fn to_uop(&self, dtype: morok_dtype::DType) -> Arc<UOp> {
match self {
SInt::Const(v) => UOp::const_(dtype, crate::ConstValue::Int(*v as i64)),
SInt::Symbolic(uop) => {
if uop.dtype() != dtype {
uop.cast(dtype)
} else {
uop.clone()
}
}
SInt::Infer => panic!("cannot convert SInt::Infer to UOp — resolve -1 first"),
}
}
pub fn simplify(&self) -> Self {
match self {
SInt::Const(_) | SInt::Infer => self.clone(),
SInt::Symbolic(uop) => {
if let crate::Op::Const(const_hash) = uop.op() {
match const_hash.0 {
crate::ConstValue::Int(v) if v >= 0 => SInt::Const(v as usize),
crate::ConstValue::UInt(v) => SInt::Const(v as usize),
_ => self.clone(),
}
} else {
self.clone()
}
}
}
}
pub fn ceildiv(&self, rhs: &SInt) -> SInt {
(self + rhs - 1usize) / rhs
}
pub fn smax(&self, rhs: &SInt) -> SInt {
match (self, rhs) {
(SInt::Infer, _) | (_, SInt::Infer) => {
panic!("smax on SInt::Infer — resolve -1 before computing")
}
(SInt::Const(a), SInt::Const(b)) => SInt::Const(*a.max(b)),
_ => {
let a = self.to_uop(morok_dtype::DType::Index);
let b = rhs.to_uop(morok_dtype::DType::Index);
SInt::Symbolic(a.try_max(&b).unwrap())
}
}
}
pub fn smin(&self, rhs: &SInt) -> SInt {
match (self, rhs) {
(SInt::Infer, _) | (_, SInt::Infer) => {
panic!("smin on SInt::Infer — resolve -1 before computing")
}
(SInt::Const(a), SInt::Const(b)) => SInt::Const(*a.min(b)),
_ => {
let a = self.to_uop(morok_dtype::DType::Index);
let b = rhs.to_uop(morok_dtype::DType::Index);
let neg_max = a.neg().try_max(&b.neg()).unwrap();
SInt::Symbolic(neg_max.neg())
}
}
}
}
macro_rules! impl_sint_binop {
($trait:ident, $method:ident, $concrete_op:tt, $uop_method:ident) => {
impl std::ops::$trait for &SInt {
type Output = SInt;
fn $method(self, rhs: &SInt) -> SInt {
match (self, rhs) {
(SInt::Infer, _) | (_, SInt::Infer) => {
panic!("arithmetic on SInt::Infer — resolve -1 before computing")
}
(SInt::Const(a), SInt::Const(b)) => SInt::Const(a $concrete_op b),
_ => {
let a = self.to_uop(morok_dtype::DType::Index);
let b = rhs.to_uop(morok_dtype::DType::Index);
SInt::Symbolic(a.$uop_method(&b).unwrap())
}
}
}
}
impl std::ops::$trait for SInt {
type Output = SInt;
fn $method(self, rhs: SInt) -> SInt { (&self).$method(&rhs) }
}
impl std::ops::$trait<&SInt> for SInt {
type Output = SInt;
fn $method(self, rhs: &SInt) -> SInt { (&self).$method(rhs) }
}
impl std::ops::$trait<SInt> for &SInt {
type Output = SInt;
fn $method(self, rhs: SInt) -> SInt { self.$method(&rhs) }
}
impl std::ops::$trait<usize> for &SInt {
type Output = SInt;
fn $method(self, rhs: usize) -> SInt { self.$method(&SInt::Const(rhs)) }
}
impl std::ops::$trait<usize> for SInt {
type Output = SInt;
fn $method(self, rhs: usize) -> SInt { (&self).$method(&SInt::Const(rhs)) }
}
impl std::ops::$trait<SInt> for usize {
type Output = SInt;
fn $method(self, rhs: SInt) -> SInt { (&SInt::Const(self)).$method(&rhs) }
}
impl std::ops::$trait<&SInt> for usize {
type Output = SInt;
fn $method(self, rhs: &SInt) -> SInt { (&SInt::Const(self)).$method(rhs) }
}
};
}
impl_sint_binop!(Add, add, +, try_add);
impl_sint_binop!(Mul, mul, *, try_mul);
impl_sint_binop!(Div, div, /, try_div);
impl std::ops::Sub for &SInt {
type Output = SInt;
fn sub(self, rhs: &SInt) -> SInt {
match (self, rhs) {
(SInt::Infer, _) | (_, SInt::Infer) => {
panic!("arithmetic on SInt::Infer — resolve -1 before computing")
}
(SInt::Const(a), SInt::Const(b)) => {
assert!(a >= b, "SInt subtraction underflow: {a} - {b} would be negative");
SInt::Const(a - b)
}
_ => {
let a = self.to_uop(morok_dtype::DType::Index);
let b = rhs.to_uop(morok_dtype::DType::Index);
SInt::Symbolic(a.try_sub(&b).unwrap())
}
}
}
}
impl std::ops::Sub for SInt {
type Output = SInt;
fn sub(self, rhs: SInt) -> SInt {
(&self).sub(&rhs)
}
}
impl std::ops::Sub<&SInt> for SInt {
type Output = SInt;
fn sub(self, rhs: &SInt) -> SInt {
(&self).sub(rhs)
}
}
impl std::ops::Sub<SInt> for &SInt {
type Output = SInt;
fn sub(self, rhs: SInt) -> SInt {
self.sub(&rhs)
}
}
impl std::ops::Sub<usize> for &SInt {
type Output = SInt;
fn sub(self, rhs: usize) -> SInt {
self.sub(&SInt::Const(rhs))
}
}
impl std::ops::Sub<usize> for SInt {
type Output = SInt;
fn sub(self, rhs: usize) -> SInt {
(&self).sub(&SInt::Const(rhs))
}
}
impl std::ops::Sub<SInt> for usize {
type Output = SInt;
fn sub(self, rhs: SInt) -> SInt {
(&SInt::Const(self)).sub(&rhs)
}
}
impl std::ops::Sub<&SInt> for usize {
type Output = SInt;
fn sub(self, rhs: &SInt) -> SInt {
(&SInt::Const(self)).sub(rhs)
}
}
impl From<usize> for SInt {
fn from(value: usize) -> Self {
SInt::Const(value)
}
}
impl From<isize> for SInt {
fn from(value: isize) -> Self {
if value == -1 {
SInt::Infer
} else {
assert!(value >= 0, "negative dimension {value} is invalid (only -1 for inference is allowed)");
SInt::Const(value as usize)
}
}
}
impl From<&isize> for SInt {
fn from(value: &isize) -> Self {
SInt::from(*value)
}
}
impl From<i32> for SInt {
fn from(value: i32) -> Self {
if value == -1 {
SInt::Infer
} else {
assert!(value >= 0, "negative dimension {value} is invalid (only -1 for inference is allowed)");
SInt::Const(value as usize)
}
}
}
impl From<&i32> for SInt {
fn from(value: &i32) -> Self {
SInt::from(*value)
}
}
impl From<i64> for SInt {
fn from(value: i64) -> Self {
if value == -1 {
SInt::Infer
} else {
assert!(value >= 0, "negative dimension {value} is invalid (only -1 for inference is allowed)");
SInt::Const(value as usize)
}
}
}
impl From<&i64> for SInt {
fn from(value: &i64) -> Self {
SInt::from(*value)
}
}
impl From<&usize> for SInt {
fn from(value: &usize) -> Self {
SInt::Const(*value)
}
}
impl From<&SInt> for SInt {
fn from(value: &SInt) -> Self {
value.clone()
}
}
impl From<Arc<UOp>> for SInt {
fn from(value: Arc<UOp>) -> Self {
let sint = SInt::Symbolic(value);
sint.simplify()
}
}
impl From<&Arc<UOp>> for SInt {
fn from(value: &Arc<UOp>) -> Self {
SInt::from(value.clone())
}
}
#[derive(Debug, Clone)]
pub enum ShrinkRange {
None,
Isize(isize, isize),
Sint(SInt, SInt),
}
pub trait IntoShrinkRange {
fn into_shrink_range(self) -> ShrinkRange;
}
impl IntoShrinkRange for Option<(isize, isize)> {
fn into_shrink_range(self) -> ShrinkRange {
match self {
Some((b, e)) => ShrinkRange::Isize(b, e),
::core::option::Option::None => ShrinkRange::None,
}
}
}
impl IntoShrinkRange for &Option<(isize, isize)> {
fn into_shrink_range(self) -> ShrinkRange {
(*self).into_shrink_range()
}
}
impl IntoShrinkRange for (isize, isize) {
fn into_shrink_range(self) -> ShrinkRange {
ShrinkRange::Isize(self.0, self.1)
}
}
impl IntoShrinkRange for &(isize, isize) {
fn into_shrink_range(self) -> ShrinkRange {
ShrinkRange::Isize(self.0, self.1)
}
}
impl IntoShrinkRange for (i32, i32) {
fn into_shrink_range(self) -> ShrinkRange {
ShrinkRange::Isize(self.0 as isize, self.1 as isize)
}
}
impl IntoShrinkRange for &(i32, i32) {
fn into_shrink_range(self) -> ShrinkRange {
ShrinkRange::Isize(self.0 as isize, self.1 as isize)
}
}
impl IntoShrinkRange for (usize, usize) {
fn into_shrink_range(self) -> ShrinkRange {
ShrinkRange::Sint(SInt::Const(self.0), SInt::Const(self.1))
}
}
impl IntoShrinkRange for &(usize, usize) {
fn into_shrink_range(self) -> ShrinkRange {
ShrinkRange::Sint(SInt::Const(self.0), SInt::Const(self.1))
}
}
impl IntoShrinkRange for Option<(SInt, SInt)> {
fn into_shrink_range(self) -> ShrinkRange {
match self {
Some((b, e)) => ShrinkRange::Sint(b, e),
::core::option::Option::None => ShrinkRange::None,
}
}
}
impl IntoShrinkRange for &Option<(SInt, SInt)> {
fn into_shrink_range(self) -> ShrinkRange {
self.clone().into_shrink_range()
}
}
impl IntoShrinkRange for (SInt, SInt) {
fn into_shrink_range(self) -> ShrinkRange {
ShrinkRange::Sint(self.0, self.1)
}
}
impl IntoShrinkRange for &(SInt, SInt) {
fn into_shrink_range(self) -> ShrinkRange {
ShrinkRange::Sint(self.0.clone(), self.1.clone())
}
}
pub fn sint_prod(values: &[SInt]) -> SInt {
values.iter().fold(SInt::Const(1), |acc, v| &acc * v)
}
pub fn sint_max(values: &[SInt]) -> SInt {
assert!(!values.is_empty(), "sint_max requires at least one value");
values.iter().skip(1).fold(values[0].clone(), |acc, v| acc.smax(v))
}
pub fn sint_min(values: &[SInt]) -> SInt {
assert!(!values.is_empty(), "sint_min requires at least one value");
values.iter().skip(1).fold(values[0].clone(), |acc, v| acc.smin(v))
}