pub mod evaluators;
use crate::onnx::convert::OnnxError;
use crate::protos::onnx::{ModelProto, NodeProto, TensorProto, TensorProto_DataType};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub enum TensorData {
Int64(Vec<i64>),
Int32(Vec<i32>),
Float32(Vec<f32>),
Float64(Vec<f64>),
UInt8(Vec<u8>),
Int8(Vec<i8>),
}
impl TensorData {
pub fn len(&self) -> usize {
match self {
TensorData::Int64(v) => v.len(),
TensorData::Int32(v) => v.len(),
TensorData::Float32(v) => v.len(),
TensorData::Float64(v) => v.len(),
TensorData::UInt8(v) => v.len(),
TensorData::Int8(v) => v.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn data_type(&self) -> TensorProto_DataType {
match self {
TensorData::Int64(_) => TensorProto_DataType::Int64,
TensorData::Int32(_) => TensorProto_DataType::Int32,
TensorData::Float32(_) => TensorProto_DataType::Float,
TensorData::Float64(_) => TensorProto_DataType::Double,
TensorData::UInt8(_) => TensorProto_DataType::Uint8,
TensorData::Int8(_) => TensorProto_DataType::Int8,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
match self {
TensorData::Int64(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
TensorData::Int32(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
TensorData::Float32(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
TensorData::Float64(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
TensorData::UInt8(v) => v.clone(),
TensorData::Int8(v) => v.iter().map(|&x| x as u8).collect(),
}
}
pub fn from_tensor_proto(tensor: &TensorProto) -> Result<Self, OnnxError> {
let raw_data = tensor.raw_data.as_slice();
let data_type = tensor.data_type;
if !raw_data.is_empty() {
match data_type {
x if x == TensorProto_DataType::Int64 as i32 => {
let values = raw_data
.chunks_exact(8)
.map(|c| {
i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
})
.collect();
Ok(TensorData::Int64(values))
}
x if x == TensorProto_DataType::Int32 as i32 => {
let values = raw_data
.chunks_exact(4)
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
Ok(TensorData::Int32(values))
}
x if x == TensorProto_DataType::Float as i32 => {
let values = raw_data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
Ok(TensorData::Float32(values))
}
x if x == TensorProto_DataType::Double as i32 => {
let values = raw_data
.chunks_exact(8)
.map(|c| {
f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
})
.collect();
Ok(TensorData::Float64(values))
}
x if x == TensorProto_DataType::Uint8 as i32 => {
Ok(TensorData::UInt8(raw_data.to_vec()))
}
x if x == TensorProto_DataType::Int8 as i32 => Ok(TensorData::Int8(
raw_data.iter().map(|&x| x as i8).collect(),
)),
_ => Err(OnnxError::TypeConversion(
webnn_onnx_utils::error::ConversionError::UnsupportedOnnxDataType(data_type),
)),
}
} else {
match data_type {
x if x == TensorProto_DataType::Int64 as i32 => {
Ok(TensorData::Int64(tensor.int64_data.as_slice().to_vec()))
}
x if x == TensorProto_DataType::Int32 as i32 => {
Ok(TensorData::Int32(tensor.int32_data.as_slice().to_vec()))
}
x if x == TensorProto_DataType::Float as i32 => {
Ok(TensorData::Float32(tensor.float_data.as_slice().to_vec()))
}
x if x == TensorProto_DataType::Double as i32 => {
Ok(TensorData::Float64(tensor.double_data.as_slice().to_vec()))
}
_ => Err(OnnxError::TypeConversion(
webnn_onnx_utils::error::ConversionError::UnsupportedOnnxDataType(data_type),
)),
}
}
}
}
#[derive(Debug, Clone)]
pub struct ConstantTensor {
pub data: TensorData,
pub shape: Vec<i64>,
pub data_type: i32,
}
impl ConstantTensor {
pub fn from_tensor_proto(tensor: &TensorProto) -> Result<Self, OnnxError> {
let data = TensorData::from_tensor_proto(tensor)?;
let shape = tensor.dims.as_slice().to_vec();
let data_type = tensor.data_type;
Ok(ConstantTensor {
data,
shape,
data_type,
})
}
pub fn to_tensor_proto(&self, name: &str) -> TensorProto {
TensorProto {
name: name.to_string(),
data_type: self.data_type,
dims: self.shape.clone(),
raw_data: self.data.to_bytes(),
..Default::default()
}
}
pub fn numel(&self) -> i64 {
if self.shape.is_empty() {
1
} else {
self.shape.iter().product()
}
}
}
#[derive(Debug)]
pub struct ConstantFoldingContext<'a> {
pub constants: HashMap<String, ConstantTensor>,
pub initializers: &'a HashMap<String, &'a TensorProto>,
}
impl<'a> ConstantFoldingContext<'a> {
pub fn new(initializers: &'a HashMap<String, &'a TensorProto>) -> Result<Self, OnnxError> {
let mut constants = HashMap::new();
for (name, tensor) in initializers.iter() {
if !tensor.raw_data.as_slice().is_empty()
|| !tensor.int64_data.as_slice().is_empty()
|| !tensor.int32_data.as_slice().is_empty()
|| !tensor.float_data.as_slice().is_empty()
|| !tensor.double_data.as_slice().is_empty()
{
match ConstantTensor::from_tensor_proto(tensor) {
Ok(ct) => {
constants.insert((*name).clone(), ct);
}
Err(e) => {
crate::debug_println!(
"Warning: Failed to parse initializer '{}': {}",
name,
e
);
}
}
}
}
Ok(ConstantFoldingContext {
constants,
initializers,
})
}
pub fn is_constant(&self, name: &str) -> bool {
self.constants.contains_key(name)
}
pub fn get_constant(&self, name: &str) -> Option<&ConstantTensor> {
self.constants.get(name)
}
pub fn add_constant(&mut self, name: String, tensor: ConstantTensor) {
self.constants.insert(name, tensor);
}
}
#[derive(Debug, Default)]
pub struct FoldingResult {
pub new_initializers: Vec<TensorProto>,
pub nodes_to_remove: HashSet<usize>,
pub nodes_folded: usize,
}
pub trait ConstantEvaluator {
fn op_type(&self) -> &str;
fn can_evaluate(&self, node: &NodeProto, ctx: &ConstantFoldingContext) -> bool;
fn evaluate(
&self,
node: &NodeProto,
ctx: &ConstantFoldingContext,
) -> Result<Vec<ConstantTensor>, OnnxError>;
}
fn build_context<'a>(
_model: &'a ModelProto,
initializers_map: &'a HashMap<String, &'a TensorProto>,
) -> Result<ConstantFoldingContext<'a>, OnnxError> {
ConstantFoldingContext::new(initializers_map)
}
fn identify_constant_nodes(
model: &ModelProto,
ctx: &ConstantFoldingContext,
evaluators: &[Box<dyn ConstantEvaluator>],
) -> Result<Vec<usize>, OnnxError> {
let graph = model.graph.as_ref().unwrap();
let mut constant_nodes = Vec::new();
for (idx, node) in graph.node.as_slice().iter().enumerate() {
let can_evaluate = evaluators.iter().any(|e| e.can_evaluate(node, ctx));
if can_evaluate {
constant_nodes.push(idx);
}
}
Ok(constant_nodes)
}
fn evaluate_constant_nodes(
model: &ModelProto,
constant_node_indices: &[usize],
ctx: &mut ConstantFoldingContext,
evaluators: &[Box<dyn ConstantEvaluator>],
) -> Result<FoldingResult, OnnxError> {
let graph = model.graph.as_ref().unwrap();
let mut result = FoldingResult::default();
for &idx in constant_node_indices {
let node = &graph.node.as_slice()[idx];
let evaluator = evaluators.iter().find(|e| e.can_evaluate(node, ctx));
if let Some(evaluator) = evaluator {
match evaluator.evaluate(node, ctx) {
Ok(output_tensors) => {
for (i, tensor) in output_tensors.iter().enumerate() {
if i < node.output.as_slice().len() {
let output_name = &node.output.as_slice()[i];
let proto = tensor.to_tensor_proto(output_name);
result.new_initializers.push(proto.clone());
ctx.add_constant(output_name.to_string(), tensor.clone());
}
}
result.nodes_to_remove.insert(idx);
result.nodes_folded += 1;
}
Err(e) => {
crate::debug_println!(
"Warning: Failed to evaluate constant node '{}' ({}): {}",
node.name.as_str(),
node.op_type.as_str(),
e
);
}
}
}
}
Ok(result)
}
pub fn fold_constants_in_model(
model: &mut ModelProto,
evaluators: &[Box<dyn ConstantEvaluator>],
) -> Result<usize, OnnxError> {
let mut total_folded = 0;
let max_iterations = 10;
let graph = model.graph.as_ref().unwrap();
let mut initializers_map: HashMap<String, &TensorProto> = HashMap::new();
for init in graph.initializer.as_slice() {
initializers_map.insert(init.name.as_str().to_string(), init);
}
for iteration in 0..max_iterations {
let initializers_map_ref: HashMap<String, &TensorProto> = model
.graph
.as_ref()
.unwrap()
.initializer
.as_slice()
.iter()
.map(|init| (init.name.as_str().to_string(), init))
.collect();
let mut ctx = build_context(model, &initializers_map_ref)?;
let constant_nodes = identify_constant_nodes(model, &ctx, evaluators)?;
if constant_nodes.is_empty() {
break;
}
let result = evaluate_constant_nodes(model, &constant_nodes, &mut ctx, evaluators)?;
if result.nodes_folded == 0 {
break;
}
let graph_mut = model.graph.as_mut().unwrap();
for init in result.new_initializers {
graph_mut.initializer.push(init);
}
let nodes = graph_mut.node.as_slice().to_vec();
graph_mut.node.clear();
for (idx, node) in nodes.into_iter().enumerate() {
if !result.nodes_to_remove.contains(&idx) {
graph_mut.node.push(node);
}
}
total_folded += result.nodes_folded;
crate::debug_println!(
"Constant folding iteration {}: {} nodes folded",
iteration + 1,
result.nodes_folded
);
}
Ok(total_folded)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_data_len() {
let data = TensorData::Int64(vec![1, 2, 3]);
assert_eq!(data.len(), 3);
let data = TensorData::Float32(vec![1.0, 2.0]);
assert_eq!(data.len(), 2);
}
#[test]
fn test_tensor_data_to_bytes() {
let data = TensorData::Int32(vec![1, 2, 3]);
let bytes = data.to_bytes();
assert_eq!(bytes.len(), 12);
let data = TensorData::Int64(vec![1, 2]);
let bytes = data.to_bytes();
assert_eq!(bytes.len(), 16); }
#[test]
fn test_constant_tensor_numel() {
let ct = ConstantTensor {
data: TensorData::Int64(vec![1, 2, 3, 4, 5, 6]),
shape: vec![2, 3],
data_type: TensorProto_DataType::Int64 as i32,
};
assert_eq!(ct.numel(), 6);
let ct = ConstantTensor {
data: TensorData::Int64(vec![42]),
shape: vec![],
data_type: TensorProto_DataType::Int64 as i32,
};
assert_eq!(ct.numel(), 1);
}
}