#![allow(clippy::needless_range_loop)]
use crate::{
compiler_utils::Compiler,
graph_tensor::GraphTensor,
op::{self, InputTensor, Operator},
shape::*,
tensor::Tensor,
};
use std::io::Write;
use colored::Colorize;
use itertools::Itertools;
use petgraph::{stable_graph::StableGraph, visit::EdgeRef, Direction};
use rustc_hash::{FxHashMap, FxHashSet};
use super::compiler_utils::{ToIds, ToIdsMut};
pub type MainGraph = StableGraph<Box<dyn Operator>, Dependency>;
pub use petgraph::stable_graph::NodeIndex;
#[derive(Debug, Default)]
pub struct Graph {
pub tensors: rustc_hash::FxHashMap<(NodeIndex, u8), Tensor>,
pub dyn_map: rustc_hash::FxHashMap<char, usize>,
pub graph: MainGraph,
pub no_delete: rustc_hash::FxHashSet<NodeIndex>,
pub to_retrieve: rustc_hash::FxHashSet<NodeIndex>,
#[allow(clippy::type_complexity)]
pub(crate) linearized_graph: Option<Vec<(NodeIndex, Vec<((NodeIndex, u8), ShapeTracker)>)>>,
consumers_map: Option<FxHashMap<(NodeIndex, u8), usize>>,
}
#[derive(Debug, Clone, Copy)]
#[allow(clippy::large_enum_variant)]
pub enum Dependency {
Data {
input_order: u8,
output_order: u8,
shape: ShapeTracker,
},
Schedule,
}
impl Dependency {
pub fn as_data(self) -> Option<(u8, u8, ShapeTracker)> {
if let Self::Data {
input_order,
output_order,
shape,
} = self
{
Some((input_order, output_order, shape))
} else {
None
}
}
pub fn is_schedule(&self) -> bool {
matches!(self, Self::Schedule)
}
}
impl Graph {
pub fn new() -> Graph {
Graph::default()
}
pub fn get_tensor(&mut self, id: NodeIndex, ind: u8) -> Option<Tensor> {
self.tensors.remove(&(id, ind))
}
pub fn get_tensor_ref(&self, id: NodeIndex, ind: u8) -> Option<&Tensor> {
self.tensors.get(&(id, ind))
}
pub fn drop_tensors<T: ToIds>(&mut self, tensors: T) {
for id in tensors.to_ids() {
self.tensors.remove(&(id, 0));
}
}
pub fn keep_tensors<T: ToIds>(&mut self, tensors: T) {
for id in tensors.to_ids() {
self.no_delete.insert(id);
}
}
pub fn retrieve_tensors<T: ToIds>(&mut self, tensors: T) {
for id in tensors.to_ids() {
self.no_delete.insert(id);
self.to_retrieve.insert(id);
}
}
pub fn set_tensor(&mut self, id: NodeIndex, ind: u8, tensor: Tensor) {
self.tensors.insert((id, ind), tensor);
}
pub fn set_dyn_dim(&mut self, dimension: char, val: usize) {
self.dyn_map.insert(dimension, val);
}
pub fn tensor<S: Shape>(&mut self) -> GraphTensor<S> {
self.named_tensor("Tensor")
}
pub fn named_tensor<S: Shape>(&mut self, name: &str) -> GraphTensor<S> {
GraphTensor {
id: self.graph.add_node(Box::new(op::Function(
format!("{name} Load"),
Box::new(|_| panic!("You must set a value for this tensor!")),
))),
graph_ref: self,
shape: S::to_tracker(),
_phantom: Default::default(),
}
}
pub fn compile<T: ToIdsMut, C: Compiler>(&mut self, compiler: C, remap: T) {
compiler.compile(self, remap);
self.toposort();
}
pub(crate) fn toposort(&mut self) {
self.linearized_graph = Some(
petgraph::algo::toposort(&self.graph, None)
.unwrap()
.into_iter()
.map(|node| {
(
node,
self.graph
.edges_directed(node, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|i| (e.source(), i)))
.sorted_by_key(|(_, (i, _, _))| *i)
.map(|(a, (_, b, c))| ((a, b), c))
.collect::<Vec<_>>(),
)
})
.collect(),
);
self.create_remaining_consumers_map();
}
pub fn swap_tensors<A: Shape, B: Shape>(&mut self, a: GraphTensor<A>, b: GraphTensor<B>) {
let a_t = self.tensors.remove(&(a.id, 0)); let b_t = self.tensors.remove(&(b.id, 0));
if let Some(a_t) = a_t {
self.tensors.insert((b.id, 0), a_t);
}
if let Some(b_t) = b_t {
self.tensors.insert((a.id, 0), b_t);
}
}
fn create_remaining_consumers_map(&mut self) {
self.consumers_map = Some(
self.graph
.node_indices()
.flat_map(|i| {
self.graph
.edges_directed(i, Direction::Outgoing)
.filter_map(|e| e.weight().as_data().map(|i| (e.source(), i)))
.group_by(|(_, (_, i, _))| *i)
.into_iter()
.map(|(ind, g)| ((i, ind), g.count()))
.collect::<Vec<_>>()
})
.collect(),
);
}
pub fn reset(&mut self) {
self.tensors.retain(|(n, _), _| self.no_delete.contains(n));
}
pub fn execute(&mut self) {
if self.linearized_graph.is_none() {
self.toposort();
}
let mut remaining_consumers = self.consumers_map.as_ref().unwrap().clone();
let mut dim_stack = Vec::new();
for (node, src_ids) in self.linearized_graph.as_ref().unwrap() {
if self.tensors.contains_key(&(*node, 0)) {
continue;
}
let mut srcs = Vec::new();
get_source_tensors(
&self.no_delete,
&mut self.tensors,
src_ids,
&remaining_consumers,
&mut srcs,
);
for (_, st) in srcs.iter_mut() {
st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack);
}
let tensors = self.graph.node_weight_mut(*node).unwrap().process(srcs);
for (i, tensor) in tensors.into_iter().enumerate() {
self.tensors.insert((*node, i as u8), tensor);
}
for (source, _) in src_ids {
*remaining_consumers.get_mut(source).unwrap() -= 1;
}
}
self.reset();
}
pub fn execute_no_delete(&mut self) {
if self.linearized_graph.is_none() {
self.toposort();
}
let mut dim_stack = Vec::new();
for (node, src_ids) in self.linearized_graph.as_ref().unwrap().iter() {
if self.tensors.contains_key(&(*node, 0)) {
continue;
}
let mut srcs = src_ids
.iter()
.map(|(id, st)| (InputTensor::Borrowed(self.tensors.get(id).unwrap()), *st))
.collect_vec();
for (_, st) in srcs.iter_mut() {
st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack);
}
let tensors = self.graph.node_weight_mut(*node).unwrap().process(srcs);
for (i, tensor) in tensors.into_iter().enumerate() {
self.tensors.insert((*node, i as u8), tensor);
}
}
}
pub fn execute_debug(&mut self) {
if self.linearized_graph.is_none() {
self.toposort();
}
let mut dim_stack = Vec::new();
let tensors_ptr = &mut self.tensors as *mut _;
let mut remaining_consumers = self.consumers_map.as_ref().unwrap().clone();
let mut op_times = FxHashMap::default();
println!(
"{:->2$} Executing {:->2$}",
"",
"",
(term_size::dimensions().unwrap().0 - " Executing ".len()) / 2
);
let start = std::time::Instant::now();
for (node, src_ids) in self.linearized_graph.as_ref().unwrap().iter() {
if self.tensors.contains_key(&(*node, 0)) {
continue;
}
let op_name = format!("{:?}", self.graph.node_weight(*node).unwrap());
print!("{}", op_name.bold().bright_green());
let mut srcs = Vec::new();
get_source_tensors(
&self.no_delete,
tensors_ptr,
src_ids,
&remaining_consumers,
&mut srcs,
);
for (_, st) in srcs.iter_mut() {
st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack);
}
let mut shapes_string = srcs
.iter()
.map(|(_, s)| {
format!(
"{:?}",
s.shape()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>()
)
})
.join(", ");
if !shapes_string.is_empty() {
shapes_string = format!(" ({shapes_string})");
}
print!("{shapes_string}");
std::io::stdout().flush().unwrap();
let now = std::time::Instant::now();
let tensors = self.graph.node_weight_mut(*node).unwrap().process(srcs);
let elapsed = now.elapsed();
println!(
"{:.>1$}",
if elapsed.as_secs() > 0 {
format!("{:.2}s", elapsed.as_secs_f32())
} else if elapsed.as_millis() > 0 {
format!("{}ms", elapsed.as_millis())
} else {
format!("{}µs", elapsed.as_micros())
}
.bold(),
term_size::dimensions().unwrap().0 - op_name.len() - shapes_string.len(),
);
for (i, tensor) in tensors.into_iter().enumerate() {
self.tensors.insert((*node, i as u8), tensor);
}
if let Some(t) = op_times.get_mut(&op_name) {
*t += elapsed;
} else {
op_times.insert(op_name, elapsed);
}
for (source, _) in src_ids {
*remaining_consumers.get_mut(source).unwrap() -= 1;
}
}
println!();
println!(
"{:->2$} Total Times {:->2$}",
"",
"",
(term_size::dimensions().unwrap().0 - " Total Times ".len()) / 2
);
for (name, elapsed) in op_times.into_iter().sorted_by(|(_, a), (_, b)| b.cmp(a)) {
print!("{}", name.bold().bright_green());
println!(
"{:.>1$}",
if elapsed.as_secs() > 0 {
format!("{:.2}s", elapsed.as_secs_f32())
} else if elapsed.as_millis() > 0 {
format!("{}ms", elapsed.as_millis())
} else {
format!("{}µs", elapsed.as_micros())
}
.bold(),
term_size::dimensions().unwrap().0 - name.len(),
);
}
println!(
"Total: {}",
if start.elapsed().as_secs() > 0 {
format!("{:.2}s", start.elapsed().as_secs_f32())
} else if start.elapsed().as_millis() > 0 {
format!("{}ms", start.elapsed().as_millis())
} else {
format!("{}µs", start.elapsed().as_micros())
}
.bold()
);
self.reset();
}
}
fn get_source_tensors(
no_delete: &FxHashSet<NodeIndex>,
tensors: *mut FxHashMap<(NodeIndex, u8), Tensor>,
src_ids: &[((NodeIndex, u8), ShapeTracker)],
remaining_consumers: &FxHashMap<(NodeIndex, u8), usize>,
srcs: &mut Vec<(InputTensor, ShapeTracker)>,
) {
for (id, sh) in src_ids {
if remaining_consumers[id] == 1 && !no_delete.contains(&id.0) {
srcs.push((
InputTensor::Owned(unsafe { tensors.as_mut().unwrap() }.remove(id).unwrap()),
*sh,
));
} else {
srcs.push((
InputTensor::Borrowed(unsafe { tensors.as_ref().unwrap() }.get(id).unwrap()),
*sh,
));
}
}
}