1use crate::codegen::{CompiledFunction, JitCompiler};
26use crate::ir::{Graph, Node, Op};
27use crate::optimize::{OptimizationPass, Optimizer};
28use crate::trace::{trace, TracedValue, Tracer};
29use crate::{JitError, JitResult};
30use std::collections::HashMap;
31use std::sync::Mutex;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Mode {
40 Default,
42 ReduceOverhead,
44 MaxAutotune,
46}
47
48impl Default for Mode {
49 fn default() -> Self {
50 Self::Default
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum Backend {
57 Default,
59 Eager,
61 AOT,
63 ONNX,
65}
66
67impl Default for Backend {
68 fn default() -> Self {
69 Self::Default
70 }
71}
72
73#[derive(Debug, Clone)]
79pub struct CompileConfig {
80 pub mode: Mode,
82 pub backend: Backend,
84 pub fullgraph: bool,
86 pub dynamic: bool,
88 pub disable: bool,
90 pub passes: Vec<OptimizationPass>,
92}
93
94impl Default for CompileConfig {
95 fn default() -> Self {
96 Self {
97 mode: Mode::Default,
98 backend: Backend::Default,
99 fullgraph: false,
100 dynamic: false,
101 disable: false,
102 passes: vec![
103 OptimizationPass::ConstantFolding,
104 OptimizationPass::DeadCodeElimination,
105 OptimizationPass::CommonSubexpressionElimination,
106 ],
107 }
108 }
109}
110
111impl CompileConfig {
112 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn mode(mut self, mode: Mode) -> Self {
119 self.mode = mode;
120 if mode == Mode::MaxAutotune {
121 self.passes.push(OptimizationPass::ElementwiseFusion);
123 self.passes.push(OptimizationPass::AlgebraicSimplification);
124 }
125 self
126 }
127
128 pub fn backend(mut self, backend: Backend) -> Self {
130 self.backend = backend;
131 self
132 }
133
134 pub fn fullgraph(mut self, fullgraph: bool) -> Self {
136 self.fullgraph = fullgraph;
137 self
138 }
139
140 pub fn dynamic(mut self, dynamic: bool) -> Self {
142 self.dynamic = dynamic;
143 self
144 }
145
146 pub fn disable(mut self, disable: bool) -> Self {
148 self.disable = disable;
149 self
150 }
151
152 pub fn add_pass(mut self, pass: OptimizationPass) -> Self {
154 self.passes.push(pass);
155 self
156 }
157}
158
159pub struct CompiledModel {
167 graph: Graph,
169 optimized_graph: Graph,
171 compiled_fn: Option<CompiledFunction>,
173 config: CompileConfig,
175 input_names: Vec<String>,
177 output_names: Vec<String>,
179}
180
181impl CompiledModel {
182 pub fn from_graph(graph: Graph, config: CompileConfig) -> JitResult<Self> {
184 let mut optimizer = Optimizer::new();
186 for pass in &config.passes {
187 optimizer.add_pass(*pass);
188 }
189 let optimized_graph = optimizer.optimize(graph.clone());
190
191 let compiled_fn = if !config.disable && config.backend != Backend::Eager {
193 let compiler = JitCompiler::new();
194 compiler.compile(&optimized_graph).ok()
195 } else {
196 None
197 };
198
199 let input_names: Vec<String> = graph.inputs().keys().cloned().collect();
200 let output_names: Vec<String> = graph.outputs().keys().cloned().collect();
201
202 Ok(Self {
203 graph,
204 optimized_graph,
205 compiled_fn,
206 config,
207 input_names,
208 output_names,
209 })
210 }
211
212 pub fn input_names(&self) -> &[String] {
214 &self.input_names
215 }
216
217 pub fn output_names(&self) -> &[String] {
219 &self.output_names
220 }
221
222 pub fn graph(&self) -> &Graph {
224 &self.graph
225 }
226
227 pub fn optimized_graph(&self) -> &Graph {
229 &self.optimized_graph
230 }
231
232 pub fn is_compiled(&self) -> bool {
234 self.compiled_fn.is_some()
235 }
236
237 pub fn stats(&self) -> CompileStats {
239 CompileStats {
240 original_ops: self.graph.len(),
241 optimized_ops: self.optimized_graph.len(),
242 is_compiled: self.compiled_fn.is_some(),
243 passes_applied: self.config.passes.len(),
244 }
245 }
246
247 pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
249 for name in &self.input_names {
251 if !inputs.contains_key(name) {
252 return Err(JitError::InputNotFound(name.clone()));
253 }
254 }
255
256 self.interpret(inputs)
258 }
259
260 fn interpret(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
262 let mut values: HashMap<String, Vec<f32>> = HashMap::new();
264
265 for (name, data) in inputs {
267 values.insert(name.clone(), data.clone());
268 }
269
270 for node in self.optimized_graph.nodes() {
271 let result = self.execute_node(node, &values)?;
272 let key = format!("node_{}", node.id.index());
274 values.insert(key, result);
275 }
276
277 let mut outputs = HashMap::new();
279 for name in &self.output_names {
280 if let Some(node_id) = self.optimized_graph.output(name) {
282 let key = format!("node_{}", node_id.index());
283 if let Some(val) = values.get(&key) {
284 outputs.insert(name.clone(), val.clone());
285 }
286 }
287 }
288
289 Ok(outputs)
290 }
291
292 fn execute_node(
294 &self,
295 node: &Node,
296 values: &HashMap<String, Vec<f32>>,
297 ) -> JitResult<Vec<f32>> {
298 match &node.op {
299 Op::Input { name } => {
300 values
301 .get(name)
302 .cloned()
303 .ok_or_else(|| JitError::InputNotFound(name.clone()))
304 }
305 Op::Output { input, .. } => {
306 let key = format!("node_{}", input.index());
307 values
308 .get(&key)
309 .cloned()
310 .ok_or_else(|| JitError::InputNotFound(key))
311 }
312 Op::Constant { value } => Ok(vec![*value as f32]),
313 Op::Add { lhs, rhs } => {
314 let a = self.get_node_value(*lhs, values)?;
315 let b = self.get_node_value(*rhs, values)?;
316 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
317 }
318 Op::Sub { lhs, rhs } => {
319 let a = self.get_node_value(*lhs, values)?;
320 let b = self.get_node_value(*rhs, values)?;
321 Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
322 }
323 Op::Mul { lhs, rhs } => {
324 let a = self.get_node_value(*lhs, values)?;
325 let b = self.get_node_value(*rhs, values)?;
326 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
327 }
328 Op::Div { lhs, rhs } => {
329 let a = self.get_node_value(*lhs, values)?;
330 let b = self.get_node_value(*rhs, values)?;
331 Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
332 }
333 Op::Neg { input } => {
334 let a = self.get_node_value(*input, values)?;
335 Ok(a.iter().map(|x| -x).collect())
336 }
337 Op::Exp { input } => {
338 let a = self.get_node_value(*input, values)?;
339 Ok(a.iter().map(|x| x.exp()).collect())
340 }
341 Op::Log { input } => {
342 let a = self.get_node_value(*input, values)?;
343 Ok(a.iter().map(|x| x.ln()).collect())
344 }
345 Op::Sqrt { input } => {
346 let a = self.get_node_value(*input, values)?;
347 Ok(a.iter().map(|x| x.sqrt()).collect())
348 }
349 Op::Relu { input } => {
350 let a = self.get_node_value(*input, values)?;
351 Ok(a.iter().map(|x| x.max(0.0)).collect())
352 }
353 Op::Sigmoid { input } => {
354 let a = self.get_node_value(*input, values)?;
355 Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
356 }
357 Op::Tanh { input } => {
358 let a = self.get_node_value(*input, values)?;
359 Ok(a.iter().map(|x| x.tanh()).collect())
360 }
361 _ => {
362 let numel = node.shape.numel();
364 Ok(vec![0.0; numel])
365 }
366 }
367 }
368
369 fn get_node_value(
371 &self,
372 node_id: crate::ir::NodeId,
373 values: &HashMap<String, Vec<f32>>,
374 ) -> JitResult<Vec<f32>> {
375 let node = self.optimized_graph.node(node_id);
377 if let Op::Input { name } = &node.op {
378 return values
379 .get(name)
380 .cloned()
381 .ok_or_else(|| JitError::InputNotFound(name.clone()));
382 }
383
384 let key = format!("node_{}", node_id.index());
386 values
387 .get(&key)
388 .cloned()
389 .ok_or_else(|| JitError::InputNotFound(key))
390 }
391}
392
393#[derive(Debug, Clone)]
395pub struct CompileStats {
396 pub original_ops: usize,
398 pub optimized_ops: usize,
400 pub is_compiled: bool,
402 pub passes_applied: usize,
404}
405
406impl CompileStats {
407 pub fn optimization_ratio(&self) -> f32 {
409 if self.original_ops == 0 {
410 1.0
411 } else {
412 self.optimized_ops as f32 / self.original_ops as f32
413 }
414 }
415}
416
417pub fn compile_graph(graph: Graph) -> JitResult<CompiledModel> {
434 CompiledModel::from_graph(graph, CompileConfig::default())
435}
436
437pub fn compile_graph_with_config(graph: Graph, config: CompileConfig) -> JitResult<CompiledModel> {
439 CompiledModel::from_graph(graph, config)
440}
441
442pub fn compile_fn<F>(f: F) -> JitResult<CompiledModel>
453where
454 F: FnOnce(&Tracer) -> TracedValue,
455{
456 let graph = trace(f);
457 compile_graph(graph)
458}
459
460pub fn compile_fn_with_config<F>(f: F, config: CompileConfig) -> JitResult<CompiledModel>
462where
463 F: FnOnce(&Tracer) -> TracedValue,
464{
465 let graph = trace(f);
466 compile_graph_with_config(graph, config)
467}
468
469pub struct LazyCompiled<F> {
477 func: F,
478 compiled: Mutex<Option<CompiledModel>>,
479 config: CompileConfig,
480}
481
482impl<F> LazyCompiled<F>
483where
484 F: Fn(&Tracer) -> TracedValue,
485{
486 pub fn new(func: F) -> Self {
488 Self {
489 func,
490 compiled: Mutex::new(None),
491 config: CompileConfig::default(),
492 }
493 }
494
495 pub fn with_config(func: F, config: CompileConfig) -> Self {
497 Self {
498 func,
499 compiled: Mutex::new(None),
500 config,
501 }
502 }
503
504 pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
506 let mut compiled = self.compiled.lock().unwrap();
507
508 if compiled.is_none() {
509 let graph = trace(&self.func);
510 *compiled = Some(CompiledModel::from_graph(graph, self.config.clone())?);
511 }
512
513 compiled.as_ref().unwrap().run(inputs)
514 }
515}
516
517#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_compile_config_default() {
527 let config = CompileConfig::default();
528 assert_eq!(config.mode, Mode::Default);
529 assert!(!config.fullgraph);
530 assert!(!config.disable);
531 }
532
533 #[test]
534 fn test_compile_config_builder() {
535 let config = CompileConfig::new()
536 .mode(Mode::MaxAutotune)
537 .fullgraph(true)
538 .dynamic(true);
539
540 assert_eq!(config.mode, Mode::MaxAutotune);
541 assert!(config.fullgraph);
542 assert!(config.dynamic);
543 }
544
545 #[test]
546 fn test_compile_simple_graph() {
547 let graph = trace(|t| {
548 let x = t.input("x", &[2]);
549 let y = x.relu();
550 t.output("y", y)
551 });
552
553 let compiled = compile_graph(graph).unwrap();
554 assert!(compiled.input_names().contains(&"x".to_string()));
555 }
556
557 #[test]
558 fn test_compile_stats() {
559 let graph = trace(|t| {
560 let x = t.input("x", &[2]);
561 let y = x.relu();
562 t.output("y", y)
563 });
564
565 let compiled = compile_graph(graph).unwrap();
566 let stats = compiled.stats();
567
568 assert!(stats.original_ops > 0);
569 assert!(stats.passes_applied > 0);
570 }
571
572 #[test]
573 fn test_mode_enum() {
574 assert_eq!(Mode::default(), Mode::Default);
575 assert_ne!(Mode::MaxAutotune, Mode::ReduceOverhead);
576 }
577
578 #[test]
579 fn test_backend_enum() {
580 assert_eq!(Backend::default(), Backend::Default);
581 }
582
583 #[test]
584 fn test_compiled_model_run() {
585 let graph = trace(|t| {
586 let x = t.input("x", &[2]);
587 let y = x.relu();
588 t.output("y", y)
589 });
590
591 let compiled = compile_graph_with_config(
592 graph,
593 CompileConfig::new().disable(true), )
595 .unwrap();
596
597 let mut inputs = HashMap::new();
598 inputs.insert("x".to_string(), vec![-1.0, 2.0]);
599
600 let outputs = compiled.run(&inputs).unwrap();
601 let y = outputs.get("y").unwrap();
602 assert_eq!(y, &vec![0.0, 2.0]); }
604}