use std::sync::Arc;
use crate::types::ConstValue;
use crate::{Op, UOp};
pub fn shrink_uop(uop: &Arc<UOp>) -> Vec<Arc<UOp>> {
let mut shrunk = Vec::new();
match uop.op() {
Op::Binary(_, lhs, rhs) => {
shrunk.push(Arc::clone(lhs));
shrunk.push(Arc::clone(rhs));
shrunk.push(UOp::const_(uop.dtype().clone(), ConstValue::Int(0)));
shrunk.push(UOp::const_(uop.dtype().clone(), ConstValue::Int(1)));
}
Op::Unary(_, src) => {
shrunk.push(Arc::clone(src));
shrunk.push(UOp::const_(uop.dtype().clone(), ConstValue::Int(0)));
}
Op::Const(cv) => {
shrunk.extend(shrink_const_value(&cv.0, &uop.dtype()));
}
Op::Ternary(_, a, b, c) => {
shrunk.push(Arc::clone(b)); shrunk.push(Arc::clone(c)); shrunk.push(Arc::clone(a)); }
_ => {
}
}
shrunk
}
fn shrink_const_value(cv: &ConstValue, dtype: &morok_dtype::DType) -> Vec<Arc<UOp>> {
let mut shrunk = Vec::new();
match cv {
ConstValue::Int(v) if *v != 0 => {
if *v > 0 {
if *v > 1 {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Int(1)));
}
if *v > 10 {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Int(v / 2)));
}
} else {
if *v < -1 {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Int(-1)));
}
if *v < -10 {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Int(v / 2)));
}
}
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Int(0)));
}
ConstValue::UInt(v) if *v != 0 => {
if *v > 1 {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::UInt(1)));
}
if *v > 10 {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::UInt(v / 2)));
}
shrunk.push(UOp::const_(dtype.clone(), ConstValue::UInt(0)));
}
ConstValue::Float(v) if *v != 0.0 => {
if v.is_finite() {
if v.abs() > 1.0 {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Float(v.signum())));
}
if v.abs() > 10.0 {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Float(v / 2.0)));
}
}
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Float(0.0)));
}
ConstValue::Bool(true) => {
shrunk.push(UOp::const_(dtype.clone(), ConstValue::Bool(false)));
}
_ => {
}
}
shrunk
}
pub fn uop_depth(uop: &Arc<UOp>) -> usize {
match uop.op() {
Op::Binary(_, lhs, rhs) => 1 + uop_depth(lhs).max(uop_depth(rhs)),
Op::Unary(_, src) => 1 + uop_depth(src),
Op::Ternary(_, a, b, c) => 1 + uop_depth(a).max(uop_depth(b)).max(uop_depth(c)),
_ => 0, }
}
pub fn uop_op_count(uop: &Arc<UOp>) -> usize {
match uop.op() {
Op::Binary(_, lhs, rhs) => 1 + uop_op_count(lhs) + uop_op_count(rhs),
Op::Unary(_, src) => 1 + uop_op_count(src),
Op::Ternary(_, a, b, c) => 1 + uop_op_count(a) + uop_op_count(b) + uop_op_count(c),
_ => 0, }
}
#[cfg(test)]
mod tests {
use super::*;
use morok_dtype::DType;
#[test]
fn test_shrink_const() {
let big_const = UOp::const_(DType::Int32, ConstValue::Int(100));
let shrunk = shrink_uop(&big_const);
assert!(!shrunk.is_empty());
assert!(shrunk.iter().any(|u| matches!(u.op(), Op::Const(cv) if cv.0 == ConstValue::Int(0))));
}
#[test]
fn test_uop_depth() {
let x = UOp::var("x", DType::Int32, 0, 100);
assert_eq!(uop_depth(&x), 0);
let x_plus_1 = UOp::new(
Op::Binary(crate::types::BinaryOp::Add, Arc::clone(&x), UOp::const_(DType::Int32, ConstValue::Int(1))),
DType::Int32,
);
assert_eq!(uop_depth(&x_plus_1), 1);
}
#[test]
fn test_uop_op_count() {
let x = UOp::var("x", DType::Int32, 0, 100);
assert_eq!(uop_op_count(&x), 0);
let x_plus_1 = UOp::new(
Op::Binary(crate::types::BinaryOp::Add, Arc::clone(&x), UOp::const_(DType::Int32, ConstValue::Int(1))),
DType::Int32,
);
assert_eq!(uop_op_count(&x_plus_1), 1);
}
}