#![allow(unused)]
use crate::{
DType, Map, Set, ZyxError,
backend::{BufferId, DeviceId, PoolId, ProgramId},
graph::{Node, search::EGraph},
runtime::Runtime,
shape::{Dim, UAxis},
tensor::TensorId,
};
use std::collections::BTreeMap;
use std::hash::BuildHasherDefault;
#[allow(dead_code)]
pub struct CompiledGraph {
pub nodes: Vec<CompiledNode>,
pub buffer_slots: Vec<BufferId>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct BufferSlot(pub u32);
#[derive(Debug, Clone)]
pub enum CompiledNode {
Allocate {
pool: PoolId,
size: Dim,
slot: BufferSlot,
},
Deallocate {
pool: PoolId,
slot: BufferSlot,
},
CopyMemory {
src_pool: PoolId,
src_buffer: BufferSlot,
dst_pool: PoolId,
dst_buffer: BufferSlot,
},
LaunchProgram {
program: ProgramId,
device: DeviceId,
},
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct CachedGraph {
pub nodes: Vec<Node>,
pub shapes: BTreeMap<usize, Box<[Dim]>>,
pub paddings: BTreeMap<usize, Box<[(i64, i64)]>>,
pub axes: BTreeMap<usize, Box<[UAxis]>>,
}
impl CachedGraph {
pub(super) fn shape(&self, mut tensor_id: usize) -> &[Dim] {
for _ in 0..1_000_000 {
if let Some(shape) = self.shapes.get(&tensor_id) {
return shape;
} else if let Node::Const { .. } = self.nodes[tensor_id] {
return &[1];
}
tensor_id = self.nodes[tensor_id].param1().into();
}
panic!("Shape of {tensor_id:?} could not be found. This is internal bug.")
}
pub(super) fn dtype(&self, mut tensor_id: usize) -> DType {
for _ in 0..100_000 {
match self.nodes[tensor_id] {
Node::Const { value } => return value.dtype(),
Node::Leaf { dtype } | Node::Cast { dtype, .. } => return dtype,
Node::Binary { bop, .. } if bop.returns_bool() => {
return DType::Bool;
}
_ => {
tensor_id = self.nodes[tensor_id].parameters().next().unwrap().into();
}
}
}
panic!("DType of {tensor_id:?} could not be found. This is internal bug.")
}
}
impl Runtime {
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn launch_or_store_graph_with_order(
&mut self,
_rcs: &Map<TensorId, u32>,
realized_nodes: &Set<TensorId>,
order: &[TensorId],
_to_eval: &Set<TensorId>,
) -> Result<(), ZyxError> {
let mut compacted = CachedGraph {
nodes: Vec::with_capacity(order.len()),
shapes: BTreeMap::new(),
paddings: BTreeMap::new(),
axes: BTreeMap::new(),
};
let mut id_map: Map<TensorId, usize> = Map::with_capacity_and_hasher(order.len(), BuildHasherDefault::new());
for (new_id, &nid) in order.iter().enumerate() {
let node = &self.graph[nid];
let reindexed = if realized_nodes.contains(&nid) {
Node::Leaf { dtype: self.graph.dtype(nid) }
} else {
match node {
Node::Const { value } => Node::Const { value: *value },
Node::Leaf { dtype } => Node::Leaf { dtype: *dtype },
Node::Expand { x } => Node::Expand { x: id_map[x].into() },
Node::Permute { x } => Node::Permute { x: id_map[x].into() },
Node::Reshape { x } => Node::Reshape { x: id_map[x].into() },
Node::Pad { x } => Node::Pad { x: id_map[x].into() },
Node::Reduce { x, rop } => Node::Reduce { x: id_map[x].into(), rop: *rop },
Node::Cast { x, dtype } => Node::Cast { x: id_map[x].into(), dtype: *dtype },
Node::Unary { x, uop } => Node::Unary { x: id_map[x].into(), uop: *uop },
Node::Binary { x, y, bop } => Node::Binary { x: id_map[x].into(), y: id_map[y].into(), bop: *bop },
Node::Custom { .. } => todo!(),
}
};
compacted.nodes.push(reindexed);
id_map.insert(nid, new_id);
if matches!(
node,
Node::Leaf { .. }
| Node::Expand { .. }
| Node::Permute { .. }
| Node::Reshape { .. }
| Node::Pad { .. }
| Node::Reduce { .. }
) {
compacted.shapes.insert(new_id, self.graph.shape(nid).into());
}
if matches!(node, Node::Pad { .. }) {
compacted.paddings.insert(new_id, self.graph.padding(nid).into());
}
if matches!(node, Node::Permute { .. } | Node::Reduce { .. }) {
compacted.axes.insert(new_id, self.graph.axes(nid).into());
}
}
if let Some(compiled_graph) = self.graph_cache.get(&compacted) {
return Ok(());
}
let mut egraph = EGraph::new(&compacted);
egraph.saturate();
let compiled_graph = egraph.extract();
self.graph_cache.insert(compacted.clone(), compiled_graph);
if let Some(compiled_graph) = self.graph_cache.get(&compacted) {
return Ok(());
}
Ok(())
}
}