use std::cell::RefCell;
use std::collections::HashMap;
use crate::tensor::{DType, StorageHandle, Tensor};
pub type VarId = usize;
#[derive(Debug, Clone)]
pub enum FusedOp {
Input(VarId),
Add(VarId, VarId, VarId), Sub(VarId, VarId, VarId),
Mul(VarId, VarId, VarId),
Relu(VarId, VarId), Sigmoid(VarId, VarId),
Tanh(VarId, VarId),
Gelu(VarId, VarId),
LeakyRelu(VarId, VarId, f32), Scale(VarId, VarId, f32), Neg(VarId, VarId),
Output(VarId),
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum FusedOpTag {
Input,
Add, Sub, Mul,
Relu, Sigmoid, Tanh, Gelu,
LeakyRelu(u32), Scale(u32), Neg,
Output,
}
impl FusedOp {
pub fn tag(&self) -> FusedOpTag {
match self {
FusedOp::Input(_) => FusedOpTag::Input,
FusedOp::Add(..) => FusedOpTag::Add,
FusedOp::Sub(..) => FusedOpTag::Sub,
FusedOp::Mul(..) => FusedOpTag::Mul,
FusedOp::Relu(..) => FusedOpTag::Relu,
FusedOp::Sigmoid(..) => FusedOpTag::Sigmoid,
FusedOp::Tanh(..) => FusedOpTag::Tanh,
FusedOp::Gelu(..) => FusedOpTag::Gelu,
FusedOp::LeakyRelu(_, _, a) => FusedOpTag::LeakyRelu(a.to_bits()),
FusedOp::Scale(_, _, s) => FusedOpTag::Scale(s.to_bits()),
FusedOp::Neg(..) => FusedOpTag::Neg,
FusedOp::Output(_) => FusedOpTag::Output,
}
}
}
pub struct FusionBlock {
pub ops: Vec<FusedOp>,
pub num_inputs: usize,
pub num_outputs: usize,
pub numel: usize,
pub dtype: DType,
pub input_storages: Vec<StorageHandle>,
pub output_storages: Vec<StorageHandle>,
}
pub struct JitTracer {
ops: Vec<FusedOp>,
tensor_map: HashMap<usize, VarId>,
input_storages: Vec<StorageHandle>,
output_storages: Vec<StorageHandle>,
next_var: VarId,
num_inputs: usize,
num_outputs: usize,
numel: usize,
dtype: DType,
}
impl JitTracer {
fn new() -> Self {
Self {
ops: Vec::new(),
tensor_map: HashMap::new(),
input_storages: Vec::new(),
output_storages: Vec::new(),
next_var: 0,
num_inputs: 0,
num_outputs: 0,
numel: 0,
dtype: DType::F32,
}
}
fn fresh_var(&mut self) -> VarId {
let v = self.next_var;
self.next_var += 1;
v
}
pub fn get_or_create_input(&mut self, tensor: &Tensor) -> VarId {
let key = tensor.storage.ptr_id();
if let Some(&v) = self.tensor_map.get(&key) {
return v;
}
let v = self.fresh_var();
self.tensor_map.insert(key, v);
self.ops.push(FusedOp::Input(v));
self.input_storages.push(tensor.storage.clone());
self.num_inputs += 1;
if self.num_inputs == 1 {
self.numel = tensor.numel();
self.dtype = tensor.dtype();
}
v
}
pub fn record_unary(
&mut self,
src: &Tensor,
make_op: impl FnOnce(VarId, VarId) -> FusedOp,
) -> (VarId, StorageHandle) {
let src_v = self.resolve_var(src);
let dst_v = self.fresh_var();
self.ops.push(make_op(dst_v, src_v));
let storage = StorageHandle::new_deferred(dst_v, self.numel, self.dtype);
self.tensor_map.insert(storage.ptr_id(), dst_v);
(dst_v, storage)
}
pub fn record_binary(
&mut self,
lhs: &Tensor,
rhs: &Tensor,
make_op: impl FnOnce(VarId, VarId, VarId) -> FusedOp,
) -> (VarId, StorageHandle) {
let lhs_v = self.resolve_var(lhs);
let rhs_v = self.resolve_var(rhs);
let dst_v = self.fresh_var();
self.ops.push(make_op(dst_v, lhs_v, rhs_v));
let storage = StorageHandle::new_deferred(dst_v, self.numel, self.dtype);
self.tensor_map.insert(storage.ptr_id(), dst_v);
(dst_v, storage)
}
fn resolve_var(&mut self, tensor: &Tensor) -> VarId {
let key = tensor.storage.ptr_id();
if let Some(&v) = self.tensor_map.get(&key) {
return v;
}
self.get_or_create_input(tensor)
}
pub fn mark_output(&mut self, var_id: VarId, storage: &StorageHandle) {
self.ops.push(FusedOp::Output(var_id));
self.output_storages.push(storage.clone());
self.num_outputs += 1;
}
pub fn into_block(self) -> FusionBlock {
FusionBlock {
ops: self.ops,
num_inputs: self.num_inputs,
num_outputs: self.num_outputs,
numel: self.numel,
dtype: self.dtype,
input_storages: self.input_storages,
output_storages: self.output_storages,
}
}
}
thread_local! {
static JIT_TRACER: RefCell<Option<JitTracer>> = RefCell::new(None);
}
pub fn is_tracing() -> bool {
JIT_TRACER.with(|t| t.borrow().is_some())
}
pub fn with_tracer<F, R>(f: F) -> R
where
F: FnOnce(&mut JitTracer) -> R,
{
JIT_TRACER.with(|t| {
let mut borrow = t.borrow_mut();
let tracer = borrow.as_mut().expect("JIT tracer not active");
f(tracer)
})
}
pub fn compile<F>(f: F) -> Tensor
where
F: FnOnce() -> Tensor,
{
JIT_TRACER.with(|t| *t.borrow_mut() = Some(JitTracer::new()));
let result = f();
let mut tracer = JIT_TRACER.with(|t| t.borrow_mut().take().expect("JIT tracer lost"));
let result_var = tracer.resolve_var(&result);
tracer.mark_output(result_var, &result.storage);
let block = tracer.into_block();
super::flush_block(block);
result
}