tract-onnx-opl 0.19.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use rand::distributions::uniform::SampleUniform;
use rand::prelude::Distribution;
use rand::rngs::SmallRng;
use rand::SeedableRng;
use rand_distr::num_traits::Float;
use rand_distr::StandardNormal;
use tract_nnef::internal::*;
use tract_nnef::ser::{array, tdims};
use tract_nnef::tract_core::trivial_op_state_freeeze;

pub fn register(registry: &mut Registry) {
    registry.register_primitive(
        "tract_onnx_random",
        &[
            TypeName::String.named("datum_type"),
            TypeName::Integer.array().named("shape"),
            TypeName::String.named("dist"),
            TypeName::Scalar.array().named("parameters"),
            TypeName::Integer.named("seed"),
        ],
        &[("output", TypeName::Scalar.tensor())],
        load,
    );
    registry.register_dumper(TypeId::of::<Random>(), dump);
}

fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
    let dt: DatumType = invocation.named_arg_as::<String>(builder, "datum_type")?.parse()?;
    let shape: TVec<TDim> = invocation.named_arg_as(builder, "shape")?;
    let fact = dt.fact(&shape);
    let dist: String = invocation.named_arg_as(builder, "dist")?;
    let parameters: TVec<Arc<Tensor>> = invocation.named_arg_as(builder, "parameters")?;
    let [p1, p2] = &*parameters else {
        bail!("Random expect two parameters")
    };
    let dist = match &*dist {
        "normal" => Dist::Normal { mean: p1.clone(), dev: p2.clone() },
        "uniform" => Dist::Uniform { low: p1.clone(), high: p2.clone() },
        _ => bail!("Unexpected distribution {}", dist),
    };
    let seed = invocation.get_named_arg_as(builder, "seed")?;
    let op = Random { fact, dist, seed };
    builder.wire(op, &[])
}

fn dump(_ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> {
    let op = node.op_as::<Random>().context("wrong op")?;
    let mut named = vec![
        ("datum_type", string(format!("{:?}", op.fact.datum_type))),
        ("shape", tdims(&op.fact.shape)),
    ];
    if let Some(seed) = op.seed {
        named.push(("seed", numeric(seed)));
    }
    match &op.dist {
        Dist::Uniform { low, high } => {
            named.push(("dist", string("uniform")));
            named.push((
                "parameters",
                array(&[
                    numeric(low.cast_to_scalar::<f32>()?),
                    numeric(high.cast_to_scalar::<f32>()?),
                ]),
            ));
        }
        Dist::Normal { mean, dev } => {
            named.push(("dist", string("normal")));
            named.push((
                "parameters",
                array(&[
                    numeric(mean.cast_to_scalar::<f32>()?),
                    numeric(dev.cast_to_scalar::<f32>()?),
                ]),
            ));
        }
    }
    Ok(Some(invocation("tract_onnx_random", &[], &named)))
}

#[derive(Debug, Clone, Hash)]
pub enum Dist {
    Uniform { low: Arc<Tensor>, high: Arc<Tensor> },
    Normal { mean: Arc<Tensor>, dev: Arc<Tensor> },
}

#[derive(Debug, Clone, Hash)]
pub struct Random {
    pub fact: TypedFact,
    pub dist: Dist,
    pub seed: Option<u64>,
}

impl_dyn_hash!(Random);

impl Op for Random {
    fn name(&self) -> Cow<str> {
        "Random".into()
    }

    op_as_typed_op!();
}

impl TypedOp for Random {
    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        Ok(tvec!(self.fact.clone()))
    }

    as_op!();
}

impl EvalOp for Random {
    fn is_stateless(&self) -> bool {
        false
    }

    fn state(
        &self,
        _session: &mut SessionState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        let rng = self.seed.map(SmallRng::seed_from_u64).unwrap_or_else(SmallRng::from_entropy);
        Ok(Some(Box::new(RandomState(rng))))
    }
}

#[derive(Clone, Debug)]
struct RandomState(SmallRng);

impl OpState for RandomState {
    fn eval(
        &mut self,
        session: &mut SessionState,
        op: &dyn Op,
        _inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let op = op.downcast_ref::<Random>().context("op and state mismatch")?;
        let mut tensor = unsafe {
            Tensor::uninitialized_dt(
                op.fact.datum_type,
                &op.fact.shape.eval_to_usize(&session.resolved_symbols)?,
            )?
        };
        match &op.dist {
            Dist::Uniform { low, high } => match op.fact.datum_type {
                DatumType::F32 => sample_uniform::<f32>(&mut tensor, &mut self.0, low, high)?,
                DatumType::F64 => sample_uniform::<f64>(&mut tensor, &mut self.0, low, high)?,
                DatumType::F16 => {
                    sample_uniform::<f32>(&mut tensor, &mut self.0, low, high)?;
                    tensor = tensor.cast_to::<f16>()?.into_owned();
                }
                _ => bail!("Random only support float types"),
            },
            Dist::Normal { mean, dev } => match op.fact.datum_type {
                DatumType::F32 => sample_normal::<f32>(&mut tensor, &mut self.0, mean, dev)?,
                DatumType::F64 => sample_normal::<f64>(&mut tensor, &mut self.0, mean, dev)?,
                DatumType::F16 => {
                    sample_uniform::<f32>(&mut tensor, &mut self.0, mean, dev)?;
                    tensor = tensor.cast_to::<f16>()?.into_owned();
                }
                _ => bail!("Random only support float types"),
            },
        }
        Ok(tvec!(tensor.into_tvalue()))
    }
}

trivial_op_state_freeeze!(RandomState);

fn sample_uniform<T: Datum + SampleUniform + Copy>(
    t: &mut Tensor,
    r: &mut SmallRng,
    low: &Tensor,
    high: &Tensor,
) -> TractResult<()> {
    let dist =
        rand::distributions::Uniform::new(low.cast_to_scalar::<T>()?, high.cast_to_scalar::<T>()?);
    t.as_slice_mut::<T>()?.iter_mut().zip(dist.sample_iter(r)).for_each(|(v, r)| *v = r);
    Ok(())
}

fn sample_normal<T: Datum + Float + Copy>(
    t: &mut Tensor,
    r: &mut SmallRng,
    mean: &Tensor,
    dev: &Tensor,
) -> TractResult<()>
where
    StandardNormal: Distribution<T>,
{
    let dist =
        rand_distr::Normal::<T>::new(mean.cast_to_scalar::<T>()?, dev.cast_to_scalar::<T>()?)?;
    t.as_slice_mut::<T>()?.iter_mut().zip(dist.sample_iter(r)).for_each(|(v, r)| *v = r);
    Ok(())
}