#![allow(unused)]
use crate::{
Map, Set, ZyxError,
backend::{BufferId, DeviceId, PoolId, ProgramId},
graph::{Node, search::EGraph},
runtime::Runtime,
shape::{Dim, UAxis},
tensor::TensorId,
};
use std::collections::BTreeMap;
use std::hash::BuildHasherDefault;
#[allow(dead_code)]
pub struct CompiledGraph {
pub nodes: Vec<CompiledNode>,
pub buffer_slots: Vec<BufferId>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct BufferSlot(pub u32);
#[derive(Debug, Clone)]
pub enum CompiledNode {
Allocate {
pool: PoolId,
size: Dim,
slot: BufferSlot,
},
Deallocate {
pool: PoolId,
slot: BufferSlot,
},
CopyMemory {
src_pool: PoolId,
src_buffer: BufferSlot,
dst_pool: PoolId,
dst_buffer: BufferSlot,
},
LaunchProgram {
program: ProgramId,
device: DeviceId,
},
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct CachedGraph {
pub nodes: Vec<Node>,
pub shapes: BTreeMap<TensorId, Box<[Dim]>>,
pub paddings: BTreeMap<TensorId, Box<[(i64, i64)]>>,
pub axes: BTreeMap<TensorId, Box<[UAxis]>>,
}
impl Runtime {
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn launch_or_store_graph_with_order(
&mut self,
_rcs: &Map<TensorId, u32>,
realized_nodes: &Set<TensorId>,
order: &[TensorId],
_to_eval: &Set<TensorId>,
) -> Result<(), ZyxError> {
let mut compacted = CachedGraph {
nodes: Vec::with_capacity(order.len()),
shapes: BTreeMap::new(),
paddings: BTreeMap::new(),
axes: BTreeMap::new(),
};
let mut id_map: Map<TensorId, TensorId> = Map::with_capacity_and_hasher(order.len(), BuildHasherDefault::new());
for (i, &nid) in order.iter().enumerate() {
let new_id = TensorId::from(i);
let node = &self.graph[nid];
let reindexed = if realized_nodes.contains(&nid) {
Node::Leaf { dtype: self.graph.dtype(nid) }
} else {
match node {
Node::Const { value } => Node::Const { value: *value },
Node::Leaf { dtype } => Node::Leaf { dtype: *dtype },
Node::Expand { x } => Node::Expand { x: id_map[x] },
Node::Permute { x } => Node::Permute { x: id_map[x] },
Node::Reshape { x } => Node::Reshape { x: id_map[x] },
Node::Pad { x } => Node::Pad { x: id_map[x] },
Node::Reduce { x, rop } => Node::Reduce { x: id_map[x], rop: *rop },
Node::Cast { x, dtype } => Node::Cast { x: id_map[x], dtype: *dtype },
Node::Unary { x, uop } => Node::Unary { x: id_map[x], uop: *uop },
Node::Binary { x, y, bop } => Node::Binary { x: id_map[x], y: id_map[y], bop: *bop },
Node::Custom { .. } => todo!(),
}
};
compacted.nodes.push(reindexed);
id_map.insert(nid, new_id);
if matches!(
node,
Node::Leaf { .. }
| Node::Expand { .. }
| Node::Permute { .. }
| Node::Reshape { .. }
| Node::Pad { .. }
| Node::Reduce { .. }
) {
compacted.shapes.insert(new_id, self.graph.shape(nid).into());
}
if matches!(node, Node::Pad { .. }) {
compacted.paddings.insert(new_id, self.graph.padding(nid).into());
}
if matches!(node, Node::Permute { .. } | Node::Reduce { .. }) {
compacted.axes.insert(new_id, self.graph.axes(nid).into());
}
}
if let Some(compiled_graph) = self.graph_cache.get(&compacted) {
return Ok(());
}
let mut egraph = EGraph::new(&compacted);
egraph.saturate();
let compiled_graph = egraph.extract();
self.graph_cache.insert(compacted.clone(), compiled_graph);
if let Some(compiled_graph) = self.graph_cache.get(&compacted) {
return Ok(());
}
Ok(())
}
}