edgefirst-decoder 0.15.1

ML model output decoding for YOLO and ModelPack object detection and segmentation
Documentation
// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
// SPDX-License-Identifier: Apache-2.0

//! Bridge between `TensorDyn` (the HAL's fundamental type-erased tensor) and
//! the decoder's internal `ArrayViewD`/`ArrayViewDQuantized` representations.
//!
//! This module maps `TensorDyn` outputs into memory and converts them into
//! ndarray views that the existing decode methods consume.

use edgefirst_tensor::{TensorDyn, TensorMap, TensorMapTrait, TensorTrait};
use ndarray::ArrayViewD;

use super::ArrayViewDQuantized;
use crate::DecoderError;

/// Mapped tensor outputs, grouped by dtype category.
///
/// The `TensorMap` values borrow from the original `TensorDyn` tensors. The
/// ndarray views created from these maps must not outlive the `MappedOutputs`.
pub(super) enum MappedOutputs {
    /// All outputs are integer types (u8, i8, u16, i16, u32, i32).
    Quantized(Vec<QuantizedMap>),
    /// All outputs are f32.
    Float32(Vec<TensorMap<f32>>),
    /// All outputs are f64.
    Float64(Vec<TensorMap<f64>>),
}

/// A mapped quantized tensor preserving the concrete integer type.
pub(super) enum QuantizedMap {
    U8(TensorMap<u8>),
    I8(TensorMap<i8>),
    U16(TensorMap<u16>),
    I16(TensorMap<i16>),
    U32(TensorMap<u32>),
    I32(TensorMap<i32>),
}

impl QuantizedMap {
    /// Create an `ArrayViewDQuantized` borrowing from the mapped data.
    pub(super) fn as_view(&self) -> Result<ArrayViewDQuantized<'_>, DecoderError> {
        macro_rules! make_view {
            ($map:expr, $variant:ident) => {{
                let shape = $map.shape().to_vec();
                let slice = $map.as_slice();
                ArrayViewD::from_shape(shape.as_slice(), slice)
                    .map(|v| ArrayViewDQuantized::$variant(v))
                    .map_err(|e| DecoderError::InvalidConfig(format!("tensor shape: {e}")))
            }};
        }
        match self {
            Self::U8(m) => make_view!(m, UInt8),
            Self::I8(m) => make_view!(m, Int8),
            Self::U16(m) => make_view!(m, UInt16),
            Self::I16(m) => make_view!(m, Int16),
            Self::U32(m) => make_view!(m, UInt32),
            Self::I32(m) => make_view!(m, Int32),
        }
    }
}

/// Map `TensorDyn` outputs into memory, detecting whether they are quantized
/// (integer) or floating-point.
///
/// All integer types (u8, i8, u16, i16, u32, i32) are grouped as quantized.
/// Float types (f32, f64) are grouped by precision. Mixed float/integer
/// inputs are an error, except that i32 tensors mixed with f32 are allowed
/// (some models produce mixed outputs where shape/count tensors are i32).
///
/// # Errors
///
/// Returns `DecoderError::InvalidConfig` if:
/// - The output slice is empty
/// - Tensor memory mapping fails
/// - Tensor types are mixed in an unsupported way
/// - An unsupported dtype is encountered (u64, i64, f16)
pub(super) fn map_tensors(outputs: &[&TensorDyn]) -> Result<MappedOutputs, DecoderError> {
    if outputs.is_empty() {
        return Err(DecoderError::InvalidConfig("no outputs".to_string()));
    }

    // Determine the category from the first tensor
    let first_dtype = outputs[0].dtype();
    let is_float = matches!(
        first_dtype,
        edgefirst_tensor::DType::F32 | edgefirst_tensor::DType::F64
    );

    if is_float {
        map_float_tensors(outputs, first_dtype)
    } else {
        map_quantized_tensors(outputs)
    }
}

