use crate::graph::{Graph, TensorID};
use crate::Float;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScheduleDirection {
Forward,
Reverse,
}
#[derive(Debug, Clone)]
pub struct ScheduledOp {
pub node_id: TensorID,
pub topo_rank: usize,
pub op_name: String,
pub inputs: Vec<TensorID>,
pub estimated_memory: usize,
}
#[derive(Debug, Clone)]
pub struct Schedule {
pub ops: Vec<ScheduledOp>,
pub direction: ScheduleDirection,
pub peak_memory_estimate: usize,
pub total_ops: usize,
}
impl Schedule {
pub fn empty(direction: ScheduleDirection) -> Self {
Self {
ops: Vec::new(),
direction,
peak_memory_estimate: 0,
total_ops: 0,
}
}
}
fn build_adjacency<F: Float>(graph: &Graph<F>) -> (Vec<HashSet<TensorID>>, Vec<Vec<TensorID>>) {
let nodes = graph.node_set.borrow();
let n = nodes.len();
let mut children: Vec<HashSet<TensorID>> = vec![HashSet::new(); n];
let mut parents: Vec<Vec<TensorID>> = vec![Vec::new(); n];
for node in nodes.iter() {
let nid = node.id;
for inc in &node.incoming_nodes {
let pid = inc.id;
if pid < n {
children[pid].insert(nid);
parents[nid].push(pid);
}
}
}
(children, parents)
}
fn compute_in_degree<F: Float>(graph: &Graph<F>) -> Vec<usize> {
let nodes = graph.node_set.borrow();
let n = nodes.len();
let mut in_deg = vec![0usize; n];
for node in nodes.iter() {
in_deg[node.id] = node.incoming_nodes.len();
}
in_deg
}
pub fn forward_schedule<F: Float>(graph: &Graph<F>) -> Schedule {
let nodes = graph.node_set.borrow();
let n = nodes.len();
if n == 0 {
return Schedule::empty(ScheduleDirection::Forward);
}
let (children, _parents) = {
drop(nodes);
build_adjacency(graph)
};
let mut in_deg = compute_in_degree(graph);
let mut ready: BinaryHeap<std::cmp::Reverse<(usize, TensorID)>> = BinaryHeap::new();
let nodes = graph.node_set.borrow();
for id in 0..n {
if in_deg[id] == 0 {
ready.push(std::cmp::Reverse((nodes[id].topo_rank, id)));
}
}
let mut ops = Vec::with_capacity(n);
while let Some(std::cmp::Reverse((rank, nid))) = ready.pop() {
let op_name = nodes[nid]
.op
.as_ref()
.map(|o| o.name().to_owned())
.unwrap_or_else(|| "source".to_owned());
let inputs: Vec<TensorID> = nodes[nid].incoming_nodes.iter().map(|inc| inc.id).collect();
ops.push(ScheduledOp {
node_id: nid,
topo_rank: rank,
op_name,
inputs,
estimated_memory: 0,
});
for &child in &children[nid] {
in_deg[child] = in_deg[child].saturating_sub(1);
if in_deg[child] == 0 {
ready.push(std::cmp::Reverse((nodes[child].topo_rank, child)));
}
}
}
let total = ops.len();
Schedule {
ops,
direction: ScheduleDirection::Forward,
peak_memory_estimate: 0,
total_ops: total,
}
}
pub fn reverse_schedule<F: Float>(graph: &Graph<F>) -> Schedule {
let mut fwd = forward_schedule(graph);
fwd.ops.reverse();
fwd.direction = ScheduleDirection::Reverse;
fwd
}
pub fn memory_optimal_schedule<F: Float>(graph: &Graph<F>) -> Schedule {
let nodes_ref = graph.node_set.borrow();
let n = nodes_ref.len();
if n == 0 {
return Schedule::empty(ScheduleDirection::Forward);
}
let node_data: Vec<(usize, String, Vec<TensorID>)> = nodes_ref
.iter()
.map(|nd| {
let op_name = nd
.op
.as_ref()
.map(|o| o.name().to_owned())
.unwrap_or_else(|| "source".to_owned());
let inputs: Vec<TensorID> = nd.incoming_nodes.iter().map(|inc| inc.id).collect();
(nd.topo_rank, op_name, inputs)
})
.collect();
drop(nodes_ref);
let (children, _parents) = build_adjacency(graph);
let mut in_deg = compute_in_degree(graph);
let mut ref_count: Vec<usize> = children.iter().map(|c| c.len()).collect();
let mut ready: Vec<TensorID> = (0..n).filter(|&id| in_deg[id] == 0).collect();
let mut ops = Vec::with_capacity(n);
let mut live_tensors: HashSet<TensorID> = HashSet::new();
let mut peak_memory: usize = 0;
let mut current_memory: usize = 0;
let tensor_unit = 1usize;
while !ready.is_empty() {
let best_idx = {
let mut best = 0usize;
let mut best_score = 0usize;
for (idx, &nid) in ready.iter().enumerate() {
let score = node_data[nid]
.2
.iter()
.filter(|&&pid| pid < n && ref_count[pid] == 1)
.count();
if score > best_score
|| (score == best_score && node_data[nid].0 < node_data[ready[best]].0)
{
best = idx;
best_score = score;
}
}
best
};
let nid = ready.swap_remove(best_idx);
let (rank, ref op_name, ref inputs) = node_data[nid];
live_tensors.insert(nid);
current_memory += tensor_unit;
for &pid in inputs {
if pid < n {
ref_count[pid] = ref_count[pid].saturating_sub(1);
if ref_count[pid] == 0 && live_tensors.contains(&pid) {
live_tensors.remove(&pid);
current_memory = current_memory.saturating_sub(tensor_unit);
}
}
}
if current_memory > peak_memory {
peak_memory = current_memory;
}
ops.push(ScheduledOp {
node_id: nid,
topo_rank: rank,
op_name: op_name.clone(),
inputs: inputs.clone(),
estimated_memory: 0,
});
for &child in &children[nid] {
in_deg[child] = in_deg[child].saturating_sub(1);
if in_deg[child] == 0 {
ready.push(child);
}
}
}
let total = ops.len();
Schedule {
ops,
direction: ScheduleDirection::Forward,
peak_memory_estimate: peak_memory,
total_ops: total,
}
}
pub fn compute_depth<F: Float>(graph: &Graph<F>) -> Vec<usize> {
let nodes = graph.node_set.borrow();
let n = nodes.len();
let mut depth = vec![0usize; n];
let mut order: Vec<TensorID> = (0..n).collect();
order.sort_by_key(|&id| nodes[id].topo_rank);
for &id in &order {
for inc in &nodes[id].incoming_nodes {
let pid = inc.id;
if pid < n && depth[pid] + 1 > depth[id] {
depth[id] = depth[pid] + 1;
}
}
}
depth
}
pub fn validate_schedule<F: Float>(graph: &Graph<F>, schedule: &Schedule) -> Result<(), String> {
let nodes = graph.node_set.borrow();
let n = nodes.len();
let mut seen: HashSet<TensorID> = HashSet::with_capacity(n);
for op in &schedule.ops {
if !seen.insert(op.node_id) {
return Err(format!("Duplicate node {} in schedule", op.node_id));
}
}
if seen.len() != n {
return Err(format!(
"Schedule contains {} nodes but graph has {}",
seen.len(),
n
));
}
if schedule.direction == ScheduleDirection::Forward {
let mut position: HashMap<TensorID, usize> = HashMap::with_capacity(n);
for (pos, op) in schedule.ops.iter().enumerate() {
position.insert(op.node_id, pos);
}
for op in &schedule.ops {
let my_pos = position.get(&op.node_id).copied().unwrap_or(usize::MAX);
for &inp in &op.inputs {
let inp_pos = position.get(&inp).copied().unwrap_or(usize::MAX);
if inp_pos >= my_pos {
return Err(format!(
"Dependency violation: node {} at position {} depends on node {} at position {}",
op.node_id, my_pos, inp, inp_pos
));
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::AsGraph;
use crate::tensor_ops as T;
use crate::VariableEnvironment;
#[test]
fn test_forward_schedule_linear_chain() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2, 2], ctx);
let b = T::ones(&[2, 2], ctx);
let c = a + b;
let _ = c * T::ones(&[2, 2], ctx);
let sched = forward_schedule(ctx.as_graph());
assert!(sched.total_ops > 0);
assert_eq!(sched.direction, ScheduleDirection::Forward);
assert!(validate_schedule(ctx.as_graph(), &sched).is_ok());
});
}
#[test]
fn test_reverse_schedule() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[3], ctx);
let b = T::ones(&[3], ctx);
let _ = a + b;
let sched = reverse_schedule(ctx.as_graph());
assert_eq!(sched.direction, ScheduleDirection::Reverse);
assert!(sched.total_ops > 0);
});
}
#[test]
fn test_memory_optimal_schedule() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[4, 4], ctx);
let b = T::ones(&[4, 4], ctx);
let c = a + b;
let d = a * b;
let _ = c + d;
let sched = memory_optimal_schedule(ctx.as_graph());
assert!(sched.total_ops > 0);
assert!(validate_schedule(ctx.as_graph(), &sched).is_ok());
});
}
#[test]
fn test_empty_graph_schedule() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let sched = forward_schedule(ctx.as_graph());
assert_eq!(sched.total_ops, 0);
});
}
#[test]
fn test_compute_depth() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let b = T::ones(&[2], ctx);
let c = a + b;
let d = c * T::ones(&[2], ctx);
let _ = d;
let depths = compute_depth(ctx.as_graph());
assert!(depths[c.id] > depths[a.id], "c should be deeper than a");
assert!(depths[d.id] > depths[c.id], "d should be deeper than c");
});
}
#[test]
fn test_validate_schedule_catches_missing_nodes() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[2], ctx);
let _ = T::ones(&[2], ctx);
let _ = a;
let mut sched = forward_schedule(ctx.as_graph());
if !sched.ops.is_empty() {
sched.ops.pop();
sched.total_ops = sched.ops.len();
assert!(validate_schedule(ctx.as_graph(), &sched).is_err());
}
});
}
#[test]
fn test_diamond_graph_schedule() {
let env = VariableEnvironment::<f32>::new();
env.run(|ctx| {
let a = T::zeros(&[3], ctx);
let b = a + T::ones(&[3], ctx);
let c = a * T::ones(&[3], ctx);
let _ = b + c;
let sched = forward_schedule(ctx.as_graph());
assert!(validate_schedule(ctx.as_graph(), &sched).is_ok());
});
}
}