use crate::error::CompileError;
use crate::partition_by_wire_ops::WIRE_DOMAIN;
use bb_ir::keys::CHAIN_DEPTH_KEY;
use bb_ir::proto::onnx::{attribute_proto, AttributeProto, ModelProto};
use bb_ir::syscall_ids::ATTR_DEADLINE_NS;
const SEND_OP: &str = "Send";
pub fn derive_wire_deadlines(
model: &mut ModelProto,
per_hop_budget_ns: u64,
) -> Result<usize, CompileError> {
let mut stamp_count = 0usize;
for func in model.functions.iter_mut() {
for node in func.node.iter_mut() {
if node.domain != WIRE_DOMAIN || node.op_type != SEND_OP {
continue;
}
let chain_depth = read_chain_depth(node).unwrap_or(1);
let deadline_ns = (chain_depth as i64).saturating_mul(per_hop_budget_ns as i64);
upsert_deadline(node, deadline_ns);
stamp_count += 1;
}
}
Ok(stamp_count)
}
fn read_chain_depth(node: &bb_ir::proto::onnx::NodeProto) -> Option<u64> {
node.metadata_props
.iter()
.find(|p| p.key == CHAIN_DEPTH_KEY)
.and_then(|p| p.value.parse().ok())
}
fn upsert_deadline(node: &mut bb_ir::proto::onnx::NodeProto, deadline_ns: i64) {
if let Some(existing) = node
.attribute
.iter_mut()
.find(|a| a.name == ATTR_DEADLINE_NS)
{
existing.i = deadline_ns;
return;
}
node.attribute.push(AttributeProto {
name: ATTR_DEADLINE_NS.to_string(),
i: deadline_ns,
r#type: attribute_proto::AttributeType::Int as i32,
..Default::default()
});
}