1use std::any::TypeId;
13use std::collections::HashMap;
14
15use crate::output::Output;
16use bb_ir::proto::onnx::tensor_proto::DataType as DT;
17use bb_ir::proto::onnx::{
18 attribute_proto, type_proto, AttributeProto, FunctionProto, NodeProto, StringStringEntryProto,
19 TensorShapeProto, TypeProto, ValueInfoProto,
20};
21use bb_ir::types::TypeNode;
22
23use crate::recorded::RecordedModule;
24
25const MODULE_INSTANCE_KEY: &str = "ai.bytesandbrains.module_instance";
28
29fn upsert_metadata(props: &mut Vec<StringStringEntryProto>, key: &str, value: &str) {
30 if let Some(entry) = props.iter_mut().find(|p| p.key == key) {
31 entry.value = value.to_string();
32 } else {
33 props.push(StringStringEntryProto {
34 key: key.to_string(),
35 value: value.to_string(),
36 });
37 }
38}
39
40pub struct Graph {
42 function: FunctionProto,
44
45 site_counter: u64,
47
48 instance_for_pointer: HashMap<(TypeId, *const ()), u32>,
53 next_instance_id: u32,
54
55 module_scope: Vec<String>,
59
60 sub_functions: Vec<FunctionProto>,
64
65 recording_target: Vec<Option<usize>>,
67
68 has_seen_function: bool,
72
73 pending_errors: Vec<crate::module::BuildError>,
75
76 mode_stack: Vec<RecordingMode>,
80
81 formal_binding_stack: Vec<HashMap<String, Output>>,
85
86 named_output_types: HashMap<(usize, String), (Output, &'static TypeNode)>,
91}
92
93#[derive(Clone, Copy, Debug, PartialEq, Eq)]
95pub enum RecordingMode {
96 Open,
98 Sealed,
100}
101
102impl Graph {
103 pub fn new() -> Self {
106 Self {
107 function: FunctionProto::default(),
108 site_counter: 0,
109 instance_for_pointer: HashMap::new(),
110 next_instance_id: 0,
111 module_scope: Vec::new(),
112 sub_functions: Vec::new(),
113 recording_target: Vec::new(),
114 has_seen_function: false,
115 pending_errors: Vec::new(),
116 mode_stack: Vec::new(),
117 named_output_types: HashMap::new(),
118 formal_binding_stack: Vec::new(),
119 }
120 }
121
122 fn current_mode(&self) -> RecordingMode {
124 self.mode_stack
125 .last()
126 .copied()
127 .unwrap_or(RecordingMode::Open)
128 }
129
130 pub fn take_pending_errors(&mut self) -> Vec<crate::module::BuildError> {
132 std::mem::take(&mut self.pending_errors)
133 }
134
135 pub fn output(&mut self, name: &str, handle: Output) {
138 let target_idx = self
139 .recording_target
140 .last()
141 .and_then(|t| *t)
142 .unwrap_or(usize::MAX);
143 let key = (target_idx, name.to_string());
144 if self.named_output_types.contains_key(&key) {
145 return;
146 }
147 let type_node = handle.type_node;
148
149 self.push_node(NodeProto {
152 op_type: bb_ir::syscall_ids::OP_PASS_THROUGH.into(),
153 domain: bb_ir::syscall_ids::SYSCALL_DOMAIN.into(),
154 input: vec![handle.name.clone()],
155 output: vec![name.to_string()],
156 ..Default::default()
157 });
158
159 let function: &mut FunctionProto = match target_idx {
160 usize::MAX => &mut self.function,
161 idx => &mut self.sub_functions[idx],
162 };
163 if function.output.iter().all(|n| n != name) {
164 function.output.push(name.to_string());
165 function
166 .value_info
167 .push(type_meta_to_value_info(name, type_node));
168 }
169 let registered = Output::new(name.to_string(), type_node);
170 self.named_output_types.insert(key, (registered, type_node));
171 }
172
173 pub fn net_out(&mut self, name: &str, peers: Output, value: Output) {
178 let value_type = value.type_node;
179 let port_name = name.to_string();
180 let handle_name = self.next_site_name();
181
182 let target_idx = self
183 .recording_target
184 .last()
185 .and_then(|t| *t)
186 .unwrap_or(usize::MAX);
187 let key = (target_idx, port_name.clone());
188 let already_registered = self.named_output_types.contains_key(&key);
189
190 self.push_node(NodeProto {
191 op_type: bb_ir::syscall_ids::OP_WIRE_SEND.into(),
192 domain: bb_ir::syscall_ids::WIRE_DOMAIN.into(),
193 input: vec![value.name.clone(), peers.name],
194 output: vec![port_name.clone(), handle_name.clone()],
195 ..Default::default()
196 });
197 self.declare_value_info(&port_name, value_type);
198 self.declare_value_info(&handle_name, &bb_ir::types::TYPE_WIRE_REQ_ID);
199
200 if !already_registered {
201 let function: &mut FunctionProto = match target_idx {
202 usize::MAX => &mut self.function,
203 idx => &mut self.sub_functions[idx],
204 };
205 if function.output.iter().all(|n| n != &port_name) {
206 function.output.push(port_name.clone());
207 }
208 let handle = Output::new(port_name.clone(), value_type);
209 self.named_output_types.insert(key, (handle, value_type));
210 }
211 }
212
213 pub fn bundle(&mut self, parts: &[Output]) -> Output {
217 assert!(
218 !parts.is_empty(),
219 "Graph::bundle: parts slice is empty; need >= 1 child Output",
220 );
221 let bundle_name = self.next_site_name();
222 let inputs: Vec<String> = parts.iter().map(|p| p.name.clone()).collect();
223
224 let child_count = parts.len();
225 let child_types = parts
226 .iter()
227 .map(|p| p.type_node.denotation)
228 .collect::<Vec<_>>()
229 .join(",");
230
231 self.push_node(NodeProto {
232 op_type: "Bundle".into(),
233 domain: "ai.bytesandbrains.composite".into(),
234 input: inputs,
235 output: vec![bundle_name.clone()],
236 attribute: vec![
237 attr_int(
238 "ai.bytesandbrains.composite.child_count",
239 child_count as i64,
240 ),
241 attr_string("ai.bytesandbrains.composite.child_types", &child_types),
242 ],
243 ..Default::default()
244 });
245 self.declare_value_info(&bundle_name, &bb_ir::types::TYPE_COMPOSITE);
246 Output::new(bundle_name, &bb_ir::types::TYPE_COMPOSITE)
247 }
248
249 pub fn unbundle(&mut self, composite: Output, part_types: &[&'static TypeNode]) -> Vec<Output> {
261 assert!(
262 !part_types.is_empty(),
263 "Graph::unbundle: part_types slice is empty; need >= 1 declared child type",
264 );
265 let child_count = part_types.len();
266 let port_names: Vec<String> = (0..child_count).map(|_| self.next_site_name()).collect();
267 let child_types = part_types
268 .iter()
269 .map(|t| t.denotation)
270 .collect::<Vec<_>>()
271 .join(",");
272
273 self.push_node(NodeProto {
274 op_type: "Unbundle".into(),
275 domain: "ai.bytesandbrains.composite".into(),
276 input: vec![composite.name],
277 output: port_names.clone(),
278 attribute: vec![
279 attr_int(
280 "ai.bytesandbrains.composite.child_count",
281 child_count as i64,
282 ),
283 attr_string("ai.bytesandbrains.composite.child_types", &child_types),
284 ],
285 ..Default::default()
286 });
287 for (port_name, type_node) in port_names.iter().zip(part_types.iter()) {
288 self.declare_value_info(port_name, type_node);
289 }
290 port_names
291 .into_iter()
292 .zip(part_types.iter())
293 .map(|(name, t)| Output::new(name, t))
294 .collect()
295 }
296
297 pub fn lookup_output(&self, name: &str) -> Option<Output> {
302 let target_idx = self
303 .recording_target
304 .last()
305 .and_then(|t| *t)
306 .unwrap_or(usize::MAX);
307 self.named_output_types
308 .get(&(target_idx, name.to_string()))
309 .map(|(h, _)| h.clone())
310 }
311
312 pub fn record_build_error(&mut self, err: crate::module::BuildError) {
317 self.pending_errors.push(err);
318 }
319
320 fn current_function_mut(&mut self) -> &mut FunctionProto {
324 match self.recording_target.last() {
325 Some(Some(idx)) => &mut self.sub_functions[*idx],
326 _ => &mut self.function,
327 }
328 }
329
330 pub fn finish(self) -> RecordedModule {
336 RecordedModule {
337 function: self.function,
338 sub_functions: self.sub_functions,
339 }
340 }
341
342 pub fn register_generic<T: 'static>(
346 &mut self,
347 instance: &T,
348 _required_trait: &'static str,
349 ) -> u32 {
350 let key = (TypeId::of::<T>(), (instance as *const T).cast::<()>());
351 if let Some(&id) = self.instance_for_pointer.get(&key) {
352 return id;
353 }
354 let id = self.next_instance_id;
355 self.next_instance_id += 1;
356 self.instance_for_pointer.insert(key, id);
357 self.current_function_mut()
358 .attribute
359 .push(format!("__slot_{id}"));
360 id
361 }
362
363 pub fn input(&mut self, name: &str) -> Output {
369 let bound_type = self
372 .formal_binding_stack
373 .last()
374 .and_then(|m| m.get(name))
375 .map(|h| h.type_node);
376
377 let build_vi = |name: &str| match bound_type {
378 Some(type_node) => type_meta_to_value_info(name, type_node),
379 None => opaque_value_info(name),
380 };
381
382 let active_targets: Vec<Option<usize>> = match self.current_mode() {
385 RecordingMode::Sealed => match self.recording_target.last() {
386 Some(slot) => vec![*slot],
387 None => Vec::new(),
388 },
389 RecordingMode::Open => self.recording_target.to_vec(),
390 };
391
392 let mut seen_root = false;
393 let touch_root = matches!(self.current_mode(), RecordingMode::Open);
394 for target in active_targets
395 .iter()
396 .chain(std::iter::once(&None).take(if touch_root { 1 } else { 0 }))
397 {
398 let function: &mut FunctionProto = match target {
399 Some(idx) => &mut self.sub_functions[*idx],
400 None => {
401 if seen_root {
402 continue;
403 }
404 seen_root = true;
405 &mut self.function
406 }
407 };
408 if function.input.iter().all(|n| n != name) {
409 function.input.push(name.to_string());
410 function.value_info.push(build_vi(name));
411 }
412 }
413
414 Output::new(name.to_string(), &bb_ir::types::TYPE_BYTES)
415 }
416
417 pub fn next_site_name(&mut self) -> String {
420 let n = self.site_counter;
421 self.site_counter += 1;
422 format!("v{n}")
423 }
424
425 pub fn declare_value_info(&mut self, name: &str, type_node: &'static bb_ir::types::TypeNode) {
428 let function = self.current_function_mut();
429 if function.value_info.iter().any(|v| v.name == name) {
430 return;
431 }
432 function
433 .value_info
434 .push(type_meta_to_value_info(name, type_node));
435 }
436
437 pub fn push_node(&mut self, mut node: NodeProto) {
441 if !self.module_scope.is_empty() {
442 let prefix = self.module_scope.join("_");
443 let existing = node
444 .metadata_props
445 .iter()
446 .find(|p| p.key == MODULE_INSTANCE_KEY)
447 .map(|p| p.value.clone());
448 let combined = match existing {
449 Some(inner) if !inner.is_empty() => format!("{prefix}_{inner}"),
450 _ => prefix,
451 };
452 upsert_metadata(&mut node.metadata_props, MODULE_INSTANCE_KEY, &combined);
453 }
454 self.current_function_mut().node.push(node);
455 }
456
457 pub fn with_function<F>(
468 &mut self,
469 name: &str,
470 bindings: &[(String, Output)],
471 body: F,
472 ) -> Vec<(String, String)>
473 where
474 F: FnOnce(&mut Graph),
475 {
476 let is_top_level_wrap = !self.has_seen_function
478 && self.recording_target.is_empty()
479 && self.function.node.is_empty()
480 && self.function.input.is_empty()
481 && self.function.attribute_proto.is_empty();
482
483 self.has_seen_function = true;
484
485 if is_top_level_wrap {
486 self.function.name = name.to_string();
488 self.module_scope.push(name.to_string());
489 let depth = self.module_scope.len();
490 body(self);
491 debug_assert_eq!(
492 self.module_scope.len(),
493 depth,
494 "with_function body must not mutate the scope stack",
495 );
496 self.module_scope.pop();
497 return Vec::new();
498 }
499
500 let target_idx = if let Some(idx) = self.sub_functions.iter().position(|f| f.name == name) {
504 idx
505 } else {
506 let new_idx = self.sub_functions.len();
507 self.sub_functions.push(FunctionProto {
508 name: name.to_string(),
509 ..Default::default()
510 });
511 new_idx
512 };
513
514 let is_duplicate = target_idx + 1 != self.sub_functions.len();
515 let recording_idx = if is_duplicate {
516 let scratch_idx = self.sub_functions.len();
517 self.sub_functions.push(FunctionProto::default());
518 scratch_idx
519 } else {
520 target_idx
521 };
522
523 let binding_map: HashMap<String, Output> = bindings
525 .iter()
526 .map(|(name, h)| (name.clone(), h.clone()))
527 .collect();
528 self.formal_binding_stack.push(binding_map);
529
530 self.recording_target.push(Some(recording_idx));
531 self.module_scope.push(name.to_string());
532 self.mode_stack.push(RecordingMode::Sealed);
536 let depth = self.module_scope.len();
537 body(self);
538 debug_assert_eq!(
539 self.module_scope.len(),
540 depth,
541 "with_function body must not mutate the scope stack",
542 );
543 self.mode_stack.pop();
544 self.module_scope.pop();
545 self.recording_target.pop();
546 self.formal_binding_stack.pop();
547
548 let recorded_outputs: Vec<String> = self.sub_functions[recording_idx].output.clone();
555
556 if is_duplicate {
557 self.sub_functions.pop();
558 }
559
560 let final_name = self.sub_functions[target_idx].name.clone();
566 let call_inputs: Vec<String> = bindings.iter().map(|(_, h)| h.name.clone()).collect();
567 let call_outputs: Vec<String> = (0..recorded_outputs.len())
568 .map(|_| self.next_site_name())
569 .collect();
570 let call = NodeProto {
571 op_type: final_name,
572 domain: "ai.bytesandbrains.module".into(),
573 input: call_inputs,
574 output: call_outputs.clone(),
575 ..Default::default()
576 };
577 self.push_node(call);
578
579 recorded_outputs.into_iter().zip(call_outputs).collect()
580 }
581
582 pub fn function(&self) -> &FunctionProto {
586 &self.function
587 }
588
589 #[cfg(test)]
593 pub(crate) fn sub_functions_for_test(&self) -> &[FunctionProto] {
594 &self.sub_functions
595 }
596}
597
598impl Default for Graph {
599 fn default() -> Self {
600 Self::new()
601 }
602}
603
604fn opaque_value_info(name: &str) -> bb_ir::proto::onnx::ValueInfoProto {
618 type_meta_to_value_info(name, &bb_ir::types::TYPE_BYTES)
619}
620
621fn type_meta_to_value_info(
622 name: &str,
623 type_node: &'static TypeNode,
624) -> bb_ir::proto::onnx::ValueInfoProto {
625 let value = if let Some(elem_type) = tensor_elem_from_denotation(type_node.denotation) {
626 type_proto::Value::TensorType(type_proto::Tensor {
631 elem_type,
632 shape: Some(TensorShapeProto::default()),
633 })
634 } else {
635 type_proto::Value::OpaqueType(type_proto::Opaque {
636 domain: "ai.bytesandbrains".into(),
637 name: type_node.denotation.into(),
638 })
639 };
640
641 ValueInfoProto {
642 name: name.to_string(),
643 r#type: Some(TypeProto {
644 value: Some(value),
645 denotation: type_node.denotation.into(),
646 }),
647 ..Default::default()
648 }
649}
650
651fn tensor_elem_from_denotation(denotation: &str) -> Option<i32> {
652 Some(match denotation {
653 "ai.bytesandbrains.tensor.f32" => DT::Float as i32,
654 "ai.bytesandbrains.tensor.f64" => DT::Double as i32,
655 "ai.bytesandbrains.tensor.i32" => DT::Int32 as i32,
656 "ai.bytesandbrains.tensor.i64" => DT::Int64 as i32,
657 "ai.bytesandbrains.tensor.bool" => DT::Bool as i32,
658 _ if denotation.starts_with("ai.bytesandbrains.tensor.") => DT::Undefined as i32,
659 _ => return None,
660 })
661}
662
663pub fn kv(key: &str, value: &str) -> StringStringEntryProto {
666 StringStringEntryProto {
667 key: key.to_string(),
668 value: value.to_string(),
669 }
670}
671
672pub fn attr_int(name: &str, value: i64) -> AttributeProto {
676 AttributeProto {
677 name: name.to_string(),
678 r#type: attribute_proto::AttributeType::Int as i32,
679 i: value,
680 ..Default::default()
681 }
682}
683
684pub fn attr_float(name: &str, value: f32) -> AttributeProto {
687 AttributeProto {
688 name: name.to_string(),
689 r#type: attribute_proto::AttributeType::Float as i32,
690 f: value,
691 ..Default::default()
692 }
693}
694
695pub fn attr_ints(name: &str, values: Vec<i64>) -> AttributeProto {
698 AttributeProto {
699 name: name.to_string(),
700 r#type: attribute_proto::AttributeType::Ints as i32,
701 ints: values,
702 ..Default::default()
703 }
704}
705
706pub fn attr_graph(name: &str, value: bb_ir::proto::onnx::GraphProto) -> AttributeProto {
709 AttributeProto {
710 name: name.to_string(),
711 r#type: attribute_proto::AttributeType::Graph as i32,
712 g: Some(value),
713 ..Default::default()
714 }
715}
716
717pub fn attr_string(name: &str, value: &str) -> AttributeProto {
723 AttributeProto {
724 name: name.to_string(),
725 r#type: attribute_proto::AttributeType::String as i32,
726 s: value.as_bytes().to_vec(),
727 ..Default::default()
728 }
729}
730
731pub fn attr_tensor(name: &str, value: bb_ir::proto::onnx::TensorProto) -> AttributeProto {
734 AttributeProto {
735 name: name.to_string(),
736 r#type: attribute_proto::AttributeType::Tensor as i32,
737 t: Some(value),
738 ..Default::default()
739 }
740}
741