use std::{
env,
fs::{self, create_dir_all},
path::{Path, PathBuf},
};
use crate::{burn::graph::BurnGraph, format_tokens, logger::init_log};
use onnx_ir::{OnnxGraphBuilder, ir::OnnxGraph};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LoadStrategy {
#[default]
File,
Embedded,
Bytes,
None,
}
#[derive(Debug)]
pub struct ModelGen {
out_dir: Option<PathBuf>,
inputs: Vec<PathBuf>,
development: bool,
load_strategy: LoadStrategy,
simplify: bool,
partition: bool,
}
impl Default for ModelGen {
fn default() -> Self {
Self {
out_dir: None,
inputs: Vec::new(),
development: false,
load_strategy: LoadStrategy::default(),
simplify: true,
partition: true,
}
}
}
impl ModelGen {
pub fn new() -> Self {
init_log().ok(); Self::default()
}
pub fn out_dir(&mut self, out_dir: &str) -> &mut Self {
self.out_dir = Some(Path::new(out_dir).into());
self
}
pub fn input(&mut self, input: &str) -> &mut Self {
self.inputs.push(input.into());
self
}
pub fn development(&mut self, development: bool) -> &mut Self {
self.development = development;
self
}
pub fn load_strategy(&mut self, strategy: LoadStrategy) -> &mut Self {
self.load_strategy = strategy;
self
}
pub fn simplify(&mut self, simplify: bool) -> &mut Self {
self.simplify = simplify;
self
}
pub fn partition(&mut self, partition: bool) -> &mut Self {
self.partition = partition;
self
}
pub fn run_from_script(&self) {
self.run(true);
}
pub fn run_from_cli(&self) {
self.run(false);
}
fn run(&self, is_build_script: bool) {
log::info!("Starting to convert ONNX to Burn");
let out_dir = self.get_output_directory(is_build_script);
log::debug!("Output directory: {out_dir:?}");
create_dir_all(&out_dir).unwrap();
for input in self.inputs.iter() {
let file_name = input.file_stem().unwrap();
let out_file: PathBuf = out_dir.join(file_name);
log::info!("Converting {input:?}");
log::debug!("Input file name: {file_name:?}");
log::debug!("Output file: {out_file:?}");
self.generate_model(input, out_file);
}
log::info!("Finished converting ONNX to Burn");
}
fn get_output_directory(&self, is_build_script: bool) -> PathBuf {
if is_build_script {
let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set");
let mut path = PathBuf::from(cargo_out_dir);
path.push(self.out_dir.as_ref().unwrap());
path
} else {
self.out_dir.as_ref().expect("out_dir is not set").clone()
}
}
fn generate_model(&self, input: &PathBuf, out_file: PathBuf) {
log::info!("Generating model from {input:?}");
log::debug!("Development mode: {:?}", self.development);
log::debug!("Output file: {out_file:?}");
let graph = OnnxGraphBuilder::new()
.simplify(self.simplify)
.parse_file(input)
.unwrap_or_else(|e| panic!("Failed to parse ONNX file '{}': {}", input.display(), e));
if self.development {
self.write_debug_file(&out_file, "onnx.txt", &graph);
}
let graph = ParsedOnnxGraph(graph);
let top_comment = Some(format!("Generated from ONNX {input:?} by burn-onnx"));
let code = self.generate_burn_graph(graph, &out_file, top_comment);
let code_str = format_tokens(code);
let source_code_file = out_file.with_extension("rs");
log::info!("Writing source code to {}", source_code_file.display());
fs::write(source_code_file, code_str).unwrap();
log::info!("Model generated");
}
fn write_debug_file<T: std::fmt::Debug>(&self, out_file: &Path, extension: &str, content: &T) {
let debug_content = format!("{content:#?}");
let debug_file = out_file.with_extension(extension);
log::debug!("Writing debug file: {debug_file:?}");
fs::write(debug_file, debug_content).unwrap();
}
fn generate_burn_graph(
&self,
graph: ParsedOnnxGraph,
out_file: &Path,
top_comment: Option<String>,
) -> proc_macro2::TokenStream {
let bpk_file = out_file.with_extension("bpk");
graph
.into_burn()
.with_burnpack(bpk_file, self.load_strategy)
.with_blank_space(true)
.with_top_comment(top_comment)
.with_partition(self.partition)
.codegen()
}
}
#[derive(Debug)]
struct ParsedOnnxGraph(OnnxGraph);
impl ParsedOnnxGraph {
pub fn into_burn(self) -> BurnGraph {
let mut graph = BurnGraph::default();
for node in self.0.nodes {
graph.register(node);
}
let input_names: Vec<_> = self
.0
.inputs
.iter()
.map(|input| input.name.clone())
.collect();
let output_names: Vec<_> = self
.0
.outputs
.iter()
.map(|output| output.name.clone())
.collect();
graph.register_input_output(input_names, output_names, &self.0.inputs, &self.0.outputs);
graph
}
}