use crate::codegen::{CompiledFunction, JitCompiler};
use crate::ir::{Graph, Node, Op};
use crate::optimize::{OptimizationPass, Optimizer};
use crate::trace::{TracedValue, Tracer, trace};
use crate::{JitError, JitResult};
use std::collections::HashMap;
use std::sync::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Mode {
#[default]
Default,
ReduceOverhead,
MaxAutotune,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Backend {
#[default]
Default,
Eager,
AOT,
ONNX,
}
#[derive(Debug, Clone)]
pub struct CompileConfig {
pub mode: Mode,
pub backend: Backend,
pub fullgraph: bool,
pub dynamic: bool,
pub disable: bool,
pub passes: Vec<OptimizationPass>,
}
impl Default for CompileConfig {
fn default() -> Self {
Self {
mode: Mode::Default,
backend: Backend::Default,
fullgraph: false,
dynamic: false,
disable: false,
passes: vec![
OptimizationPass::ConstantFolding,
OptimizationPass::DeadCodeElimination,
OptimizationPass::CommonSubexpressionElimination,
],
}
}
}
impl CompileConfig {
pub fn new() -> Self {
Self::default()
}
pub fn mode(mut self, mode: Mode) -> Self {
self.mode = mode;
if mode == Mode::MaxAutotune {
self.passes.push(OptimizationPass::ElementwiseFusion);
self.passes.push(OptimizationPass::AlgebraicSimplification);
}
self
}
pub fn backend(mut self, backend: Backend) -> Self {
self.backend = backend;
self
}
pub fn fullgraph(mut self, fullgraph: bool) -> Self {
self.fullgraph = fullgraph;
self
}
pub fn dynamic(mut self, dynamic: bool) -> Self {
self.dynamic = dynamic;
self
}
pub fn disable(mut self, disable: bool) -> Self {
self.disable = disable;
self
}
pub fn add_pass(mut self, pass: OptimizationPass) -> Self {
self.passes.push(pass);
self
}
}
pub struct CompiledModel {
graph: Graph,
optimized_graph: Graph,
compiled_fn: Option<CompiledFunction>,
config: CompileConfig,
input_names: Vec<String>,
output_names: Vec<String>,
}
impl CompiledModel {
pub fn from_graph(graph: Graph, config: CompileConfig) -> JitResult<Self> {
let mut optimizer = Optimizer::new();
for pass in &config.passes {
optimizer.add_pass(*pass);
}
let optimized_graph = optimizer.optimize(graph.clone());
let compiled_fn = if !config.disable && config.backend != Backend::Eager {
let compiler = JitCompiler::new();
compiler.compile(&optimized_graph).ok()
} else {
None
};
let input_names: Vec<String> = graph.inputs().keys().cloned().collect();
let output_names: Vec<String> = graph.outputs().keys().cloned().collect();
Ok(Self {
graph,
optimized_graph,
compiled_fn,
config,
input_names,
output_names,
})
}
pub fn input_names(&self) -> &[String] {
&self.input_names
}
pub fn output_names(&self) -> &[String] {
&self.output_names
}
pub fn graph(&self) -> &Graph {
&self.graph
}
pub fn optimized_graph(&self) -> &Graph {
&self.optimized_graph
}
pub fn is_compiled(&self) -> bool {
self.compiled_fn.is_some()
}
pub fn stats(&self) -> CompileStats {
CompileStats {
original_ops: self.graph.len(),
optimized_ops: self.optimized_graph.len(),
is_compiled: self.compiled_fn.is_some(),
passes_applied: self.config.passes.len(),
}
}
pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
for name in &self.input_names {
if !inputs.contains_key(name) {
return Err(JitError::InputNotFound(name.clone()));
}
}
if let Some(ref compiled) = self.compiled_fn {
let input_pairs: Vec<(String, Vec<f32>)> = self
.input_names
.iter()
.map(|name| (name.clone(), inputs[name].clone()))
.collect();
let input_refs: Vec<(&str, &[f32])> = input_pairs
.iter()
.map(|(name, data)| (name.as_str(), data.as_slice()))
.collect();
let flat_result = compiled.run(&input_refs)?;
let mut outputs = HashMap::new();
let mut offset = 0;
for name in &self.output_names {
let remaining = flat_result.len() - offset;
let size = remaining / (self.output_names.len() - outputs.len()).max(1);
outputs.insert(name.clone(), flat_result[offset..offset + size].to_vec());
offset += size;
}
Ok(outputs)
} else {
self.interpret(inputs)
}
}
fn interpret(
&self,
inputs: &HashMap<String, Vec<f32>>,
) -> JitResult<HashMap<String, Vec<f32>>> {
let mut values: HashMap<String, Vec<f32>> = HashMap::new();
for (name, data) in inputs {
values.insert(name.clone(), data.clone());
}
for node in self.optimized_graph.nodes() {
let result = self.execute_node(node, &values)?;
let key = format!("node_{}", node.id.index());
values.insert(key, result);
}
let mut outputs = HashMap::new();
for name in &self.output_names {
if let Some(node_id) = self.optimized_graph.output(name) {
let key = format!("node_{}", node_id.index());
if let Some(val) = values.get(&key) {
outputs.insert(name.clone(), val.clone());
}
}
}
Ok(outputs)
}
fn execute_node(&self, node: &Node, values: &HashMap<String, Vec<f32>>) -> JitResult<Vec<f32>> {
match &node.op {
Op::Input { name } => values
.get(name)
.cloned()
.ok_or_else(|| JitError::InputNotFound(name.clone())),
Op::Output { input, .. } => {
let key = format!("node_{}", input.index());
values
.get(&key)
.cloned()
.ok_or(JitError::InputNotFound(key))
}
Op::Constant { value } => Ok(vec![*value as f32]),
Op::Add { lhs, rhs } => {
let a = self.get_node_value(*lhs, values)?;
let b = self.get_node_value(*rhs, values)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
}
Op::Sub { lhs, rhs } => {
let a = self.get_node_value(*lhs, values)?;
let b = self.get_node_value(*rhs, values)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
}
Op::Mul { lhs, rhs } => {
let a = self.get_node_value(*lhs, values)?;
let b = self.get_node_value(*rhs, values)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
}
Op::Div { lhs, rhs } => {
let a = self.get_node_value(*lhs, values)?;
let b = self.get_node_value(*rhs, values)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
}
Op::Neg { input } => {
let a = self.get_node_value(*input, values)?;
Ok(a.iter().map(|x| -x).collect())
}
Op::Exp { input } => {
let a = self.get_node_value(*input, values)?;
Ok(a.iter().map(|x| x.exp()).collect())
}
Op::Log { input } => {
let a = self.get_node_value(*input, values)?;
Ok(a.iter().map(|x| x.ln()).collect())
}
Op::Sqrt { input } => {
let a = self.get_node_value(*input, values)?;
Ok(a.iter().map(|x| x.sqrt()).collect())
}
Op::Relu { input } => {
let a = self.get_node_value(*input, values)?;
Ok(a.iter().map(|x| x.max(0.0)).collect())
}
Op::Sigmoid { input } => {
let a = self.get_node_value(*input, values)?;
Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
}
Op::Tanh { input } => {
let a = self.get_node_value(*input, values)?;
Ok(a.iter().map(|x| x.tanh()).collect())
}
other => Err(crate::error::JitError::UnsupportedOp(format!(
"CompiledModel::execute_node: operation {:?} not implemented",
other
))),
}
}
fn get_node_value(
&self,
node_id: crate::ir::NodeId,
values: &HashMap<String, Vec<f32>>,
) -> JitResult<Vec<f32>> {
let node = self.optimized_graph.node(node_id);
if let Op::Input { name } = &node.op {
return values
.get(name)
.cloned()
.ok_or_else(|| JitError::InputNotFound(name.clone()));
}
let key = format!("node_{}", node_id.index());
values
.get(&key)
.cloned()
.ok_or(JitError::InputNotFound(key))
}
}
#[derive(Debug, Clone)]
pub struct CompileStats {
pub original_ops: usize,
pub optimized_ops: usize,
pub is_compiled: bool,
pub passes_applied: usize,
}
impl CompileStats {
pub fn optimization_ratio(&self) -> f32 {
if self.original_ops == 0 {
1.0
} else {
self.optimized_ops as f32 / self.original_ops as f32
}
}
}
pub fn compile_graph(graph: Graph) -> JitResult<CompiledModel> {
CompiledModel::from_graph(graph, CompileConfig::default())
}
pub fn compile_graph_with_config(graph: Graph, config: CompileConfig) -> JitResult<CompiledModel> {
CompiledModel::from_graph(graph, config)
}
pub fn compile_fn<F>(f: F) -> JitResult<CompiledModel>
where
F: FnOnce(&Tracer) -> TracedValue,
{
let graph = trace(f);
compile_graph(graph)
}
pub fn compile_fn_with_config<F>(f: F, config: CompileConfig) -> JitResult<CompiledModel>
where
F: FnOnce(&Tracer) -> TracedValue,
{
let graph = trace(f);
compile_graph_with_config(graph, config)
}
pub struct LazyCompiled<F> {
func: F,
compiled: Mutex<Option<CompiledModel>>,
config: CompileConfig,
}
impl<F> LazyCompiled<F>
where
F: Fn(&Tracer) -> TracedValue,
{
pub fn new(func: F) -> Self {
Self {
func,
compiled: Mutex::new(None),
config: CompileConfig::default(),
}
}
pub fn with_config(func: F, config: CompileConfig) -> Self {
Self {
func,
compiled: Mutex::new(None),
config,
}
}
pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
let mut compiled = self.compiled.lock().unwrap();
if compiled.is_none() {
let graph = trace(&self.func);
*compiled = Some(CompiledModel::from_graph(graph, self.config.clone())?);
}
compiled.as_ref().unwrap().run(inputs)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_config_default() {
let config = CompileConfig::default();
assert_eq!(config.mode, Mode::Default);
assert!(!config.fullgraph);
assert!(!config.disable);
}
#[test]
fn test_compile_config_builder() {
let config = CompileConfig::new()
.mode(Mode::MaxAutotune)
.fullgraph(true)
.dynamic(true);
assert_eq!(config.mode, Mode::MaxAutotune);
assert!(config.fullgraph);
assert!(config.dynamic);
}
#[test]
fn test_compile_simple_graph() {
let graph = trace(|t| {
let x = t.input("x", &[2]);
let y = x.relu();
t.output("y", y)
});
let compiled = compile_graph(graph).unwrap();
assert!(compiled.input_names().contains(&"x".to_string()));
}
#[test]
fn test_compile_stats() {
let graph = trace(|t| {
let x = t.input("x", &[2]);
let y = x.relu();
t.output("y", y)
});
let compiled = compile_graph(graph).unwrap();
let stats = compiled.stats();
assert!(stats.original_ops > 0);
assert!(stats.passes_applied > 0);
}
#[test]
fn test_mode_enum() {
assert_eq!(Mode::default(), Mode::Default);
assert_ne!(Mode::MaxAutotune, Mode::ReduceOverhead);
}
#[test]
fn test_backend_enum() {
assert_eq!(Backend::default(), Backend::Default);
}
#[test]
fn test_compiled_model_run() {
let graph = trace(|t| {
let x = t.input("x", &[2]);
let y = x.relu();
t.output("y", y)
});
let compiled = compile_graph_with_config(
graph,
CompileConfig::new().disable(true), )
.unwrap();
let mut inputs = HashMap::new();
inputs.insert("x".to_string(), vec![-1.0, 2.0]);
let outputs = compiled.run(&inputs).unwrap();
let y = outputs.get("y").unwrap();
assert_eq!(y, &vec![0.0, 2.0]); }
}