use edgefirst_tensor::{TensorDyn, TensorMap, TensorMapTrait, TensorTrait};
use ndarray::ArrayViewD;
use super::ArrayViewDQuantized;
use crate::DecoderError;
pub(super) enum MappedOutputs {
Quantized(Vec<QuantizedMap>),
Float32(Vec<TensorMap<f32>>),
Float64(Vec<TensorMap<f64>>),
}
pub(super) enum QuantizedMap {
U8(TensorMap<u8>),
I8(TensorMap<i8>),
U16(TensorMap<u16>),
I16(TensorMap<i16>),
U32(TensorMap<u32>),
I32(TensorMap<i32>),
}
impl QuantizedMap {
pub(super) fn as_view(&self) -> Result<ArrayViewDQuantized<'_>, DecoderError> {
macro_rules! make_view {
($map:expr, $variant:ident) => {{
let shape = $map.shape().to_vec();
let slice = $map.as_slice();
ArrayViewD::from_shape(shape.as_slice(), slice)
.map(|v| ArrayViewDQuantized::$variant(v))
.map_err(|e| DecoderError::InvalidConfig(format!("tensor shape: {e}")))
}};
}
match self {
Self::U8(m) => make_view!(m, UInt8),
Self::I8(m) => make_view!(m, Int8),
Self::U16(m) => make_view!(m, UInt16),
Self::I16(m) => make_view!(m, Int16),
Self::U32(m) => make_view!(m, UInt32),
Self::I32(m) => make_view!(m, Int32),
}
}
}
pub(super) fn map_tensors(outputs: &[&TensorDyn]) -> Result<MappedOutputs, DecoderError> {
if outputs.is_empty() {
return Err(DecoderError::InvalidConfig("no outputs".to_string()));
}
let first_dtype = outputs[0].dtype();
let is_float = matches!(
first_dtype,
edgefirst_tensor::DType::F32 | edgefirst_tensor::DType::F64
);
if is_float {
map_float_tensors(outputs, first_dtype)
} else {
map_quantized_tensors(outputs)
}
}
fn map_float_tensors(
outputs: &[&TensorDyn],
first_dtype: edgefirst_tensor::DType,
) -> Result<MappedOutputs, DecoderError> {
if first_dtype == edgefirst_tensor::DType::F32 {
let mut maps = Vec::with_capacity(outputs.len());
for &t in outputs {
match t {
TensorDyn::F32(tensor) => {
maps.push(tensor.map().map_err(|e| {
DecoderError::InvalidConfig(format!("tensor map failed: {e}"))
})?);
}
TensorDyn::I32(_) => continue,
_ => {
return Err(DecoderError::InvalidConfig(format!(
"mixed tensor types: expected f32, got {:?}",
t.dtype()
)));
}
}
}
Ok(MappedOutputs::Float32(maps))
} else {
let mut maps = Vec::with_capacity(outputs.len());
for &t in outputs {
match t {
TensorDyn::F64(tensor) => {
maps.push(tensor.map().map_err(|e| {
DecoderError::InvalidConfig(format!("tensor map failed: {e}"))
})?);
}
_ => {
return Err(DecoderError::InvalidConfig(format!(
"mixed tensor types: expected f64, got {:?}",
t.dtype()
)));
}
}
}
Ok(MappedOutputs::Float64(maps))
}
}
fn map_quantized_tensors(outputs: &[&TensorDyn]) -> Result<MappedOutputs, DecoderError> {
let mut maps = Vec::with_capacity(outputs.len());
for &t in outputs {
let qmap = match t {
TensorDyn::U8(tensor) => QuantizedMap::U8(
tensor
.map()
.map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
),
TensorDyn::I8(tensor) => QuantizedMap::I8(
tensor
.map()
.map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
),
TensorDyn::U16(tensor) => QuantizedMap::U16(
tensor
.map()
.map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
),
TensorDyn::I16(tensor) => QuantizedMap::I16(
tensor
.map()
.map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
),
TensorDyn::U32(tensor) => QuantizedMap::U32(
tensor
.map()
.map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
),
TensorDyn::I32(tensor) => QuantizedMap::I32(
tensor
.map()
.map_err(|e| DecoderError::InvalidConfig(format!("tensor map: {e}")))?,
),
_ => {
return Err(DecoderError::InvalidConfig(format!(
"unsupported tensor dtype for quantized decode: {:?}",
t.dtype()
)));
}
};
maps.push(qmap);
}
Ok(MappedOutputs::Quantized(maps))
}
pub(super) fn quantized_views(
maps: &[QuantizedMap],
) -> Result<Vec<ArrayViewDQuantized<'_>>, DecoderError> {
maps.iter().map(|m| m.as_view()).collect()
}
pub(super) fn f32_views(maps: &[TensorMap<f32>]) -> Result<Vec<ArrayViewD<'_, f32>>, DecoderError> {
maps.iter()
.map(|m| {
let shape = m.shape().to_vec();
ArrayViewD::from_shape(shape.as_slice(), m.as_slice())
.map_err(|e| DecoderError::InvalidConfig(format!("tensor shape: {e}")))
})
.collect()
}
pub(super) fn f64_views(maps: &[TensorMap<f64>]) -> Result<Vec<ArrayViewD<'_, f64>>, DecoderError> {
maps.iter()
.map(|m| {
let shape = m.shape().to_vec();
ArrayViewD::from_shape(shape.as_slice(), m.as_slice())
.map_err(|e| DecoderError::InvalidConfig(format!("tensor shape: {e}")))
})
.collect()
}