use std::collections::BTreeMap;
use bb_ir::proto::onnx::{FunctionProto, ModelProto};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct GenericSlotSpec {
pub(crate) slot_id: u32,
pub(crate) required_trait: &'static str,
}
impl GenericSlotSpec {
pub fn new(slot_id: u32, required_trait: &'static str) -> Self {
Self {
slot_id,
required_trait,
}
}
pub fn slot_id(&self) -> u32 {
self.slot_id
}
pub fn required_trait(&self) -> &'static str {
self.required_trait
}
}
pub fn derive_generic_slots(model: &ModelProto) -> Vec<GenericSlotSpec> {
let mut seen: BTreeMap<u32, &'static str> = BTreeMap::new();
for function in &model.functions {
derive_generic_slots_in(function, &mut seen);
}
seen.into_iter()
.map(|(slot_id, required_trait)| GenericSlotSpec {
slot_id,
required_trait,
})
.collect()
}
fn derive_generic_slots_in(function: &FunctionProto, seen: &mut BTreeMap<u32, &'static str>) {
for node in &function.node {
let Some(slot_id) =
metadata_value(node, "ai.bytesandbrains.slot_id").and_then(|v| v.parse::<u32>().ok())
else {
continue;
};
let Some(rt) = metadata_value(node, "ai.bytesandbrains.required_trait") else {
continue;
};
let static_rt: Option<&'static str> = match rt {
"BackendRuntime" => Some("BackendRuntime"),
"ModelRuntime" => Some("ModelRuntime"),
"IndexRuntime" => Some("IndexRuntime"),
"AggregatorRuntime" => Some("AggregatorRuntime"),
"CodecRuntime" => Some("CodecRuntime"),
"DataSourceRuntime" => Some("DataSourceRuntime"),
"PeerSelectorRuntime" => Some("PeerSelectorRuntime"),
"ProtocolRuntime" => Some("ProtocolRuntime"),
"WireRuntime" => Some("WireRuntime"),
_ => None,
};
if let Some(rt_static) = static_rt {
seen.entry(slot_id).or_insert(rt_static);
}
}
}
pub fn derive_partition_name(model: &ModelProto) -> &str {
model
.functions
.first()
.map(|f| f.name.as_str())
.unwrap_or("")
}
fn metadata_value<'a>(node: &'a bb_ir::proto::onnx::NodeProto, key: &str) -> Option<&'a str> {
node.metadata_props
.iter()
.find(|p| p.key == key)
.map(|p| p.value.as_str())
}