bb_compiler/
expand_ops.rs1use crate::error::CompileError;
11use bb_ir::proto::onnx::{
12 attribute_proto::AttributeType, AttributeProto, GraphProto, NodeProto, StringStringEntryProto,
13};
14
15pub const EXPANDED_KEY: &str = "ai.bytesandbrains.expanded";
17
18const SYSCALL_DOMAIN: &str = "ai.bytesandbrains.syscall";
19
20const INTERVAL_DEFAULT_PERIOD_NS: i64 = 1_000_000_000;
23
24pub type ExpandFn = fn(&mut NodeProto) -> Result<(), CompileError>;
27
28fn lookup_expansion(domain: &str, op_type: &str) -> Option<ExpandFn> {
32 match (domain, op_type) {
33 (SYSCALL_DOMAIN, "Interval") => Some(expand_interval),
34 (SYSCALL_DOMAIN, "Constant") => Some(expand_constant),
35 _ => None,
36 }
37}
38
39pub fn expand_ops(graph: &mut GraphProto) -> Result<(), CompileError> {
41 for node in graph.node.iter_mut() {
42 if metadata_value(node, EXPANDED_KEY).is_some() {
43 continue;
44 }
45 let Some(expand_fn) = lookup_expansion(&node.domain, &node.op_type) else {
46 continue;
47 };
48 expand_fn(node)?;
49 set_metadata(&mut node.metadata_props, EXPANDED_KEY, "true");
50 }
51 Ok(())
52}
53
54fn expand_interval(node: &mut NodeProto) -> Result<(), CompileError> {
55 if node.attribute.iter().any(|a| a.name == "period_ns") {
56 return Ok(());
57 }
58 node.attribute.push(AttributeProto {
59 name: "period_ns".to_string(),
60 r#type: AttributeType::Int as i32,
61 i: INTERVAL_DEFAULT_PERIOD_NS,
62 ..Default::default()
63 });
64 Ok(())
65}
66
67fn expand_constant(node: &mut NodeProto) -> Result<(), CompileError> {
74 let value_attr = node.attribute.iter().find(|a| a.name == "value");
75 let Some(attr) = value_attr else {
76 return Err(CompileError::ExpansionFailed {
77 domain: node.domain.clone(),
78 op_type: node.op_type.clone(),
79 reason: "Constant node is missing the required `value` attribute".into(),
80 });
81 };
82 if attr.r#type != AttributeType::Tensor as i32 {
83 return Err(CompileError::ExpansionFailed {
84 domain: node.domain.clone(),
85 op_type: node.op_type.clone(),
86 reason: format!(
87 "Constant `value` attribute must be TENSOR (got type tag {})",
88 attr.r#type
89 ),
90 });
91 }
92 if attr.t.is_none() {
93 return Err(CompileError::ExpansionFailed {
94 domain: node.domain.clone(),
95 op_type: node.op_type.clone(),
96 reason: "Constant `value` attribute carries no TensorProto payload".into(),
97 });
98 }
99 Ok(())
100}
101
102fn metadata_value(node: &NodeProto, key: &str) -> Option<String> {
103 node.metadata_props
104 .iter()
105 .find(|p| p.key == key)
106 .map(|p| p.value.clone())
107}
108
109fn set_metadata(props: &mut Vec<StringStringEntryProto>, key: &str, value: &str) {
110 if let Some(existing) = props.iter_mut().find(|p| p.key == key) {
111 existing.value = value.to_string();
112 return;
113 }
114 props.push(StringStringEntryProto {
115 key: key.to_string(),
116 value: value.to_string(),
117 });
118}
119