#![allow(clippy::excessive_nesting)] #![allow(unused_variables)]
use crate::compiler::{
CompilationResult, CompilationStats, CompilerConfig, ComputationGraph, GraphNode,
};
use crate::errors::TrustformersError;
use crate::errors::{invalid_format, runtime_error, unsupported_operation};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
pub struct JitCompiler {
config: CompilerConfig,
backend: Box<dyn JitBackend>,
compilation_cache: Arc<Mutex<HashMap<String, CachedCompilation>>>,
compilation_stats: CompilationStatistics,
}
impl JitCompiler {
pub fn new(config: &CompilerConfig) -> Result<Self, TrustformersError> {
let backend = Self::create_backend(config)?;
Ok(Self {
config: config.clone(),
backend,
compilation_cache: Arc::new(Mutex::new(HashMap::new())),
compilation_stats: CompilationStatistics::new(),
})
}
pub fn update_config(&mut self, config: &CompilerConfig) -> Result<(), TrustformersError> {
self.config = config.clone();
self.backend = Self::create_backend(config)?;
Ok(())
}
fn create_backend(config: &CompilerConfig) -> Result<Box<dyn JitBackend>, TrustformersError> {
#[cfg(feature = "llvm")]
if config.compiler_flags.contains(&"llvm".to_string()) {
return Ok(Box::new(LLVMBackend::new(config)?));
}
#[cfg(feature = "cranelift")]
if config.compiler_flags.contains(&"cranelift".to_string()) {
return Ok(Box::new(CraneliftBackend::new(config)?));
}
Ok(Box::new(InterpreterBackend::new(config)?))
}
pub fn compile(
&mut self,
graph: ComputationGraph,
) -> Result<CompilationResult, TrustformersError> {
let start_time = Instant::now();
let cache_key = self.generate_cache_key(&graph)?;
if self.config.enable_cache {
if let Some(cached) = self.get_cached_compilation(&cache_key)? {
self.compilation_stats.cache_hits += 1;
return Ok(CompilationResult {
compiled_code: cached.compiled_code.clone(),
stats: cached.stats.clone(),
metadata: cached.metadata.clone(),
});
}
}
self.compilation_stats.cache_misses += 1;
graph.validate()?;
let ir = self.generate_ir(&graph)?;
let original_ir_size = ir.instructions.len();
let original_compute_cost = self.calculate_total_compute_cost(&ir);
let original_memory_cost = self.calculate_total_memory_cost(&ir);
let (optimized_ir, optimization_metrics) = self.optimize_ir_with_metrics(ir)?;
let compiled_code = self.backend.compile_ir(optimized_ir)?;
let compilation_time = start_time.elapsed();
let optimized_compute_cost =
self.calculate_total_compute_cost(&optimization_metrics.optimized_ir);
let optimized_memory_cost =
self.calculate_total_memory_cost(&optimization_metrics.optimized_ir);
let performance_gain = if optimized_compute_cost > 0.0 {
original_compute_cost / optimized_compute_cost
} else {
1.0
};
let memory_reduction = if original_memory_cost > 0.0 {
(original_memory_cost - optimized_memory_cost) / original_memory_cost
} else {
0.0
};
let stats = CompilationStats {
compilation_time_ms: compilation_time.as_millis() as u64,
original_ops: graph.nodes.len(),
optimized_ops: optimization_metrics.optimized_ir.instructions.len(),
fused_kernels: optimization_metrics.fused_kernels,
performance_gain,
memory_reduction,
applied_passes: optimization_metrics.applied_passes,
};
let metadata = HashMap::new();
let result = CompilationResult {
compiled_code: compiled_code.clone(),
stats: stats.clone(),
metadata: metadata.clone(),
};
if self.config.enable_cache {
self.cache_compilation(cache_key, compiled_code, stats, metadata)?;
}
self.compilation_stats.compilations += 1;
self.compilation_stats.total_compilation_time += compilation_time;
Ok(result)
}
fn generate_ir(
&self,
graph: &ComputationGraph,
) -> Result<IntermediateRepresentation, TrustformersError> {
let mut ir = IntermediateRepresentation::new();
for node in &graph.nodes {
let instruction = self.node_to_instruction(node)?;
ir.add_instruction(instruction);
}
for edge in &graph.edges {
ir.add_dependency(edge.from, edge.to);
}
Ok(ir)
}
fn node_to_instruction(&self, node: &GraphNode) -> Result<IRInstruction, TrustformersError> {
let opcode = match node.op_type.as_str() {
"MatMul" => IROpcode::MatMul,
"Add" => IROpcode::Add,
"Mul" => IROpcode::Mul,
"ReLU" => IROpcode::ReLU,
"Sigmoid" => IROpcode::Sigmoid,
"Tanh" => IROpcode::Tanh,
"Softmax" => IROpcode::Softmax,
"LayerNorm" => IROpcode::LayerNorm,
"Attention" => IROpcode::Attention,
"Embedding" => IROpcode::Embedding,
"Linear" => IROpcode::Linear,
"Conv2D" => IROpcode::Conv2D,
"Pool2D" => IROpcode::Pool2D,
"Reshape" => IROpcode::Reshape,
"Transpose" => IROpcode::Transpose,
_ => return Err(unsupported_operation("node_compilation", &node.op_type)),
};
Ok(IRInstruction {
id: node.id,
opcode,
inputs: node.input_shapes.clone(),
outputs: node.output_shapes.clone(),
attributes: node.attributes.clone(),
compute_cost: node.compute_cost,
memory_cost: node.memory_cost,
})
}
#[allow(dead_code)]
fn optimize_ir(
&self,
mut ir: IntermediateRepresentation,
) -> Result<IntermediateRepresentation, TrustformersError> {
ir = self.apply_constant_propagation(ir)?;
ir = self.apply_dead_instruction_elimination(ir)?;
ir = self.apply_instruction_scheduling(ir)?;
Ok(ir)
}
fn optimize_ir_with_metrics(
&self,
mut ir: IntermediateRepresentation,
) -> Result<(IntermediateRepresentation, OptimizationMetrics), TrustformersError> {
let mut applied_passes = Vec::new();
let mut fused_kernels = 0;
let (ir_after_cp, cp_fused) = self.apply_constant_propagation_with_metrics(ir)?;
ir = ir_after_cp;
fused_kernels += cp_fused;
applied_passes.push("constant_propagation".to_string());
let (ir_after_die, die_removed) =
self.apply_dead_instruction_elimination_with_metrics(ir)?;
ir = ir_after_die;
applied_passes.push(format!(
"dead_instruction_elimination(removed: {})",
die_removed
));
let (ir_after_sched, sched_reordered) =
self.apply_instruction_scheduling_with_metrics(ir)?;
ir = ir_after_sched;
applied_passes.push(format!(
"instruction_scheduling(reordered: {})",
sched_reordered
));
let (ir_after_fusion, fusion_count) = self.apply_kernel_fusion_with_metrics(ir)?;
ir = ir_after_fusion;
fused_kernels += fusion_count;
applied_passes.push(format!("kernel_fusion(fused: {})", fusion_count));
let metrics = OptimizationMetrics {
optimized_ir: ir.clone(),
fused_kernels,
applied_passes,
};
Ok((ir, metrics))
}
fn apply_constant_propagation(
&self,
mut ir: IntermediateRepresentation,
) -> Result<IntermediateRepresentation, TrustformersError> {
let mut changed = true;
while changed {
changed = false;
for instruction in &mut ir.instructions {
if self.can_evaluate_at_compile_time(instruction) {
instruction.attributes.insert("constant".to_string(), "true".to_string());
changed = true;
}
}
}
Ok(ir)
}
fn apply_dead_instruction_elimination(
&self,
mut ir: IntermediateRepresentation,
) -> Result<IntermediateRepresentation, TrustformersError> {
let mut used = vec![false; ir.instructions.len()];
for (i, instruction) in ir.instructions.iter().enumerate() {
if instruction.attributes.contains_key("output") {
used[i] = true;
}
}
let mut changed = true;
while changed {
changed = false;
for &(from, to) in &ir.dependencies {
if used[to] && !used[from] {
used[from] = true;
changed = true;
}
}
}
ir.instructions.retain(|instruction| used[instruction.id]);
Ok(ir)
}
fn apply_instruction_scheduling(
&self,
ir: IntermediateRepresentation,
) -> Result<IntermediateRepresentation, TrustformersError> {
Ok(ir)
}
fn can_evaluate_at_compile_time(&self, instruction: &IRInstruction) -> bool {
matches!(instruction.opcode, IROpcode::Add | IROpcode::Mul)
&& instruction.attributes.get("all_inputs_constant").is_some_and(|v| v == "true")
}
fn apply_constant_fold_arithmetic(
&self,
instruction: &mut IRInstruction,
) -> Option<(String, bool)> {
if matches!(
instruction.opcode,
IROpcode::Add | IROpcode::Mul | IROpcode::Sub | IROpcode::Div
) {
if let Some(constant_value) = self.evaluate_constant_instruction(instruction) {
instruction
.attributes
.insert("folded_value".to_string(), constant_value.clone());
return Some((constant_value, true));
}
}
None
}
fn generate_cache_key(&self, graph: &ComputationGraph) -> Result<String, TrustformersError> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
graph.nodes.len().hash(&mut hasher);
graph.edges.len().hash(&mut hasher);
for node in &graph.nodes {
node.op_type.hash(&mut hasher);
node.input_shapes.hash(&mut hasher);
node.output_shapes.hash(&mut hasher);
}
for edge in &graph.edges {
edge.from.hash(&mut hasher);
edge.to.hash(&mut hasher);
edge.shape.hash(&mut hasher);
edge.dtype.hash(&mut hasher);
}
self.config.target_hardware.device_type.hash(&mut hasher);
self.config.target_hardware.compute_units.hash(&mut hasher);
Ok(format!("{:x}", hasher.finish()))
}
fn get_cached_compilation(
&self,
cache_key: &str,
) -> Result<Option<CachedCompilation>, TrustformersError> {
let cache = self
.compilation_cache
.lock()
.map_err(|_| runtime_error("Failed to acquire cache lock"))?;
Ok(cache.get(cache_key).cloned())
}
fn cache_compilation(
&self,
cache_key: String,
compiled_code: Vec<u8>,
stats: CompilationStats,
metadata: HashMap<String, String>,
) -> Result<(), TrustformersError> {
let mut cache = self
.compilation_cache
.lock()
.map_err(|_| runtime_error("Failed to acquire cache lock"))?;
let cached = CachedCompilation {
compiled_code,
stats,
metadata,
timestamp: std::time::SystemTime::now(),
};
cache.insert(cache_key, cached);
Ok(())
}
pub fn clear_cache(&mut self) {
if let Ok(mut cache) = self.compilation_cache.lock() {
cache.clear();
}
}
pub fn cache_size(&self) -> usize {
self.compilation_cache.lock().map(|cache| cache.len()).unwrap_or(0)
}
pub fn get_stats(&self) -> &CompilationStatistics {
&self.compilation_stats
}
pub fn reset_stats(&mut self) {
self.compilation_stats = CompilationStatistics::new();
}
fn calculate_total_compute_cost(&self, ir: &IntermediateRepresentation) -> f64 {
ir.instructions.iter().map(|inst| inst.compute_cost).sum()
}
fn calculate_total_memory_cost(&self, ir: &IntermediateRepresentation) -> f64 {
ir.instructions.iter().map(|inst| inst.memory_cost).sum()
}
fn apply_constant_propagation_with_metrics(
&self,
mut ir: IntermediateRepresentation,
) -> Result<(IntermediateRepresentation, usize), TrustformersError> {
let mut fused_operations = 0;
let mut changed = true;
while changed {
changed = false;
let instructions_to_remove = Vec::new();
for (i, instruction) in ir.instructions.iter_mut().enumerate() {
if !self.can_evaluate_at_compile_time(instruction) {
continue;
}
instruction.attributes.insert("constant".to_string(), "true".to_string());
if let Some((_value, folded)) = self.apply_constant_fold_arithmetic(instruction) {
if folded {
fused_operations += 1;
changed = true;
}
}
}
for i in instructions_to_remove.into_iter().rev() {
ir.instructions.remove(i);
}
}
Ok((ir, fused_operations))
}
fn apply_dead_instruction_elimination_with_metrics(
&self,
mut ir: IntermediateRepresentation,
) -> Result<(IntermediateRepresentation, usize), TrustformersError> {
let original_count = ir.instructions.len();
let mut used = vec![false; ir.instructions.len()];
for (i, instruction) in ir.instructions.iter().enumerate() {
if instruction.attributes.contains_key("output") {
used[i] = true;
}
}
let mut changed = true;
while changed {
changed = false;
for &(from, to) in &ir.dependencies {
if to < used.len() && from < used.len() && used[to] && !used[from] {
used[from] = true;
changed = true;
}
}
}
let mut instruction_id_map = HashMap::new();
let mut new_instructions = Vec::new();
let mut new_id = 0;
for (old_id, instruction) in ir.instructions.into_iter().enumerate() {
if used[old_id] {
instruction_id_map.insert(old_id, new_id);
new_instructions.push(IRInstruction {
id: new_id,
..instruction
});
new_id += 1;
}
}
ir.instructions = new_instructions;
ir.dependencies = ir
.dependencies
.into_iter()
.filter_map(|(from, to)| {
if let (Some(&new_from), Some(&new_to)) =
(instruction_id_map.get(&from), instruction_id_map.get(&to))
{
Some((new_from, new_to))
} else {
None
}
})
.collect();
let removed_count = original_count - ir.instructions.len();
Ok((ir, removed_count))
}
fn apply_instruction_scheduling_with_metrics(
&self,
mut ir: IntermediateRepresentation,
) -> Result<(IntermediateRepresentation, usize), TrustformersError> {
let mut reordered_count = 0;
let mut instruction_depths = vec![0; ir.instructions.len()];
for &(from, to) in &ir.dependencies {
if from < instruction_depths.len() && to < instruction_depths.len() {
instruction_depths[to] = instruction_depths[to].max(instruction_depths[from] + 1);
}
}
let mut instruction_indices: Vec<usize> = (0..ir.instructions.len()).collect();
instruction_indices.sort_by_key(|&i| instruction_depths[i]);
for (new_pos, &old_pos) in instruction_indices.iter().enumerate() {
if new_pos != old_pos {
reordered_count += 1;
}
}
let mut new_instructions = Vec::new();
for &old_index in &instruction_indices {
if old_index < ir.instructions.len() {
new_instructions.push(ir.instructions[old_index].clone());
}
}
for (new_id, instruction) in new_instructions.iter_mut().enumerate() {
instruction.id = new_id;
}
ir.instructions = new_instructions;
Ok((ir, reordered_count))
}
fn apply_kernel_fusion_with_metrics(
&self,
mut ir: IntermediateRepresentation,
) -> Result<(IntermediateRepresentation, usize), TrustformersError> {
let mut fused_count = 0;
let mut i = 0;
while i < ir.instructions.len().saturating_sub(1) {
let can_fuse = self.can_fuse_instructions(&ir.instructions[i], &ir.instructions[i + 1]);
if can_fuse {
let fused_instruction =
self.create_fused_instruction(&ir.instructions[i], &ir.instructions[i + 1])?;
ir.instructions[i] = fused_instruction;
ir.instructions.remove(i + 1);
for j in i + 1..ir.instructions.len() {
ir.instructions[j].id = j;
}
fused_count += 1;
} else {
i += 1;
}
}
Ok((ir, fused_count))
}
fn can_fuse_instructions(&self, inst1: &IRInstruction, inst2: &IRInstruction) -> bool {
match (&inst1.opcode, &inst2.opcode) {
(IROpcode::Add, IROpcode::ReLU) => true,
(IROpcode::MatMul, IROpcode::Add) => true, (IROpcode::ReLU, IROpcode::Add) => true,
(IROpcode::Add, IROpcode::Mul) => true,
_ => false,
}
}
fn create_fused_instruction(
&self,
inst1: &IRInstruction,
inst2: &IRInstruction,
) -> Result<IRInstruction, TrustformersError> {
let mut fused_attributes = inst1.attributes.clone();
fused_attributes
.extend(inst2.attributes.iter().map(|(k, v)| (format!("fused_{}", k), v.clone())));
fused_attributes.insert(
"fused_ops".to_string(),
format!("{:?}+{:?}", inst1.opcode, inst2.opcode),
);
Ok(IRInstruction {
id: inst1.id,
opcode: self.get_fused_opcode(&inst1.opcode, &inst2.opcode),
inputs: inst1.inputs.clone(),
outputs: inst2.outputs.clone(),
attributes: fused_attributes,
compute_cost: inst1.compute_cost + inst2.compute_cost * 0.7, memory_cost: (inst1.memory_cost + inst2.memory_cost) * 0.8, })
}
fn get_fused_opcode(&self, op1: &IROpcode, op2: &IROpcode) -> IROpcode {
match (op1, op2) {
(IROpcode::Add, IROpcode::ReLU) => IROpcode::Custom("AddReLU".to_string()),
(IROpcode::MatMul, IROpcode::Add) => IROpcode::Custom("MatMulBias".to_string()),
(IROpcode::ReLU, IROpcode::Add) => IROpcode::Custom("ReLUAdd".to_string()),
(IROpcode::Add, IROpcode::Mul) => IROpcode::Custom("AddMul".to_string()),
_ => IROpcode::Custom(format!("{:?}_{:?}", op1, op2)),
}
}
fn evaluate_constant_instruction(&self, instruction: &IRInstruction) -> Option<String> {
match instruction.opcode {
IROpcode::Add
if instruction.attributes.contains_key("const_a")
&& instruction.attributes.contains_key("const_b") =>
{
if let (Ok(a), Ok(b)) = (
instruction
.attributes
.get("const_a")
.expect("const_a must exist after contains_key check")
.parse::<f64>(),
instruction
.attributes
.get("const_b")
.expect("const_b must exist after contains_key check")
.parse::<f64>(),
) {
return Some((a + b).to_string());
}
},
IROpcode::Mul
if instruction.attributes.contains_key("const_a")
&& instruction.attributes.contains_key("const_b") =>
{
if let (Ok(a), Ok(b)) = (
instruction
.attributes
.get("const_a")
.expect("const_a must exist after contains_key check")
.parse::<f64>(),
instruction
.attributes
.get("const_b")
.expect("const_b must exist after contains_key check")
.parse::<f64>(),
) {
return Some((a * b).to_string());
}
},
_ => {},
}
None
}
}
#[derive(Debug, Clone)]
struct OptimizationMetrics {
optimized_ir: IntermediateRepresentation,
fused_kernels: usize,
applied_passes: Vec<String>,
}
#[derive(Debug, Clone)]
struct CachedCompilation {
compiled_code: Vec<u8>,
stats: CompilationStats,
metadata: HashMap<String, String>,
#[allow(dead_code)]
timestamp: std::time::SystemTime,
}
#[derive(Debug, Default, Clone)]
pub struct CompilationStatistics {
pub compilations: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub total_compilation_time: std::time::Duration,
}
impl CompilationStatistics {
pub fn new() -> Self {
Self::default()
}
pub fn cache_hit_rate(&self) -> f64 {
let total = self.cache_hits + self.cache_misses;
if total == 0 {
0.0
} else {
self.cache_hits as f64 / total as f64
}
}
pub fn average_compilation_time(&self) -> std::time::Duration {
if self.compilations == 0 {
std::time::Duration::ZERO
} else {
self.total_compilation_time / self.compilations as u32
}
}
}
#[derive(Debug, Clone)]
pub struct IntermediateRepresentation {
pub instructions: Vec<IRInstruction>,
pub dependencies: Vec<(usize, usize)>,
pub metadata: HashMap<String, String>,
}
impl IntermediateRepresentation {
pub fn new() -> Self {
Self {
instructions: Vec::new(),
dependencies: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn add_instruction(&mut self, instruction: IRInstruction) {
self.instructions.push(instruction);
}
pub fn add_dependency(&mut self, from: usize, to: usize) {
self.dependencies.push((from, to));
}
}
impl Default for IntermediateRepresentation {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct IRInstruction {
pub id: usize,
pub opcode: IROpcode,
pub inputs: Vec<Vec<usize>>,
pub outputs: Vec<Vec<usize>>,
pub attributes: HashMap<String, String>,
pub compute_cost: f64,
pub memory_cost: f64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IROpcode {
Add,
Mul,
Sub,
Div,
MatMul,
ReLU,
Sigmoid,
Tanh,
Softmax,
Linear,
LayerNorm,
Attention,
Embedding,
Conv2D,
Conv3D,
Pool2D,
Pool3D,
Reshape,
Transpose,
Concat,
Split,
If,
While,
Call,
Return,
Load,
Store,
Alloc,
Free,
Custom(String),
}
pub trait JitBackend: Send + Sync {
fn compile_ir(&mut self, ir: IntermediateRepresentation) -> Result<Vec<u8>, TrustformersError>;
fn name(&self) -> &str;
fn supported_targets(&self) -> Vec<String>;
fn optimize_ir(
&self,
ir: IntermediateRepresentation,
) -> Result<IntermediateRepresentation, TrustformersError> {
Ok(ir)
}
}
#[cfg(feature = "llvm")]
pub struct LLVMBackend {
#[allow(dead_code)]
config: CompilerConfig,
}
#[cfg(feature = "llvm")]
impl LLVMBackend {
pub fn new(config: &CompilerConfig) -> Result<Self, TrustformersError> {
Ok(Self {
config: config.clone(),
})
}
}
#[cfg(feature = "llvm")]
impl JitBackend for LLVMBackend {
fn compile_ir(
&mut self,
_ir: IntermediateRepresentation,
) -> Result<Vec<u8>, TrustformersError> {
Ok(vec![0x90, 0xc3]) }
fn name(&self) -> &str {
"LLVM"
}
fn supported_targets(&self) -> Vec<String> {
vec![
"x86_64".to_string(),
"aarch64".to_string(),
"arm".to_string(),
]
}
}
#[cfg(feature = "cranelift")]
pub struct CraneliftBackend {
#[allow(dead_code)]
config: CompilerConfig,
}
#[cfg(feature = "cranelift")]
impl CraneliftBackend {
pub fn new(config: &CompilerConfig) -> Result<Self, TrustformersError> {
Ok(Self {
config: config.clone(),
})
}
}
#[cfg(feature = "cranelift")]
impl JitBackend for CraneliftBackend {
fn compile_ir(
&mut self,
_ir: IntermediateRepresentation,
) -> Result<Vec<u8>, TrustformersError> {
Ok(vec![0x90, 0xc3]) }
fn name(&self) -> &str {
"Cranelift"
}
fn supported_targets(&self) -> Vec<String> {
vec!["x86_64".to_string(), "aarch64".to_string()]
}
}
pub struct InterpreterBackend {
#[allow(dead_code)]
config: CompilerConfig,
}
impl InterpreterBackend {
pub fn new(config: &CompilerConfig) -> Result<Self, TrustformersError> {
Ok(Self {
config: config.clone(),
})
}
}
impl JitBackend for InterpreterBackend {
fn compile_ir(&mut self, ir: IntermediateRepresentation) -> Result<Vec<u8>, TrustformersError> {
let serialized = serde_json::to_vec(&SerializableIR::from(ir))
.map_err(|e| invalid_format("json", e.to_string()))?;
Ok(serialized)
}
fn name(&self) -> &str {
"Interpreter"
}
fn supported_targets(&self) -> Vec<String> {
vec!["any".to_string()]
}
}
#[derive(Debug, Serialize, Deserialize)]
struct SerializableIR {
instructions: Vec<SerializableInstruction>,
dependencies: Vec<(usize, usize)>,
metadata: HashMap<String, String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct SerializableInstruction {
id: usize,
opcode: String,
inputs: Vec<Vec<usize>>,
outputs: Vec<Vec<usize>>,
attributes: HashMap<String, String>,
compute_cost: f64,
memory_cost: f64,
}
impl From<IntermediateRepresentation> for SerializableIR {
fn from(ir: IntermediateRepresentation) -> Self {
let instructions = ir
.instructions
.into_iter()
.map(|inst| SerializableInstruction {
id: inst.id,
opcode: format!("{:?}", inst.opcode),
inputs: inst.inputs,
outputs: inst.outputs,
attributes: inst.attributes,
compute_cost: inst.compute_cost,
memory_cost: inst.memory_cost,
})
.collect();
Self {
instructions,
dependencies: ir.dependencies,
metadata: ir.metadata,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compiler::{CompilerConfig, ComputationGraph};
#[test]
fn test_jit_compiler_creation() {
let config = CompilerConfig::default();
let result = JitCompiler::new(&config);
assert!(result.is_ok());
}
#[test]
fn test_ir_instruction_creation() {
let instruction = IRInstruction {
id: 0,
opcode: IROpcode::MatMul,
inputs: vec![vec![128, 256], vec![256, 512]],
outputs: vec![vec![128, 512]],
attributes: HashMap::new(),
compute_cost: 100.0,
memory_cost: 50.0,
};
assert_eq!(instruction.opcode, IROpcode::MatMul);
assert_eq!(instruction.inputs.len(), 2);
assert_eq!(instruction.outputs.len(), 1);
}
#[test]
fn test_cache_key_generation() {
let config = CompilerConfig::default();
let compiler = JitCompiler::new(&config).expect("operation failed in test");
let graph = ComputationGraph::new();
let cache_key = compiler.generate_cache_key(&graph);
assert!(cache_key.is_ok());
let key1 = cache_key.expect("operation failed in test");
let key2 = compiler.generate_cache_key(&graph).expect("operation failed in test");
assert_eq!(key1, key2); }
#[test]
fn test_compilation_statistics() {
let mut stats = CompilationStatistics::new();
assert_eq!(stats.cache_hit_rate(), 0.0);
stats.cache_hits = 3;
stats.cache_misses = 7;
assert_eq!(stats.cache_hit_rate(), 0.3);
}
#[test]
fn test_ir_opcodes() {
assert_ne!(IROpcode::Add, IROpcode::Mul);
assert_eq!(IROpcode::ReLU, IROpcode::ReLU);
}
#[test]
fn test_interpreter_backend() {
let config = CompilerConfig::default();
let backend = InterpreterBackend::new(&config);
assert!(backend.is_ok());
let backend = backend.expect("operation failed in test");
assert_eq!(backend.name(), "Interpreter");
assert!(!backend.supported_targets().is_empty());
}
}