use crate::backend::cpu::CpuScalarBackend;
use crate::backend::{Backend, TensorStore};
use crate::ir::SemanticGraph;
use crate::object::{ObjectMeta, Tensor};
use crate::op::Operator;
use crate::planner::ExecutionPlan;
use crate::planner::HeuristicPlanner;
use crate::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TensorHandle {
id: usize,
}
impl TensorHandle {
pub fn id(&self) -> usize {
self.id
}
}
pub struct GraphBuilder {
graph: SemanticGraph,
store: TensorStore<i64>,
}
impl GraphBuilder {
pub fn input(&mut self, meta: ObjectMeta) -> Result<TensorHandle> {
let id = self.graph.add_input(meta);
Ok(TensorHandle { id })
}
pub fn input_tensor(&mut self, t: Tensor<i64>) -> Result<TensorHandle> {
let id = self.graph.add_input(t.meta.clone());
self.store.insert(id, t);
Ok(TensorHandle { id })
}
pub fn op<O: Operator>(&mut self, op: O, inputs: &[TensorHandle]) -> Result<TensorHandle> {
let ids: Vec<usize> = inputs.iter().map(|h| h.id).collect();
let outputs = self.graph.add_op(op, &ids)?;
if outputs.len() != 1 {
return Err(Error::ir(format!(
"DSL `op` requires exactly one output, got {}",
outputs.len()
)));
}
Ok(TensorHandle { id: outputs[0] })
}
pub fn graph(&self) -> &SemanticGraph {
&self.graph
}
pub fn store(&self) -> &TensorStore<i64> {
&self.store
}
pub fn store_mut(&mut self) -> &mut TensorStore<i64> {
&mut self.store
}
pub fn add(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::AddOp, &[lhs, rhs])
}
pub fn sub(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::SubOp, &[lhs, rhs])
}
pub fn mul(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::MulOp, &[lhs, rhs])
}
pub fn div(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::DivOp, &[lhs, rhs])
}
pub fn reshape(&mut self, input: TensorHandle, target: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::ReshapeOp, &[input, target])
}
pub fn transpose(&mut self, input: TensorHandle, axes: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::TransposeOp, &[input, axes])
}
pub fn flatten(&mut self, input: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::FlattenOp, &[input])
}
pub fn relu(&mut self, input: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::ReluOp, &[input])
}
pub fn gelu(&mut self, input: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::GeluOp, &[input])
}
pub fn softmax(&mut self, input: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::SoftmaxOp, &[input])
}
pub fn layer_norm(
&mut self,
input: TensorHandle,
gamma: TensorHandle,
beta: TensorHandle,
) -> Result<TensorHandle> {
self.op(crate::op::LayerNormOp, &[input, gamma, beta])
}
pub fn gather(
&mut self,
input: TensorHandle,
indices: TensorHandle,
axis: TensorHandle,
) -> Result<TensorHandle> {
self.op(crate::op::GatherOp, &[input, indices, axis])
}
pub fn scatter(
&mut self,
input: TensorHandle,
indices: TensorHandle,
axis: TensorHandle,
) -> Result<TensorHandle> {
self.op(crate::op::ScatterOp, &[input, indices, axis])
}
pub fn sum(&mut self, input: TensorHandle, axis: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::SumOp, &[input, axis])
}
pub fn mean(&mut self, input: TensorHandle, axis: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::MeanOp, &[input, axis])
}
pub fn max(&mut self, input: TensorHandle, axis: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::MaxOp, &[input, axis])
}
pub fn min(&mut self, input: TensorHandle, axis: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::MinOp, &[input, axis])
}
pub fn matmul(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
self.op(crate::op::MatmulOp, &[lhs, rhs])
}
}
pub struct CompiledGraph {
graph: SemanticGraph,
store: TensorStore<i64>,
plan: ExecutionPlan,
output: usize,
}
impl CompiledGraph {
pub fn graph(&self) -> &SemanticGraph {
&self.graph
}
pub fn plan(&self) -> &ExecutionPlan {
&self.plan
}
pub fn store(&self) -> &TensorStore<i64> {
&self.store
}
pub fn output_tensor(&self) -> Result<&Tensor<i64>> {
self.store.get(self.output)
}
pub fn output_id(&self) -> usize {
self.output
}
}
pub fn run_graph<F>(backend: &CpuScalarBackend, f: F) -> Result<CompiledGraph>
where
F: FnOnce(&mut GraphBuilder) -> Result<TensorHandle>,
{
let mut builder = GraphBuilder {
graph: SemanticGraph::new(),
store: TensorStore::new(),
};
let out = f(&mut builder)?;
let plan = HeuristicPlanner::new(backend.capabilities()).plan(&builder.graph)?;
backend.execute_i64(&builder.graph, &plan, &mut builder.store)?;
Ok(CompiledGraph {
graph: builder.graph,
store: builder.store,
plan,
output: out.id,
})
}