use std::collections::{BinaryHeap, HashMap};
use std::sync::Arc;
use morok_ir::UOp;
use morok_ir::op::Op;
use morok_ir::types::ConstValue;
use morok_ir::uop::core::UOpKey;
mod priority {
pub const PARAM: i32 = -20;
pub const DEFINE_VAR: i32 = -19;
pub const DEFINE_LOCAL: i32 = -18;
pub const DEFINE_REG: i32 = -17;
pub const END: i32 = -5;
pub const LOAD: i32 = -1;
pub const DEFAULT: i32 = 0;
pub const STORE: i32 = 1;
pub const RANGE: i32 = 5;
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct OrderKey {
run_count: u64,
priority: i32,
arg_value: Option<i64>,
ideal_pos: usize,
id: u64,
}
pub fn linearize(sink: Arc<UOp>) -> Vec<Arc<UOp>> {
let nodes = sink.toposort();
if nodes.is_empty() {
return vec![sink];
}
#[allow(clippy::mutable_key_type)]
let mut consumers: HashMap<UOpKey, Vec<Arc<UOp>>> = HashMap::new();
#[allow(clippy::mutable_key_type)]
let mut out_degree: HashMap<UOpKey, usize> = HashMap::new();
#[allow(clippy::mutable_key_type)]
let mut priorities: HashMap<UOpKey, OrderKey> = HashMap::new();
let mut id_to_uop: HashMap<u64, Arc<UOp>> = HashMap::new();
for u in nodes.iter().rev() {
id_to_uop.insert(u.id, u.clone());
for src in u.op().sources() {
consumers.entry(UOpKey(src.clone())).or_default().push(u.clone());
}
let run_count = compute_run_count(u);
let (base_priority, arg_value) = get_priority(u);
priorities.insert(
UOpKey(u.clone()),
OrderKey { run_count, priority: base_priority, arg_value, ideal_pos: 0, id: u.id },
);
}
for node in &nodes {
let key = UOpKey(node.clone());
let degree = consumers.get(&key).map_or(0, |c| c.len());
out_degree.insert(key, degree);
}
let mut sorted: Vec<_> = nodes.to_vec();
sorted.sort_by_key(|u| {
priorities.get(&UOpKey(u.clone())).cloned().unwrap_or(OrderKey {
run_count: 0,
priority: priority::DEFAULT,
arg_value: None,
ideal_pos: 0,
id: u.id,
})
});
#[allow(clippy::mutable_key_type)]
let nkey: HashMap<UOpKey, usize> =
sorted.iter().enumerate().map(|(i, u)| (UOpKey(u.clone()), sorted.len() - 1 - i)).collect();
for (key, pos) in &nkey {
if let Some(order_key) = priorities.get_mut(key) {
order_key.ideal_pos = *pos;
}
}
let mut heap: BinaryHeap<OrderKey> = BinaryHeap::new();
let sink_key = priorities.get(&UOpKey(sink.clone())).cloned().unwrap_or(OrderKey {
run_count: 0,
priority: priority::DEFAULT,
arg_value: None,
ideal_pos: 0,
id: sink.id,
});
heap.push(sink_key);
let mut result = Vec::with_capacity(nodes.len());
let mut visited: std::collections::HashSet<u64> = std::collections::HashSet::new();
while let Some(order_key) = heap.pop() {
let u_id = order_key.id;
if visited.contains(&u_id) {
continue;
}
visited.insert(u_id);
let u = match id_to_uop.get(&u_id) {
Some(uop) => uop.clone(),
None => continue,
};
result.push(u.clone());
for src in u.op().sources() {
let src_key = UOpKey(src.clone());
if let Some(deg) = out_degree.get_mut(&src_key) {
*deg = deg.saturating_sub(1);
if *deg == 0 && !visited.contains(&src.id) {
if let Some(src_order_key) = priorities.get(&src_key) {
heap.push(src_order_key.clone());
}
}
}
}
}
result.reverse();
result
}
fn compute_run_count(uop: &Arc<UOp>) -> u64 {
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::properties::InScopeRangesProperty;
#[allow(clippy::mutable_key_type)]
let in_scope = InScopeRangesProperty::get(uop);
if in_scope.is_empty() {
return 1;
}
in_scope
.iter()
.map(|key| match key.0.vmax() {
ConstValue::Int(v) => (v + 1) as u64,
ConstValue::UInt(v) => v + 1,
_ => 1,
})
.product()
}
fn get_priority(uop: &Arc<UOp>) -> (i32, Option<i64>) {
match uop.op() {
Op::Param { slot, device: None, .. } => (priority::PARAM, Some(*slot as i64)),
Op::DefineVar { name, .. } => {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
name.hash(&mut hasher);
(priority::DEFINE_VAR, Some(hasher.finish() as i64))
}
Op::DefineLocal(_) => (priority::DEFINE_LOCAL, None),
Op::DefineReg { .. } => (priority::DEFINE_REG, None),
Op::Const(_) | Op::VConst { .. } => (priority::DEFAULT, None),
Op::End { .. } => (priority::END, None),
Op::Load { .. } => (priority::LOAD, None),
Op::Store { .. } => (priority::STORE, None),
Op::Range { .. } => (priority::RANGE, None),
_ => (priority::DEFAULT, None),
}
}
#[cfg(test)]
mod tests {
use super::*;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use smallvec::smallvec;
#[test]
fn test_linearize_single_const() {
let c = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let sink = UOp::sink(vec![c.clone()]);
let result = linearize(sink.clone());
assert_eq!(result.len(), 2); assert!(matches!(result[0].op(), Op::Const(_)));
assert!(matches!(result[1].op(), Op::Sink { .. }));
}
#[test]
fn test_linearize_simple_computation() {
let a = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let b = UOp::const_(DType::Float32, ConstValue::Float(2.0));
let sum = a.try_add(&b).unwrap();
let sink = UOp::sink(vec![sum]);
let result = linearize(sink);
assert_eq!(result.len(), 4);
assert!(matches!(result[0].op(), Op::Const(_)));
assert!(matches!(result[1].op(), Op::Const(_)));
assert!(matches!(result[2].op(), Op::Binary(_, _, _)));
assert!(matches!(result[3].op(), Op::Sink { .. }));
}
#[test]
fn test_linearize_with_range() {
let end_val = UOp::index_const(10);
let range = UOp::range(end_val, 0);
let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let end = value.end(smallvec![range.clone()]);
let sink = UOp::sink(vec![end]);
let result = linearize(sink);
let range_pos = result.iter().position(|u| matches!(u.op(), Op::Range { .. }));
let end_pos = result.iter().position(|u| matches!(u.op(), Op::End { .. }));
assert!(range_pos.is_some());
assert!(end_pos.is_some());
assert!(range_pos.unwrap() < end_pos.unwrap());
}
#[test]
fn test_linearize_preserves_dependencies() {
let c = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let c2 = UOp::const_(DType::Float32, ConstValue::Float(2.0));
let c3 = UOp::const_(DType::Float32, ConstValue::Float(3.0));
let a = c.try_add(&c2).unwrap();
let b = c.try_add(&c3).unwrap();
let sum = a.try_add(&b).unwrap();
let sink = UOp::sink(vec![sum.clone()]);
let result = linearize(sink);
let c_pos = result.iter().position(|u| std::sync::Arc::ptr_eq(u, &c));
let a_pos = result.iter().position(|u| std::sync::Arc::ptr_eq(u, &a));
let b_pos = result.iter().position(|u| std::sync::Arc::ptr_eq(u, &b));
let sum_pos = result.iter().position(|u| std::sync::Arc::ptr_eq(u, &sum));
assert!(c_pos.is_some());
assert!(a_pos.is_some());
assert!(b_pos.is_some());
assert!(sum_pos.is_some());
assert!(c_pos.unwrap() < a_pos.unwrap());
assert!(c_pos.unwrap() < b_pos.unwrap());
assert!(a_pos.unwrap() < sum_pos.unwrap());
assert!(b_pos.unwrap() < sum_pos.unwrap());
}
#[test]
#[allow(clippy::assertions_on_constants)]
fn test_priority_ordering() {
assert!(priority::PARAM < priority::DEFAULT);
assert!(priority::DEFAULT < priority::RANGE);
assert!(priority::END < priority::DEFAULT);
assert!(priority::LOAD < priority::DEFAULT);
assert!(priority::DEFAULT < priority::STORE);
}
}