1use crate::codegen::{CompiledFunction, JitCompiler};
26use crate::ir::{Graph, Node, Op};
27use crate::optimize::{OptimizationPass, Optimizer};
28use crate::trace::{trace, TracedValue, Tracer};
29use crate::{JitError, JitResult};
30use std::collections::HashMap;
31use std::sync::Mutex;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Mode {
40 Default,
42 ReduceOverhead,
44 MaxAutotune,
46}
47
48impl Default for Mode {
49 fn default() -> Self {
50 Self::Default
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum Backend {
57 Default,
59 Eager,
61 AOT,
63 ONNX,
65}
66
67impl Default for Backend {
68 fn default() -> Self {
69 Self::Default
70 }
71}
72
73#[derive(Debug, Clone)]
79pub struct CompileConfig {
80 pub mode: Mode,
82 pub backend: Backend,
84 pub fullgraph: bool,
86 pub dynamic: bool,
88 pub disable: bool,
90 pub passes: Vec<OptimizationPass>,
92}
93
94impl Default for CompileConfig {
95 fn default() -> Self {
96 Self {
97 mode: Mode::Default,
98 backend: Backend::Default,
99 fullgraph: false,
100 dynamic: false,
101 disable: false,
102 passes: vec![
103 OptimizationPass::ConstantFolding,
104 OptimizationPass::DeadCodeElimination,
105 OptimizationPass::CommonSubexpressionElimination,
106 ],
107 }
108 }
109}
110
111impl CompileConfig {
112 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn mode(mut self, mode: Mode) -> Self {
119 self.mode = mode;
120 if mode == Mode::MaxAutotune {
121 self.passes.push(OptimizationPass::ElementwiseFusion);
123 self.passes.push(OptimizationPass::AlgebraicSimplification);
124 }
125 self
126 }
127
128 pub fn backend(mut self, backend: Backend) -> Self {
130 self.backend = backend;
131 self
132 }
133
134 pub fn fullgraph(mut self, fullgraph: bool) -> Self {
136 self.fullgraph = fullgraph;
137 self
138 }
139
140 pub fn dynamic(mut self, dynamic: bool) -> Self {
142 self.dynamic = dynamic;
143 self
144 }
145
146 pub fn disable(mut self, disable: bool) -> Self {
148 self.disable = disable;
149 self
150 }
151
152 pub fn add_pass(mut self, pass: OptimizationPass) -> Self {
154 self.passes.push(pass);
155 self
156 }
157}
158
159pub struct CompiledModel {
167 graph: Graph,
169 optimized_graph: Graph,
171 compiled_fn: Option<CompiledFunction>,
173 config: CompileConfig,
175 input_names: Vec<String>,
177 output_names: Vec<String>,
179}
180
181impl CompiledModel {
182 pub fn from_graph(graph: Graph, config: CompileConfig) -> JitResult<Self> {
184 let mut optimizer = Optimizer::new();
186 for pass in &config.passes {
187 optimizer.add_pass(*pass);
188 }
189 let optimized_graph = optimizer.optimize(graph.clone());
190
191 let compiled_fn = if !config.disable && config.backend != Backend::Eager {
193 let compiler = JitCompiler::new();
194 compiler.compile(&optimized_graph).ok()
195 } else {
196 None
197 };
198
199 let input_names: Vec<String> = graph.inputs().keys().cloned().collect();
200 let output_names: Vec<String> = graph.outputs().keys().cloned().collect();
201
202 Ok(Self {
203 graph,
204 optimized_graph,
205 compiled_fn,
206 config,
207 input_names,
208 output_names,
209 })
210 }
211
212 pub fn input_names(&self) -> &[String] {
214 &self.input_names
215 }
216
217 pub fn output_names(&self) -> &[String] {
219 &self.output_names
220 }
221
222 pub fn graph(&self) -> &Graph {
224 &self.graph
225 }
226
227 pub fn optimized_graph(&self) -> &Graph {
229 &self.optimized_graph
230 }
231
232 pub fn is_compiled(&self) -> bool {
234 self.compiled_fn.is_some()
235 }
236
237 pub fn stats(&self) -> CompileStats {
239 CompileStats {
240 original_ops: self.graph.len(),
241 optimized_ops: self.optimized_graph.len(),
242 is_compiled: self.compiled_fn.is_some(),
243 passes_applied: self.config.passes.len(),
244 }
245 }
246
247 pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
249 for name in &self.input_names {
251 if !inputs.contains_key(name) {
252 return Err(JitError::InputNotFound(name.clone()));
253 }
254 }
255
256 self.interpret(inputs)
258 }
259
260 fn interpret(
262 &self,
263 inputs: &HashMap<String, Vec<f32>>,
264 ) -> JitResult<HashMap<String, Vec<f32>>> {
265 let mut values: HashMap<String, Vec<f32>> = HashMap::new();
267
268 for (name, data) in inputs {
270 values.insert(name.clone(), data.clone());
271 }
272
273 for node in self.optimized_graph.nodes() {
274 let result = self.execute_node(node, &values)?;
275 let key = format!("node_{}", node.id.index());
277 values.insert(key, result);
278 }
279
280 let mut outputs = HashMap::new();
282 for name in &self.output_names {
283 if let Some(node_id) = self.optimized_graph.output(name) {
285 let key = format!("node_{}", node_id.index());
286 if let Some(val) = values.get(&key) {
287 outputs.insert(name.clone(), val.clone());
288 }
289 }
290 }
291
292 Ok(outputs)
293 }
294
295 fn execute_node(&self, node: &Node, values: &HashMap<String, Vec<f32>>) -> JitResult<Vec<f32>> {
297 match &node.op {
298 Op::Input { name } => values
299 .get(name)
300 .cloned()
301 .ok_or_else(|| JitError::InputNotFound(name.clone())),
302 Op::Output { input, .. } => {
303 let key = format!("node_{}", input.index());
304 values
305 .get(&key)
306 .cloned()
307 .ok_or_else(|| JitError::InputNotFound(key))
308 }
309 Op::Constant { value } => Ok(vec![*value as f32]),
310 Op::Add { lhs, rhs } => {
311 let a = self.get_node_value(*lhs, values)?;
312 let b = self.get_node_value(*rhs, values)?;
313 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
314 }
315 Op::Sub { lhs, rhs } => {
316 let a = self.get_node_value(*lhs, values)?;
317 let b = self.get_node_value(*rhs, values)?;
318 Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
319 }
320 Op::Mul { lhs, rhs } => {
321 let a = self.get_node_value(*lhs, values)?;
322 let b = self.get_node_value(*rhs, values)?;
323 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
324 }
325 Op::Div { lhs, rhs } => {
326 let a = self.get_node_value(*lhs, values)?;
327 let b = self.get_node_value(*rhs, values)?;
328 Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
329 }
330 Op::Neg { input } => {
331 let a = self.get_node_value(*input, values)?;
332 Ok(a.iter().map(|x| -x).collect())
333 }
334 Op::Exp { input } => {
335 let a = self.get_node_value(*input, values)?;
336 Ok(a.iter().map(|x| x.exp()).collect())
337 }
338 Op::Log { input } => {
339 let a = self.get_node_value(*input, values)?;
340 Ok(a.iter().map(|x| x.ln()).collect())
341 }
342 Op::Sqrt { input } => {
343 let a = self.get_node_value(*input, values)?;
344 Ok(a.iter().map(|x| x.sqrt()).collect())
345 }
346 Op::Relu { input } => {
347 let a = self.get_node_value(*input, values)?;
348 Ok(a.iter().map(|x| x.max(0.0)).collect())
349 }
350 Op::Sigmoid { input } => {
351 let a = self.get_node_value(*input, values)?;
352 Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
353 }
354 Op::Tanh { input } => {
355 let a = self.get_node_value(*input, values)?;
356 Ok(a.iter().map(|x| x.tanh()).collect())
357 }
358 _ => {
359 let numel = node.shape.numel();
361 Ok(vec![0.0; numel])
362 }
363 }
364 }
365
366 fn get_node_value(
368 &self,
369 node_id: crate::ir::NodeId,
370 values: &HashMap<String, Vec<f32>>,
371 ) -> JitResult<Vec<f32>> {
372 let node = self.optimized_graph.node(node_id);
374 if let Op::Input { name } = &node.op {
375 return values
376 .get(name)
377 .cloned()
378 .ok_or_else(|| JitError::InputNotFound(name.clone()));
379 }
380
381 let key = format!("node_{}", node_id.index());
383 values
384 .get(&key)
385 .cloned()
386 .ok_or_else(|| JitError::InputNotFound(key))
387 }
388}
389
390#[derive(Debug, Clone)]
392pub struct CompileStats {
393 pub original_ops: usize,
395 pub optimized_ops: usize,
397 pub is_compiled: bool,
399 pub passes_applied: usize,
401}
402
403impl CompileStats {
404 pub fn optimization_ratio(&self) -> f32 {
406 if self.original_ops == 0 {
407 1.0
408 } else {
409 self.optimized_ops as f32 / self.original_ops as f32
410 }
411 }
412}
413
414pub fn compile_graph(graph: Graph) -> JitResult<CompiledModel> {
431 CompiledModel::from_graph(graph, CompileConfig::default())
432}
433
434pub fn compile_graph_with_config(graph: Graph, config: CompileConfig) -> JitResult<CompiledModel> {
436 CompiledModel::from_graph(graph, config)
437}
438
439pub fn compile_fn<F>(f: F) -> JitResult<CompiledModel>
450where
451 F: FnOnce(&Tracer) -> TracedValue,
452{
453 let graph = trace(f);
454 compile_graph(graph)
455}
456
457pub fn compile_fn_with_config<F>(f: F, config: CompileConfig) -> JitResult<CompiledModel>
459where
460 F: FnOnce(&Tracer) -> TracedValue,
461{
462 let graph = trace(f);
463 compile_graph_with_config(graph, config)
464}
465
466pub struct LazyCompiled<F> {
474 func: F,
475 compiled: Mutex<Option<CompiledModel>>,
476 config: CompileConfig,
477}
478
479impl<F> LazyCompiled<F>
480where
481 F: Fn(&Tracer) -> TracedValue,
482{
483 pub fn new(func: F) -> Self {
485 Self {
486 func,
487 compiled: Mutex::new(None),
488 config: CompileConfig::default(),
489 }
490 }
491
492 pub fn with_config(func: F, config: CompileConfig) -> Self {
494 Self {
495 func,
496 compiled: Mutex::new(None),
497 config,
498 }
499 }
500
501 pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
503 let mut compiled = self.compiled.lock().unwrap();
504
505 if compiled.is_none() {
506 let graph = trace(&self.func);
507 *compiled = Some(CompiledModel::from_graph(graph, self.config.clone())?);
508 }
509
510 compiled.as_ref().unwrap().run(inputs)
511 }
512}
513
514#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_compile_config_default() {
524 let config = CompileConfig::default();
525 assert_eq!(config.mode, Mode::Default);
526 assert!(!config.fullgraph);
527 assert!(!config.disable);
528 }
529
530 #[test]
531 fn test_compile_config_builder() {
532 let config = CompileConfig::new()
533 .mode(Mode::MaxAutotune)
534 .fullgraph(true)
535 .dynamic(true);
536
537 assert_eq!(config.mode, Mode::MaxAutotune);
538 assert!(config.fullgraph);
539 assert!(config.dynamic);
540 }
541
542 #[test]
543 fn test_compile_simple_graph() {
544 let graph = trace(|t| {
545 let x = t.input("x", &[2]);
546 let y = x.relu();
547 t.output("y", y)
548 });
549
550 let compiled = compile_graph(graph).unwrap();
551 assert!(compiled.input_names().contains(&"x".to_string()));
552 }
553
554 #[test]
555 fn test_compile_stats() {
556 let graph = trace(|t| {
557 let x = t.input("x", &[2]);
558 let y = x.relu();
559 t.output("y", y)
560 });
561
562 let compiled = compile_graph(graph).unwrap();
563 let stats = compiled.stats();
564
565 assert!(stats.original_ops > 0);
566 assert!(stats.passes_applied > 0);
567 }
568
569 #[test]
570 fn test_mode_enum() {
571 assert_eq!(Mode::default(), Mode::Default);
572 assert_ne!(Mode::MaxAutotune, Mode::ReduceOverhead);
573 }
574
575 #[test]
576 fn test_backend_enum() {
577 assert_eq!(Backend::default(), Backend::Default);
578 }
579
580 #[test]
581 fn test_compiled_model_run() {
582 let graph = trace(|t| {
583 let x = t.input("x", &[2]);
584 let y = x.relu();
585 t.output("y", y)
586 });
587
588 let compiled = compile_graph_with_config(
589 graph,
590 CompileConfig::new().disable(true), )
592 .unwrap();
593
594 let mut inputs = HashMap::new();
595 inputs.insert("x".to_string(), vec![-1.0, 2.0]);
596
597 let outputs = compiled.run(&inputs).unwrap();
598 let y = outputs.get("y").unwrap();
599 assert_eq!(y, &vec![0.0, 2.0]); }
601}