use super::super::types::{ExecutorResult, RawOutputs};
use crate::runtime_adapter::AdapterError;
use ndarray::{ArrayD, IxDyn};
use std::collections::HashMap;
pub fn argmax_step(data: RawOutputs, _dim: Option<usize>) -> ExecutorResult<RawOutputs> {
let tensor_map = match data {
RawOutputs::TensorMap(map) => map,
_ => {
return Err(AdapterError::InvalidInput(
"Argmax requires tensor map".to_string(),
))
}
};
let tensor = tensor_map
.values()
.next()
.ok_or_else(|| AdapterError::InvalidInput("No outputs to apply argmax".to_string()))?;
let class_id = argmax_token(tensor)?;
Ok(RawOutputs::ClassId(class_id))
}
pub fn softmax_step(data: RawOutputs, dim: Option<usize>) -> ExecutorResult<RawOutputs> {
let mut tensor_map = match data {
RawOutputs::TensorMap(map) => map,
_ => {
return Err(AdapterError::InvalidInput(
"Softmax requires tensor map".to_string(),
))
}
};
for (_name, tensor) in tensor_map.iter_mut() {
apply_softmax(tensor, dim)?;
}
Ok(RawOutputs::TensorMap(tensor_map))
}
pub fn topk_step(data: RawOutputs, k: usize, dim: Option<usize>) -> ExecutorResult<RawOutputs> {
let tensor_map = match data {
RawOutputs::TensorMap(map) => map,
_ => {
return Err(AdapterError::InvalidInput(
"TopK requires tensor map".to_string(),
))
}
};
let tensor = tensor_map
.values()
.next()
.ok_or_else(|| AdapterError::InvalidInput("No outputs for TopK".to_string()))?;
let top_k_results = top_k_predictions(tensor, k, dim)?;
let mut flattened = Vec::with_capacity(k * 2);
for (idx, score) in top_k_results {
flattened.push(idx as f32);
flattened.push(score);
}
let topk_tensor = ArrayD::from_shape_vec(IxDyn(&[k * 2]), flattened).map_err(|e| {
AdapterError::InvalidInput(format!("Failed to create TopK tensor: {:?}", e))
})?;
let mut result_map = HashMap::new();
result_map.insert("topk".to_string(), topk_tensor);
Ok(RawOutputs::TensorMap(result_map))
}
pub fn threshold_step(
data: RawOutputs,
threshold: f32,
return_indices: bool,
) -> ExecutorResult<RawOutputs> {
let tensor_map = match data {
RawOutputs::TensorMap(map) => map,
_ => {
return Err(AdapterError::InvalidInput(
"Threshold requires tensor map".to_string(),
))
}
};
let tensor = tensor_map
.values()
.next()
.ok_or_else(|| AdapterError::InvalidInput("No outputs for Threshold".to_string()))?;
let values = tensor.as_slice().ok_or_else(|| {
AdapterError::InvalidInput("Tensor is not contiguous for Threshold".to_string())
})?;
if return_indices {
let indices: Vec<f32> = values
.iter()
.enumerate()
.filter_map(|(idx, &val)| {
if val > threshold {
Some(idx as f32)
} else {
None
}
})
.collect();
let result_tensor =
ArrayD::from_shape_vec(IxDyn(&[indices.len()]), indices).map_err(|e| {
AdapterError::InvalidInput(format!("Failed to create threshold tensor: {:?}", e))
})?;
let mut result_map = HashMap::new();
result_map.insert("threshold_indices".to_string(), result_tensor);
Ok(RawOutputs::TensorMap(result_map))
} else {
let binary: Vec<f32> = values
.iter()
.map(|&val| if val > threshold { 1.0 } else { 0.0 })
.collect();
let result_tensor = ArrayD::from_shape_vec(IxDyn(tensor.shape()), binary).map_err(|e| {
AdapterError::InvalidInput(format!("Failed to create threshold mask: {:?}", e))
})?;
let mut result_map = HashMap::new();
result_map.insert("threshold_mask".to_string(), result_tensor);
Ok(RawOutputs::TensorMap(result_map))
}
}
pub fn meanpool_step(data: RawOutputs, dim: usize) -> ExecutorResult<RawOutputs> {
let tensor_map = match data {
RawOutputs::TensorMap(map) => map,
_ => {
return Err(AdapterError::InvalidInput(
"MeanPool requires tensor map".to_string(),
))
}
};
let tensor = tensor_map
.values()
.next()
.ok_or_else(|| AdapterError::InvalidInput("No outputs for MeanPool".to_string()))?;
let shape = tensor.shape();
if shape.len() != 3 {
return Err(AdapterError::InvalidInput(format!(
"MeanPool expects 3D tensor [batch, seq_len, hidden_size], got {:?}",
shape
)));
}
let batch_size = shape[0];
let seq_len = shape[1];
let hidden_size = shape[2];
if dim != 1 {
return Err(AdapterError::InvalidInput(format!(
"MeanPool only supports pooling over dim=1 (sequence), got dim={}",
dim
)));
}
let mut pooled = ArrayD::<f32>::zeros(IxDyn(&[batch_size, hidden_size]));
for b in 0..batch_size {
for h in 0..hidden_size {
let mut sum = 0.0;
for s in 0..seq_len {
sum += tensor[IxDyn(&[b, s, h])];
}
pooled[IxDyn(&[b, h])] = sum / (seq_len as f32);
}
}
let mut result_map = HashMap::new();
result_map.insert("sentence_embedding".to_string(), pooled);
Ok(RawOutputs::TensorMap(result_map))
}
pub fn denormalize_step(data: RawOutputs, mean: &[f32], std: &[f32]) -> ExecutorResult<RawOutputs> {
if mean.len() != std.len() {
return Err(AdapterError::InvalidInput(format!(
"Denormalize mean length ({}) must match std length ({})",
mean.len(),
std.len()
)));
}
if mean.is_empty() {
return Err(AdapterError::InvalidInput(
"Denormalize requires at least one mean/std value".to_string(),
));
}
let mut tensor_map = match data {
RawOutputs::TensorMap(map) => map,
_ => {
return Err(AdapterError::InvalidInput(
"Denormalize requires tensor map input".to_string(),
))
}
};
for (name, tensor) in tensor_map.iter_mut() {
let tensor_slice = tensor.as_slice_mut().ok_or_else(|| {
AdapterError::InvalidInput(format!(
"Denormalize requires a contiguous tensor (output \"{}\" is non-contiguous)",
name
))
})?;
for (i, val) in tensor_slice.iter_mut().enumerate() {
let channel = i % mean.len();
*val = (*val * std[channel]) + mean[channel];
}
}
Ok(RawOutputs::TensorMap(tensor_map))
}
pub fn argmax_token(logits: &ArrayD<f32>) -> ExecutorResult<usize> {
let shape = logits.shape();
let data = logits
.as_slice()
.ok_or_else(|| AdapterError::InvalidInput("Logits tensor is not contiguous".to_string()))?;
if shape.len() == 3 {
let vocab_size = shape[2];
let start_idx = 0; let end_idx = start_idx + vocab_size;
let slice = &data[start_idx..end_idx];
let max_idx = slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
Ok(max_idx)
} else if shape.len() == 2 {
let vocab_size = shape[1];
let slice = &data[0..vocab_size];
let max_idx = slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
Ok(max_idx)
} else if shape.len() == 1 {
let max_idx = data
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
Ok(max_idx)
} else {
Err(AdapterError::InvalidInput(format!(
"Unexpected logits shape: {:?}",
shape
)))
}
}
fn apply_softmax(tensor: &mut ArrayD<f32>, dim: Option<usize>) -> ExecutorResult<()> {
let shape = tensor.shape().to_vec();
let dim = dim.unwrap_or(shape.len() - 1);
if dim >= shape.len() {
return Err(AdapterError::InvalidInput(format!(
"Softmax dimension {} out of bounds for tensor with {} dimensions",
dim,
shape.len()
)));
}
if let Some(slice) = tensor.as_slice_mut() {
if shape.len() == 1 {
softmax_1d(slice);
} else if shape.len() == 2 && dim == 1 {
let batch_size = shape[0];
let class_size = shape[1];
for batch in 0..batch_size {
let start = batch * class_size;
let end = start + class_size;
softmax_1d(&mut slice[start..end]);
}
} else {
return Err(AdapterError::InvalidInput(format!(
"Softmax only supports 1D or 2D tensors, got shape {:?}",
shape
)));
}
} else {
return Err(AdapterError::InvalidInput(
"Tensor is not contiguous, cannot apply softmax".to_string(),
));
}
Ok(())
}
fn softmax_1d(slice: &mut [f32]) {
let max = slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for val in slice.iter_mut() {
*val = (*val - max).exp();
sum += *val;
}
for val in slice.iter_mut() {
*val /= sum;
}
}
fn top_k_predictions(
tensor: &ArrayD<f32>,
k: usize,
dim: Option<usize>,
) -> ExecutorResult<Vec<(usize, f32)>> {
let shape = tensor.shape();
let _dim = dim.unwrap_or(shape.len() - 1);
let values = tensor.as_slice().ok_or_else(|| {
AdapterError::InvalidInput("Tensor is not contiguous for TopK".to_string())
})?;
let class_scores: &[f32] = if shape.len() == 1 {
values
} else if shape.len() == 2 && shape[0] == 1 {
&values[0..shape[1]]
} else {
return Err(AdapterError::InvalidInput(format!(
"TopK only supports 1D or 2D (batch=1) tensors, got shape {:?}",
shape
)));
};
let mut indexed_scores: Vec<(usize, f32)> = class_scores
.iter()
.enumerate()
.map(|(idx, &score)| (idx, score))
.collect();
indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k: Vec<(usize, f32)> = indexed_scores.into_iter().take(k).collect();
Ok(top_k)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::preprocessing::tensor::normalize_step;
use crate::execution::types::PreprocessedData;
#[test]
fn test_denormalize_step_round_trip() {
let original = vec![1.0f32, 2.0, 3.0, 4.0];
let mean = vec![2.5f32];
let std_vals = vec![1.0f32];
let orig_tensor =
ArrayD::from_shape_vec(IxDyn(&[4]), original.clone()).expect("valid shape");
let norm_data =
normalize_step(PreprocessedData::Tensor(orig_tensor), &mean, &std_vals).unwrap();
let norm_tensor = match norm_data {
PreprocessedData::Tensor(t) => t,
_ => panic!("Expected Tensor"),
};
let mut map = HashMap::new();
map.insert("output".to_string(), norm_tensor);
let result = denormalize_step(RawOutputs::TensorMap(map), &mean, &std_vals).unwrap();
match result {
RawOutputs::TensorMap(out_map) => {
let out = out_map.values().next().unwrap();
for (actual, expected) in out.iter().zip(original.iter()) {
assert!(
(actual - expected).abs() < 1e-5,
"round-trip failed: expected {}, got {}",
expected,
actual
);
}
}
_ => panic!("Expected TensorMap output"),
}
}
#[test]
fn test_denormalize_step_per_channel() {
let mean = vec![1.0f32, 2.0, 3.0];
let std_vals = vec![1.0f32, 1.0, 1.0];
let normalized = vec![-1.0f32, -1.0, -1.0, 2.0, 2.0, 2.0];
let expected = [0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let tensor = ArrayD::from_shape_vec(IxDyn(&[6]), normalized).expect("valid shape");
let mut map = HashMap::new();
map.insert("output".to_string(), tensor);
let result = denormalize_step(RawOutputs::TensorMap(map), &mean, &std_vals).unwrap();
match result {
RawOutputs::TensorMap(out_map) => {
let out = out_map.values().next().unwrap();
for (actual, exp) in out.iter().zip(expected.iter()) {
assert!(
(actual - exp).abs() < 1e-5,
"per-channel failed: expected {}, got {}",
exp,
actual
);
}
}
_ => panic!("Expected TensorMap output"),
}
}
#[test]
fn test_denormalize_step_scalar_broadcast() {
let mean = vec![0.5f32];
let std_vals = vec![2.0f32];
let tensor = ArrayD::from_shape_vec(IxDyn(&[4]), vec![0.0f32; 4]).expect("valid shape");
let mut map = HashMap::new();
map.insert("output".to_string(), tensor);
let result = denormalize_step(RawOutputs::TensorMap(map), &mean, &std_vals).unwrap();
match result {
RawOutputs::TensorMap(out_map) => {
let out = out_map.values().next().unwrap();
for &val in out.iter() {
assert!(
(val - 0.5f32).abs() < 1e-5,
"scalar broadcast failed: expected 0.5, got {}",
val
);
}
}
_ => panic!("Expected TensorMap output"),
}
}
#[test]
fn test_denormalize_step_shape_mismatch() {
let mean = vec![0.5f32, 0.5];
let std_vals = vec![1.0f32];
let tensor = ArrayD::from_shape_vec(IxDyn(&[4]), vec![0.0f32; 4]).expect("valid shape");
let mut map = HashMap::new();
map.insert("output".to_string(), tensor);
let result = denormalize_step(RawOutputs::TensorMap(map), &mean, &std_vals);
assert!(result.is_err(), "expected error on shape mismatch");
}
}