Skip to main content

bb_compiler/
expand_ops.rs

1//! `expand_ops` โ€” materialize op-variant choices.
2//!
3//! Each `(domain, op_type)` that needs expansion has a matching arm
4//! in `lookup_expansion` returning the `ExpandFn` to apply. Plain
5//! `match` because the compiler runs at build time on a single
6//! thread โ€” a table behind a sync primitive would be overkill, and
7//! a `match` makes the catalog trivially auditable. All expansions
8//! stamp `EXPANDED_KEY = "true"` for idempotence.
9
10use crate::error::CompileError;
11use bb_ir::proto::onnx::{
12    attribute_proto::AttributeType, AttributeProto, GraphProto, NodeProto, StringStringEntryProto,
13};
14
15/// Idempotence stamp key.
16pub const EXPANDED_KEY: &str = "ai.bytesandbrains.expanded";
17
18const SYSCALL_DOMAIN: &str = "ai.bytesandbrains.syscall";
19
20/// Default Interval period when the attribute is missing
21/// (1 second in nanoseconds).
22const INTERVAL_DEFAULT_PERIOD_NS: i64 = 1_000_000_000;
23
24/// Per-op expansion function. Mutates the node in place; returns
25/// `Err` only on malformed input the compiler can't recover from.
26pub type ExpandFn = fn(&mut NodeProto) -> Result<(), CompileError>;
27
28/// Resolve a `(domain, op_type)` to its expansion function, or
29/// `None` when no expansion applies (most ops fall here - the
30/// pass is a no-op for them).
31fn lookup_expansion(domain: &str, op_type: &str) -> Option<ExpandFn> {
32    match (domain, op_type) {
33        (SYSCALL_DOMAIN, "Interval") => Some(expand_interval),
34        (SYSCALL_DOMAIN, "Constant") => Some(expand_constant),
35        _ => None,
36    }
37}
38
39/// Expand ops in-place per the static expansion registry. Pure.
40pub fn expand_ops(graph: &mut GraphProto) -> Result<(), CompileError> {
41    for node in graph.node.iter_mut() {
42        if metadata_value(node, EXPANDED_KEY).is_some() {
43            continue;
44        }
45        let Some(expand_fn) = lookup_expansion(&node.domain, &node.op_type) else {
46            continue;
47        };
48        expand_fn(node)?;
49        set_metadata(&mut node.metadata_props, EXPANDED_KEY, "true");
50    }
51    Ok(())
52}
53
54fn expand_interval(node: &mut NodeProto) -> Result<(), CompileError> {
55    if node.attribute.iter().any(|a| a.name == "period_ns") {
56        return Ok(());
57    }
58    node.attribute.push(AttributeProto {
59        name: "period_ns".to_string(),
60        r#type: AttributeType::Int as i32,
61        i: INTERVAL_DEFAULT_PERIOD_NS,
62        ..Default::default()
63    });
64    Ok(())
65}
66
67/// `Constant` expansion per COMPILER.md ยง5.2: every `Constant` node
68/// MUST carry a `value` attribute of type `TENSOR`. The expansion
69/// validates that requirement so downstream dispatch never sees a
70/// mis-shaped `Constant`. Nodes that already carry a non-empty
71/// `value` attribute are accepted; everything else is rejected with
72/// `CompileError::ExpansionFailed`.
73fn expand_constant(node: &mut NodeProto) -> Result<(), CompileError> {
74    let value_attr = node.attribute.iter().find(|a| a.name == "value");
75    let Some(attr) = value_attr else {
76        return Err(CompileError::ExpansionFailed {
77            domain: node.domain.clone(),
78            op_type: node.op_type.clone(),
79            reason: "Constant node is missing the required `value` attribute".into(),
80        });
81    };
82    if attr.r#type != AttributeType::Tensor as i32 {
83        return Err(CompileError::ExpansionFailed {
84            domain: node.domain.clone(),
85            op_type: node.op_type.clone(),
86            reason: format!(
87                "Constant `value` attribute must be TENSOR (got type tag {})",
88                attr.r#type
89            ),
90        });
91    }
92    if attr.t.is_none() {
93        return Err(CompileError::ExpansionFailed {
94            domain: node.domain.clone(),
95            op_type: node.op_type.clone(),
96            reason: "Constant `value` attribute carries no TensorProto payload".into(),
97        });
98    }
99    Ok(())
100}
101
102fn metadata_value(node: &NodeProto, key: &str) -> Option<String> {
103    node.metadata_props
104        .iter()
105        .find(|p| p.key == key)
106        .map(|p| p.value.clone())
107}
108
109fn set_metadata(props: &mut Vec<StringStringEntryProto>, key: &str, value: &str) {
110    if let Some(existing) = props.iter_mut().find(|p| p.key == key) {
111        existing.value = value.to_string();
112        return;
113    }
114    props.push(StringStringEntryProto {
115        key: key.to_string(),
116        value: value.to_string(),
117    });
118}
119