bb_runtime/node/
derivation.rs1use std::collections::BTreeMap;
9
10use bb_ir::proto::onnx::{FunctionProto, ModelProto};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
18pub struct GenericSlotSpec {
19 pub(crate) slot_id: u32,
20 pub(crate) required_trait: &'static str,
21}
22
23impl GenericSlotSpec {
24 pub fn new(slot_id: u32, required_trait: &'static str) -> Self {
26 Self {
27 slot_id,
28 required_trait,
29 }
30 }
31
32 pub fn slot_id(&self) -> u32 {
34 self.slot_id
35 }
36
37 pub fn required_trait(&self) -> &'static str {
39 self.required_trait
40 }
41}
42
43pub fn derive_generic_slots(model: &ModelProto) -> Vec<GenericSlotSpec> {
47 let mut seen: BTreeMap<u32, &'static str> = BTreeMap::new();
48 for function in &model.functions {
49 derive_generic_slots_in(function, &mut seen);
50 }
51 seen.into_iter()
52 .map(|(slot_id, required_trait)| GenericSlotSpec {
53 slot_id,
54 required_trait,
55 })
56 .collect()
57}
58
59fn derive_generic_slots_in(function: &FunctionProto, seen: &mut BTreeMap<u32, &'static str>) {
60 for node in &function.node {
61 let Some(slot_id) =
62 metadata_value(node, "ai.bytesandbrains.slot_id").and_then(|v| v.parse::<u32>().ok())
63 else {
64 continue;
65 };
66 let Some(rt) = metadata_value(node, "ai.bytesandbrains.required_trait") else {
67 continue;
68 };
69 let static_rt: Option<&'static str> = match rt {
72 "BackendRuntime" => Some("BackendRuntime"),
73 "ModelRuntime" => Some("ModelRuntime"),
74 "IndexRuntime" => Some("IndexRuntime"),
75 "AggregatorRuntime" => Some("AggregatorRuntime"),
76 "CodecRuntime" => Some("CodecRuntime"),
77 "DataSourceRuntime" => Some("DataSourceRuntime"),
78 "PeerSelectorRuntime" => Some("PeerSelectorRuntime"),
79 "ProtocolRuntime" => Some("ProtocolRuntime"),
80 "WireRuntime" => Some("WireRuntime"),
81 _ => None,
82 };
83 if let Some(rt_static) = static_rt {
84 seen.entry(slot_id).or_insert(rt_static);
85 }
86 }
87}
88
89pub fn derive_partition_name(model: &ModelProto) -> &str {
92 model
93 .functions
94 .first()
95 .map(|f| f.name.as_str())
96 .unwrap_or("")
97}
98
99fn metadata_value<'a>(node: &'a bb_ir::proto::onnx::NodeProto, key: &str) -> Option<&'a str> {
100 node.metadata_props
101 .iter()
102 .find(|p| p.key == key)
103 .map(|p| p.value.as_str())
104}
105