usls 0.1.11

A Rust library integrated with ONNXRuntime, providing a collection of ML models.
Documentation
use aksr::Builder;
use anyhow::Result;
use ndarray::Axis;

use crate::{elapsed_module, Config, Engine, Image, Mask, Ops, Processor, Xs, Y};

#[derive(Builder, Debug)]
pub struct DepthPro {
    engine: Engine,
    height: usize,
    width: usize,
    batch: usize,
    spec: String,
    processor: Processor,
}

impl DepthPro {
    pub fn new(config: Config) -> Result<Self> {
        let engine = Engine::try_from_config(&config.model)?;
        let spec = engine.spec().to_string();
        let (batch, height, width) = (
            engine.batch().opt(),
            engine.try_height().unwrap_or(&512.into()).opt(),
            engine.try_width().unwrap_or(&512.into()).opt(),
        );
        let processor = Processor::try_from_config(&config.processor)?
            .with_image_width(width as _)
            .with_image_height(height as _);

        Ok(Self {
            engine,
            height,
            width,
            batch,
            spec,
            processor,
        })
    }

    fn preprocess(&mut self, xs: &[Image]) -> Result<Xs> {
        Ok(self.processor.process_images(xs)?.into())
    }

    fn inference(&mut self, xs: Xs) -> Result<Xs> {
        self.engine.run(xs)
    }

    fn postprocess(&mut self, xs: Xs) -> Result<Vec<Y>> {
        let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]);
        let predicted_depth = predicted_depth.mapv(|x| 1. / x);

        let mut ys: Vec<Y> = Vec::new();
        for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() {
            let (h1, w1) = (
                self.processor.images_transform_info[idx].height_src,
                self.processor.images_transform_info[idx].width_src,
            );
            let v = luma.into_owned().into_raw_vec_and_offset().0;
            let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
            let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
            let v = v
                .iter()
                .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
                .collect::<Vec<_>>();

            let luma = Ops::resize_luma8_u8(
                &v,
                self.width as _,
                self.height as _,
                w1 as _,
                h1 as _,
                false,
                "Bilinear",
            )?;
            ys.push(Y::default().with_masks(&[Mask::new(&luma, w1, h1)?]));
        }

        Ok(ys)
    }

    pub fn forward(&mut self, xs: &[Image]) -> Result<Vec<Y>> {
        let ys = elapsed_module!("DepthPro", "preprocess", self.preprocess(xs)?);
        let ys = elapsed_module!("DepthPro", "inference", self.inference(ys)?);
        let ys = elapsed_module!("DepthPro", "postprocess", self.postprocess(ys)?);

        Ok(ys)
    }
}