use std::sync::Arc;
use half::bf16;
use proptest::prelude::*;
use morok_dtype::{DType, ScalarDType};
use crate::UOp;
use crate::types::{BinaryOp, ConstValue, TernaryOp, UnaryOp};
use morok_dtype::ScalarDType as Scalar;
static NON_SUPPORTED: &[Scalar] = &[Scalar::FP8E4M3, Scalar::FP8E5M2, Scalar::Index, Scalar::Void];
pub fn const_(dtype: ScalarDType) -> impl Strategy<Value = ConstValue> {
use ScalarDType::*;
match dtype {
Int8 => any::<i8>().prop_map(|i| ConstValue::Int(i as i64)).boxed(),
Int16 => any::<i16>().prop_map(|i| ConstValue::Int(i as i64)).boxed(),
Int32 => any::<i32>().prop_map(|i| ConstValue::Int(i as i64)).boxed(),
Int64 => any::<i64>().prop_map(ConstValue::Int).boxed(),
UInt8 => any::<u8>().prop_map(|i| ConstValue::UInt(i as u64)).boxed(),
UInt16 => any::<u16>().prop_map(|i| ConstValue::UInt(i as u64)).boxed(),
UInt32 => any::<u32>().prop_map(|i| ConstValue::UInt(i as u64)).boxed(),
UInt64 => any::<u64>().prop_map(ConstValue::UInt).boxed(),
Bool => any::<bool>().prop_map(ConstValue::Bool).boxed(),
Float16 => any::<f32>().prop_map(|i| ConstValue::Float(half::f16::from_f32(i).to_f64())).boxed(),
BFloat16 => any::<f32>().prop_map(|i| ConstValue::Float(bf16::from_f32(i).to_f64())).boxed(),
Float32 => any::<f32>().prop_map(|i| ConstValue::Float(i as f64)).boxed(),
Float64 => any::<f64>().prop_map(ConstValue::Float).boxed(),
_ => unreachable!(),
}
}
pub fn arithmetic_sdtype() -> impl Strategy<Value = ScalarDType> {
morok_dtype::test::proptests::generators::scalar_generator()
.prop_filter("only supported types", |sdtype| !NON_SUPPORTED.contains(sdtype))
}
pub fn const_pair() -> impl Strategy<Value = (DType, ConstValue)> {
arithmetic_sdtype().prop_flat_map(|sdtype| const_(sdtype).prop_map(move |value| (DType::Scalar(sdtype), value)))
}
pub fn arb_small_int() -> impl Strategy<Value = ConstValue> {
(-10i64..=10).prop_map(ConstValue::Int)
}
pub fn nonzero_int() -> impl Strategy<Value = ConstValue> {
any::<i64>().prop_filter("non zer", |&x| x != 0).prop_map(ConstValue::Int)
}
pub fn arb_int_dtype() -> impl Strategy<Value = DType> {
morok_dtype::test::proptests::generators::int_dtype().prop_map(Into::into)
}
pub fn arb_float_dtype() -> impl Strategy<Value = DType> {
prop_oneof![Just(DType::Float16), Just(DType::Float32), Just(DType::Float64),]
}
#[derive(Debug, Clone)]
pub enum DTypeFamily {
SignedInt,
UnsignedInt,
Float,
}
impl DTypeFamily {
pub fn widening_sequence(&self) -> Vec<DType> {
match self {
Self::SignedInt => vec![DType::Int8, DType::Int16, DType::Int32, DType::Int64],
Self::UnsignedInt => vec![DType::UInt8, DType::UInt16, DType::UInt32, DType::UInt64],
Self::Float => vec![DType::Float16, DType::Float32, DType::Float64],
}
}
pub fn narrowest(&self) -> DType {
self.widening_sequence()[0].clone()
}
pub fn widest(&self) -> DType {
let seq = self.widening_sequence();
seq[seq.len() - 1].clone()
}
}
pub fn arb_dtype_family() -> impl Strategy<Value = DTypeFamily> {
prop_oneof![Just(DTypeFamily::SignedInt), Just(DTypeFamily::UnsignedInt), Just(DTypeFamily::Float),]
}
pub fn arb_binary_op() -> impl Strategy<Value = BinaryOp> {
prop_oneof![
5 => Just(BinaryOp::Add),
5 => Just(BinaryOp::Mul),
4 => Just(BinaryOp::Sub),
2 => Just(BinaryOp::Idiv),
2 => Just(BinaryOp::Mod),
3 => Just(BinaryOp::Max),
1 => Just(BinaryOp::Pow),
3 => Just(BinaryOp::Lt),
3 => Just(BinaryOp::Eq),
3 => Just(BinaryOp::Ne),
2 => Just(BinaryOp::And),
2 => Just(BinaryOp::Or),
1 => Just(BinaryOp::Xor),
]
}
pub fn arb_arithmetic_binary_op() -> impl Strategy<Value = BinaryOp> {
prop_oneof![Just(BinaryOp::Add), Just(BinaryOp::Mul), Just(BinaryOp::Sub), Just(BinaryOp::Max),]
}
pub fn arb_commutative_binary_op() -> impl Strategy<Value = BinaryOp> {
prop_oneof![
Just(BinaryOp::Add),
Just(BinaryOp::Mul),
Just(BinaryOp::Eq),
Just(BinaryOp::Ne),
Just(BinaryOp::And),
Just(BinaryOp::Or),
Just(BinaryOp::Xor),
Just(BinaryOp::Max),
]
}
pub fn arb_associative_binary_op() -> impl Strategy<Value = BinaryOp> {
prop_oneof![Just(BinaryOp::Add), Just(BinaryOp::Mul), Just(BinaryOp::And), Just(BinaryOp::Or), Just(BinaryOp::Max),]
}
pub fn arb_unary_op() -> impl Strategy<Value = UnaryOp> {
prop_oneof![
Just(UnaryOp::Neg),
Just(UnaryOp::Sqrt),
Just(UnaryOp::Exp2),
Just(UnaryOp::Log2),
Just(UnaryOp::Sin),
Just(UnaryOp::Reciprocal),
Just(UnaryOp::Trunc),
]
}
pub fn arb_ternary_op() -> impl Strategy<Value = TernaryOp> {
prop_oneof![Just(TernaryOp::Where), Just(TernaryOp::MulAcc),]
}
pub fn arb_const_uop(dtype: DType) -> impl Strategy<Value = Arc<UOp>> {
const_(dtype.scalar().unwrap()).prop_map(move |cv| UOp::const_(dtype.clone(), cv))
}
pub fn arb_var_uop(dtype: DType) -> impl Strategy<Value = Arc<UOp>> {
("[a-z]", 1i64..100).prop_map(move |(name, max_val)| UOp::var(name, dtype.clone(), 0, max_val))
}
pub fn arb_simple_uop(dtype: DType) -> impl Strategy<Value = Arc<UOp>> {
prop_oneof![arb_const_uop(dtype.clone()), arb_var_uop(dtype),]
}
pub fn arb_arithmetic_tree(dtype: DType, depth: usize) -> impl Strategy<Value = Arc<UOp>> {
let leaf = arb_simple_uop(dtype.clone());
leaf.prop_recursive(depth as u32, depth as u32 * 4, 3, move |inner| {
let _dtype = dtype.clone();
prop_oneof![
(arb_arithmetic_binary_op(), inner.clone(), inner.clone()).prop_map(move |(op, lhs, rhs)| {
match op {
BinaryOp::Add => lhs.try_add(&rhs).unwrap(),
BinaryOp::Mul => lhs.try_mul(&rhs).unwrap(),
BinaryOp::Sub => lhs.try_sub(&rhs).unwrap(),
BinaryOp::Max => lhs.try_max(&rhs).unwrap(),
_ => unreachable!("arb_arithmetic_binary_op only generates Add, Mul, Sub, Max"),
}
}),
inner.clone().prop_map(move |src| src.neg()),
]
})
}
pub fn arb_arithmetic_tree_up_to(dtype: DType, max_depth: usize) -> impl Strategy<Value = Arc<UOp>> {
(0..=max_depth).prop_flat_map(move |depth| arb_arithmetic_tree(dtype.clone(), depth))
}
pub fn arb_bounded_const(dtype: DType) -> impl Strategy<Value = Arc<UOp>> {
use morok_dtype::ScalarDType::*;
(-100i64..=100).prop_map(move |v| {
let cv = match dtype.scalar().unwrap() {
Int8 | Int16 | Int32 | Int64 | Index => ConstValue::Int(v),
UInt8 | UInt16 | UInt32 | UInt64 => ConstValue::UInt(v.unsigned_abs()),
_ => ConstValue::Int(v),
};
UOp::const_(dtype.clone(), cv)
})
}
pub fn arb_simple_uop_bounded(dtype: DType) -> impl Strategy<Value = Arc<UOp>> {
prop_oneof![arb_bounded_const(dtype.clone()), arb_var_uop(dtype),]
}
pub fn arb_arithmetic_tree_bounded(dtype: DType, depth: usize) -> impl Strategy<Value = Arc<UOp>> {
let leaf = arb_simple_uop_bounded(dtype.clone());
leaf.prop_recursive(depth as u32, depth as u32 * 4, 3, move |inner| {
prop_oneof![
(arb_arithmetic_binary_op(), inner.clone(), inner.clone()).prop_map(move |(op, lhs, rhs)| {
match op {
BinaryOp::Add => lhs.try_add(&rhs).unwrap(),
BinaryOp::Mul => lhs.try_mul(&rhs).unwrap(),
BinaryOp::Sub => lhs.try_sub(&rhs).unwrap(),
BinaryOp::Max => lhs.try_max(&rhs).unwrap(),
_ => unreachable!("arb_arithmetic_binary_op only generates Add, Mul, Sub, Max"),
}
}),
inner.clone().prop_map(move |src| src.neg()),
]
})
}
pub fn arb_arithmetic_tree_bounded_up_to(dtype: DType, max_depth: usize) -> impl Strategy<Value = Arc<UOp>> {
(0..=max_depth).prop_flat_map(move |depth| arb_arithmetic_tree_bounded(dtype.clone(), depth))
}
#[derive(Debug, Clone)]
pub enum KnownPropertyGraph {
AddZero { x: Arc<UOp>, dtype: DType },
MulOne { x: Arc<UOp>, dtype: DType },
SubZero { x: Arc<UOp>, dtype: DType },
MulZero { x: Arc<UOp>, dtype: DType },
SubSelf { x: Arc<UOp>, dtype: DType },
AddSelf { x: Arc<UOp>, dtype: DType },
}
impl KnownPropertyGraph {
pub fn build(&self) -> Arc<UOp> {
match self {
Self::AddZero { x, dtype } => {
let zero = ConstValue::zero(dtype.scalar().unwrap());
x.try_add(&UOp::const_(dtype.clone(), zero)).unwrap()
}
Self::MulOne { x, dtype } => {
let one = ConstValue::one(dtype.scalar().unwrap());
x.try_mul(&UOp::const_(dtype.clone(), one)).unwrap()
}
Self::SubZero { x, dtype } => {
let zero = ConstValue::zero(dtype.scalar().unwrap());
x.try_sub(&UOp::const_(dtype.clone(), zero)).unwrap()
}
Self::MulZero { x, dtype } => {
let zero = ConstValue::zero(dtype.scalar().unwrap());
x.try_mul(&UOp::const_(dtype.clone(), zero)).unwrap()
}
Self::SubSelf { x, .. } => x.try_sub(x).unwrap(),
Self::AddSelf { x, .. } => x.try_add(x).unwrap(),
}
}
pub fn expected_result(&self) -> Option<Arc<UOp>> {
match self {
Self::AddZero { x, .. } | Self::MulOne { x, .. } | Self::SubZero { x, .. } => Some(Arc::clone(x)),
Self::MulZero { dtype, .. } | Self::SubSelf { dtype, .. } => {
Some(UOp::const_(dtype.clone(), ConstValue::Int(0)))
}
Self::AddSelf { .. } => None, }
}
}
pub fn arb_known_property_graph() -> impl Strategy<Value = KnownPropertyGraph> {
arb_int_dtype()
.prop_flat_map(|dtype| {
arb_var_uop(dtype.clone()).prop_flat_map(move |x| {
let dtype = dtype.clone();
prop_oneof![
Just(KnownPropertyGraph::AddZero { x: Arc::clone(&x), dtype: dtype.clone() }),
Just(KnownPropertyGraph::MulOne { x: Arc::clone(&x), dtype: dtype.clone() }),
Just(KnownPropertyGraph::SubZero { x: Arc::clone(&x), dtype: dtype.clone() }),
Just(KnownPropertyGraph::MulZero { x: Arc::clone(&x), dtype: dtype.clone() }),
Just(KnownPropertyGraph::SubSelf { x: Arc::clone(&x), dtype: dtype.clone() }),
Just(KnownPropertyGraph::AddSelf { x, dtype }),
]
})
})
.boxed()
}