use crate::ast::{ConstDecl, Node};
use crate::onnx::convert::OnnxError;
use onnx::onnx::{NodeProto, TensorProto};
use std::collections::HashMap;
pub mod activation;
pub mod conversion;
pub mod elementwise;
pub mod matmul;
pub mod normalization;
pub mod reduction;
pub mod reshape;
pub mod utility;
use activation::ActivationHandler;
use conversion::ConversionHandler;
use elementwise::ElementwiseHandler;
use matmul::MatMulHandler;
use normalization::NormalizationHandler;
use reduction::ReductionHandler;
use reshape::ReshapeHandler;
use utility::UtilityHandler;
pub struct ConversionContext<'a> {
pub initializers: &'a HashMap<String, &'a TensorProto>,
pub value_shapes: &'a HashMap<String, Vec<i64>>,
pub const_values: &'a HashMap<String, Vec<i64>>,
pub value_ids: &'a HashMap<String, String>,
pub value_types: &'a HashMap<String, crate::ast::DataType>,
}
impl<'a> ConversionContext<'a> {
pub fn resolve_input(&self, name: &str) -> String {
if let Some(mapped) = self.value_ids.get(name) {
return mapped.clone();
}
let sanitized = crate::onnx::convert::sanitize_identifier(name);
if let Some(mapped) = self.value_ids.get(&sanitized) {
return mapped.clone();
}
sanitized
}
}
pub struct ConversionResult {
pub nodes: Vec<Node>,
pub consts: Vec<(String, ConstDecl)>,
pub output_mappings: HashMap<String, String>,
pub output_types: HashMap<String, crate::ast::DataType>,
}
impl ConversionResult {
pub fn new(nodes: Vec<Node>) -> Self {
Self {
nodes,
consts: Vec::new(),
output_mappings: HashMap::new(),
output_types: HashMap::new(),
}
}
}
pub trait OpHandler {
fn supports(&self, op_type: &str) -> bool;
fn convert<'a>(
&self,
node: &NodeProto,
context: &ConversionContext<'a>,
) -> Result<ConversionResult, OnnxError>;
}
pub struct OpRegistry {
handlers: Vec<Box<dyn OpHandler>>,
}
impl OpRegistry {
pub fn new() -> Self {
let handlers: Vec<Box<dyn OpHandler>> = vec![
Box::new(MatMulHandler),
Box::new(ElementwiseHandler),
Box::new(NormalizationHandler),
Box::new(ReshapeHandler),
Box::new(ConversionHandler),
Box::new(UtilityHandler),
Box::new(ReductionHandler),
Box::new(ActivationHandler),
];
OpRegistry { handlers }
}
pub fn convert_node<'a>(
&self,
node: &NodeProto,
context: &ConversionContext<'a>,
) -> Result<ConversionResult, OnnxError> {
let op_type = node.get_op_type();
for handler in &self.handlers {
if handler.supports(op_type) {
return handler.convert(node, context);
}
}
let node_name = if node.has_name() {
node.get_name().to_string()
} else {
"<unnamed>".to_string()
};
Err(OnnxError::UnsupportedOp {
op: op_type.to_string(),
node: node_name,
})
}
}
impl Default for OpRegistry {
fn default() -> Self {
Self::new()
}
}