use crate::{
instruction::Instruction,
layer::execution::LayerExecution,
model::{graph_model::GraphModel, layer_connection::LayerId},
tensor::TensorDesc,
utils::error::VKMLError,
};
use std::collections::{HashMap, HashSet};
pub type OperationId = usize;
pub type TensorId = usize;
pub struct DependencyGraph {
pub predecessors: Vec<Vec<OperationId>>,
pub successors: Vec<Vec<OperationId>>,
pub topological_order: Vec<OperationId>,
}
pub struct TensorGraph {
pub tensor_descs: Vec<TensorDesc>,
pub operations: Vec<Box<dyn Instruction>>,
pub input_tensor_ids: Vec<TensorId>,
pub output_tensor_ids: Vec<TensorId>,
pub tensor_to_layer: Vec<Option<LayerId>>,
pub operation_to_layer: Vec<LayerId>,
pub memory_requirements: usize,
}
impl TensorGraph {
pub fn from_graph_model(model: &GraphModel) -> Result<Self, VKMLError> {
if model.verified.is_none() {
return Err(VKMLError::TensorGraph("Model not verified".into()));
}
let execution_order = &model.verified.as_ref().unwrap().execution_order;
let mut tensor_descs: Vec<TensorDesc> = Vec::new();
let mut operations: Vec<Box<dyn Instruction>> = Vec::new();
let mut tensor_to_layer_map = Vec::new();
let mut operation_to_layer_map = Vec::new();
let mut global_tensor_map: HashMap<(LayerId, usize), TensorId> = HashMap::new();
let mut layer_executions: HashMap<LayerId, LayerExecution> = HashMap::new();
let mut memory_requirements = 0;
for &layer_id in execution_order {
let layer_wrapper = model.layers.get(&layer_id).ok_or_else(|| {
VKMLError::TensorGraph(format!("Layer {} not found in model", layer_id))
})?;
let input_descs: Vec<TensorDesc> = layer_wrapper
.input_connections
.iter()
.map(|conn| {
let src_layer_id = conn.get_layerid();
let src_output_idx = conn.get_outputidx();
let src_exec = layer_executions.get(&src_layer_id).ok_or_else(|| {
VKMLError::TensorGraph(format!(
"Source LayerExecution for {} not found when building layer {}",
src_layer_id, layer_id
))
})?;
let src_local_tensor_idx = src_exec.outputs[src_output_idx];
Ok(src_exec.tensors[src_local_tensor_idx].clone())
})
.collect::<Result<Vec<TensorDesc>, VKMLError>>()?;
let input_desc_refs: Vec<&TensorDesc> = input_descs.iter().collect();
let layer_exec = layer_wrapper
.layer
.build_layer_exec(model.batch_size, &input_desc_refs)?;
layer_executions.insert(layer_id, layer_exec);
}
let mut latest_producer_op_for_tensor: Vec<Option<OperationId>> = Vec::new();
for &layer_id in execution_order {
{
let layer_exec = layer_executions.get(&layer_id).unwrap();
for (local_idx, local_tensor_desc) in layer_exec.tensors.iter().enumerate() {
if !layer_exec.input_mappings.contains_key(&local_idx) {
let global_tensor_id = tensor_descs.len();
global_tensor_map.insert((layer_id, local_idx), global_tensor_id);
memory_requirements += local_tensor_desc.size_in_bytes();
tensor_descs.push(local_tensor_desc.clone());
tensor_to_layer_map.push(Some(layer_id));
if global_tensor_id >= latest_producer_op_for_tensor.len() {
latest_producer_op_for_tensor.resize(global_tensor_id + 1, None);
}
}
}
for (local_idx, (input_conn_idx, _output_idx_in_conn)) in &layer_exec.input_mappings
{
let input_connection =
&model.layers[&layer_id].input_connections[*input_conn_idx];
let src_layer_id = input_connection.get_layerid();
let src_local_output_idx =
layer_executions[&src_layer_id].outputs[input_connection.get_outputidx()];
let global_src_tensor_id =
global_tensor_map[&(src_layer_id, src_local_output_idx)];
global_tensor_map.insert((layer_id, *local_idx), global_src_tensor_id);
}
}
let layer_exec_mut = layer_executions.get_mut(&layer_id).unwrap();
for mut instruction in layer_exec_mut.instructions.drain(..) {
let global_op_id = operations.len();
let global_inputs: Vec<TensorId> = instruction
.get_input_tensor_ids()
.iter()
.map(|&local_id| global_tensor_map[&(layer_id, local_id)])
.collect();
let global_outputs: Vec<TensorId> = instruction
.get_output_tensor_ids()
.iter()
.map(|&local_id| global_tensor_map[&(layer_id, local_id)])
.collect();
instruction.remap_tensor_ids(&global_inputs, &global_outputs);
operations.push(instruction);
operation_to_layer_map.push(layer_id);
for &output_global_id in &global_outputs {
if output_global_id >= latest_producer_op_for_tensor.len() {
latest_producer_op_for_tensor.resize(output_global_id + 1, None);
}
latest_producer_op_for_tensor[output_global_id] = Some(global_op_id);
}
}
}
let mut input_tensors_model = Vec::new();
for &layer_id in &model.verified.as_ref().unwrap().entry_points {
let layer_exec = layer_executions.get(&layer_id).unwrap();
for (local_idx, _) in layer_exec.tensors.iter().enumerate() {
if !layer_exec.input_mappings.contains_key(&local_idx) {
let global_id = global_tensor_map[&(layer_id, local_idx)];
input_tensors_model.push(global_id);
}
}
}
input_tensors_model.sort_unstable();
let mut output_tensors_model = Vec::new();
let mut seen_outputs = HashSet::new();
for &layer_id in &model.verified.as_ref().unwrap().exit_points {
let layer_exec = layer_executions.get(&layer_id).unwrap();
for &local_output_idx in &layer_exec.outputs {
if let Some(global_id) = global_tensor_map.get(&(layer_id, local_output_idx))
&& seen_outputs.insert(*global_id)
{
output_tensors_model.push(*global_id);
}
}
}
output_tensors_model.sort_unstable();
Ok(TensorGraph {
tensor_descs,
operations,
input_tensor_ids: input_tensors_model,
output_tensor_ids: output_tensors_model,
tensor_to_layer: tensor_to_layer_map,
operation_to_layer: operation_to_layer_map,
memory_requirements,
})
}
pub fn dependency_graph(&self) -> DependencyGraph {
let num_ops = self.operations.len();
let mut predecessors: Vec<Vec<OperationId>> = vec![Vec::new(); num_ops];
let mut successors: Vec<Vec<OperationId>> = vec![Vec::new(); num_ops];
let mut in_degree: Vec<usize> = vec![0; num_ops];
let mut tensor_to_producers: Vec<Vec<OperationId>> =
vec![Vec::new(); self.tensor_descs.len()];
for (op_id, op) in self.operations.iter().enumerate() {
let outputs = op.get_output_tensor_ids();
for &t in &outputs {
if t >= tensor_to_producers.len() {
tensor_to_producers.resize_with(t + 1, Vec::new);
}
tensor_to_producers[t].push(op_id);
}
}
let mut seen_stamp: Vec<u32> = vec![0; num_ops];
let mut stamp: u32 = 1;
for (curr_op, op) in self.operations.iter().enumerate() {
stamp = stamp.wrapping_add(1);
if stamp == 0 {
seen_stamp.fill(0);
stamp = 1;
}
let inputs = op.get_input_tensor_ids();
for &t in &inputs {
if t >= tensor_to_producers.len() {
continue;
}
for &pred_op in &tensor_to_producers[t] {
if pred_op == curr_op {
continue;
}
if seen_stamp[pred_op] == stamp {
continue;
}
seen_stamp[pred_op] = stamp;
predecessors[curr_op].push(pred_op);
successors[pred_op].push(curr_op);
in_degree[curr_op] += 1;
}
}
}
let mut queue: Vec<OperationId> = Vec::with_capacity(num_ops);
for (op, °) in in_degree.iter().enumerate().take(num_ops) {
if deg == 0 {
queue.push(op);
}
}
let mut ordered: Vec<OperationId> = Vec::with_capacity(num_ops);
let mut head = 0;
while head < queue.len() {
let op = queue[head];
head += 1;
ordered.push(op);
for &succ in &successors[op] {
in_degree[succ] = in_degree[succ].saturating_sub(1);
if in_degree[succ] == 0 {
queue.push(succ);
}
}
}
if ordered.len() < num_ops {
eprintln!(
"Could not schedule all operations: {}/{}.",
ordered.len(),
num_ops
);
}
DependencyGraph {
predecessors,
successors,
topological_order: ordered,
}
}
pub fn get_instruction_or_panic(&self, idx: usize) -> &dyn Instruction {
self.operations
.get(idx)
.map(|boxed| boxed.as_ref())
.unwrap_or_else(|| panic!("Instruction index {} is out of bounds", idx))
}
pub fn get_tensor_producers(&self, tensor_id: usize) -> Vec<usize> {
self.operations
.iter()
.enumerate()
.filter_map(|(op_idx, op)| {
if op.get_output_tensor_ids().contains(&tensor_id) {
Some(op_idx)
} else {
None
}
})
.collect()
}
pub fn get_tensor_consumers(&self, tensor_id: usize) -> Vec<usize> {
self.operations
.iter()
.enumerate()
.filter_map(|(op_idx, op)| {
if op.get_input_tensor_ids().contains(&tensor_id) {
Some(op_idx)
} else {
None
}
})
.collect()
}
pub fn get_operation_inputs(&self, op_idx: usize) -> Vec<usize> {
if op_idx < self.operations.len() {
self.operations[op_idx].get_input_tensor_ids()
} else {
Vec::new()
}
}
pub fn get_operation_outputs(&self, op_idx: usize) -> Vec<usize> {
if op_idx < self.operations.len() {
self.operations[op_idx].get_output_tensor_ids()
} else {
Vec::new()
}
}
pub fn get_input_tensor_ids(&self) -> &[TensorId] {
&self.input_tensor_ids
}
pub fn get_output_tensor_ids(&self) -> &[TensorId] {
&self.output_tensor_ids
}
}