/// Map all outputs as float tensors (f32 or f64).
fn map_float_tensors(
    outputs: &[&TensorDyn],
    first_dtype: edgefirst_tensor::DType,
) -> Result<MappedOutputs, DecoderError> {
    if first_dtype == edgefirst_tensor::DType::F32 {
        let mut maps = Vec::with_capacity(outputs.len());
        for &t in outputs {
            match t {
                TensorDyn::F32(tensor) => {
                    maps.push(tensor.map().map_err(|e| {
                        DecoderError::InvalidConfig(format!("tensor map failed: {e}"))
                    })?);
                }
                // Some models have mixed f32 + i32 outputs (e.g. count tensors).
                // Skip i32 tensors silently; the decoder indexes only f32 outputs.
                TensorDyn::I32(_) => continue,
                _ => {
                    return Err(DecoderError::InvalidConfig(format!(
                        "mixed tensor types: expected f32, got {:?}",
                        t.dtype()
                    )));
                }
            }
        }
        Ok(MappedOutputs::Float32(maps))
    } else {
        // f64
        let mut maps = Vec::with_capacity(outputs.len());
        for &t in outputs {
            match t {
                TensorDyn::F64(tensor) => {
                    maps.push(tensor.map().map_err(|e| {
                        DecoderError::InvalidConfig(format!("tensor map failed: {e}"))
                    })?);
                }
                _ => {
                    return Err(DecoderError::InvalidConfig(format!(
                        "mixed tensor types: expected f64, got {:?}",
                        t.dtype()
                    )));
                }
            }
        }
        Ok(MappedOutputs::Float64(maps))
    }
}

/// Map all outputs as quantized (integer) tensors.
fn map_quantized_tensors(outputs: &[&TensorDyn]) -> Result<MappedOutputs, DecoderError> {
    let mut maps = Vec::with_capacity(outputs.len());
    for &t in outputs {
        let qmap = match t {
            TensorDyn::U8(tensor) => QuantizedMap::U8(
                tensor
                    .map()
                    .map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
            ),
            TensorDyn::I8(tensor) => QuantizedMap::I8(
                tensor
                    .map()
                    .map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
            ),
            TensorDyn::U16(tensor) => QuantizedMap::U16(
                tensor
                    .map()
                    .map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
            ),
            TensorDyn::I16(tensor) => QuantizedMap::I16(
                tensor
                    .map()
                    .map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
            ),
            TensorDyn::U32(tensor) => QuantizedMap::U32(
                tensor
                    .map()
                    .map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
            ),
            TensorDyn::I32(tensor) => QuantizedMap::I32(
                tensor
                    .map()
                    .map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
            ),
            _ => {
                return Err(DecoderError::InvalidConfig(format!(
                    "unsupported tensor dtype for quantized decode: {:?}",
                    t.dtype()
                )));
            }
        };
        maps.push(qmap);
    }
    Ok(MappedOutputs::Quantized(maps))
}

/// Convert a slice of `QuantizedMap` into `ArrayViewDQuantized` views.
pub(super) fn quantized_views(
    maps: &[QuantizedMap],
) -> Result<Vec<ArrayViewDQuantized<'_>>, DecoderError> {
    maps.iter().map(|m| m.as_view()).collect()
}

/// Convert a slice of `TensorMap<f32>` into `ArrayViewD<f32>` views.
pub(super) fn f32_views(maps: &[TensorMap<f32>]) -> Result<Vec<ArrayViewD<'_, f32>>, DecoderError> {
    maps.iter()
        .map(|m| {
            let shape = m.shape().to_vec();
            ArrayViewD::from_shape(shape.as_slice(), m.as_slice())
                .map_err(|e| DecoderError::InvalidConfig(format!("tensor shape: {e}")))
        })
        .collect()
}

/// Convert a slice of `TensorMap<f64>` into `ArrayViewD<f64>` views.
pub(super) fn f64_views(maps: &[TensorMap<f64>]) -> Result<Vec<ArrayViewD<'_, f64>>, DecoderError> {
    maps.iter()
        .map(|m| {
            let shape = m.shape().to_vec();
            ArrayViewD::from_shape(shape.as_slice(), m.as_slice())
                .map_err(|e| DecoderError::InvalidConfig(format!("tensor shape: {e}")))
        })
        .collect()
}