use std::collections::HashMap;
use crate::artifact::BindingSpec;
use crate::error::CompileError;
use bb_ir::proto::onnx::{ModelProto, NodeProto};
use bb_ir::syscall_ids::{OP_PASS_THROUGH, OP_WIRE_SEND, SYSCALL_DOMAIN, WIRE_DOMAIN};
use bb_ir::types::TypeNode;
pub(crate) fn refine_polymorphic_value_info(
model: &mut ModelProto,
spec: &BindingSpec,
) -> Result<(), CompileError> {
let mut refinements: Vec<(String, &'static TypeNode)> = Vec::new();
for function in model.functions.iter() {
for node in function.node.iter() {
let has_required_trait =
metadata_value(node, "ai.bytesandbrains.required_trait").is_some();
let has_slot_id = metadata_value(node, "ai.bytesandbrains.slot_id").is_some();
if !has_required_trait && !has_slot_id {
continue;
}
if has_slot_id && !has_required_trait {
return Err(CompileError::MissingRequiredTraitMetadata {
node: node.name.clone(),
});
}
let required_trait =
metadata_value(node, "ai.bytesandbrains.required_trait").expect("checked above");
{
let matching: Vec<&str> = spec
.slots
.iter()
.filter(|s| s.role == required_trait)
.map(|s| s.slot.as_str())
.collect();
if matching.len() > 1 {
return Err(CompileError::AmbiguousRoleBinding {
role: required_trait.to_string(),
slot_names: matching.iter().map(|s| s.to_string()).collect(),
});
}
}
let slot_id_str =
metadata_value(node, "ai.bytesandbrains.slot_id").ok_or_else(|| {
CompileError::InvalidSlotId {
node: node.name.clone(),
value: String::new(),
}
})?;
let _slot_id: u32 = slot_id_str
.parse()
.map_err(|_| CompileError::InvalidSlotId {
node: node.name.clone(),
value: slot_id_str.to_string(),
})?;
let slot =
spec.lookup_by_role(required_trait)
.ok_or_else(|| CompileError::UnknownSlotId {
node: node.name.clone(),
slot_id: _slot_id,
})?;
let Some(port) = port_name_for_trait(required_trait, node)? else {
continue;
};
let Some(narrowed) = slot.storage_type_opt(port) else {
continue;
};
for output in &node.output {
refinements.push((output.clone(), narrowed));
}
}
}
propagate_through_value_preserving_ops(model, &mut refinements);
apply_refinements(model, &refinements);
Ok(())
}
fn propagate_through_value_preserving_ops(
model: &ModelProto,
refinements: &mut Vec<(String, &'static TypeNode)>,
) {
if refinements.is_empty() {
return;
}
let mut by_name: HashMap<String, &'static TypeNode> =
refinements.iter().map(|(n, t)| (n.clone(), *t)).collect();
loop {
let mut added = false;
for function in model.functions.iter() {
for node in function.node.iter() {
if !is_value_preserving(node) {
continue;
}
let Some(input_name) = node.input.first() else {
continue;
};
let Some(&narrowed) = by_name.get(input_name) else {
continue;
};
let Some(output_name) = node.output.first() else {
continue;
};
if by_name.contains_key(output_name) {
continue;
}
by_name.insert(output_name.clone(), narrowed);
refinements.push((output_name.clone(), narrowed));
added = true;
}
}
if !added {
break;
}
}
}
fn is_value_preserving(node: &NodeProto) -> bool {
matches!(
(node.domain.as_str(), node.op_type.as_str()),
(SYSCALL_DOMAIN, OP_PASS_THROUGH) | (WIRE_DOMAIN, OP_WIRE_SEND)
)
}
fn port_name_for_trait(
required_trait: &str,
node: &NodeProto,
) -> Result<Option<&'static str>, CompileError> {
match required_trait {
"IndexRuntime" => Ok(Some("vector")),
"AggregatorRuntime" => Ok(Some("element")),
"ModelRuntime" => Ok(Some("tensor")),
"DataSourceRuntime" => Ok(Some("sample")),
"BackendRuntime" => Ok(Some("tensor")),
"PeerSelectorRuntime" => Ok(None), "CodecRuntime" => {
let port_meta =
metadata_value(node, "ai.bytesandbrains.codec.port").ok_or_else(|| {
CompileError::MissingCodecPortMetadata {
node: node.name.clone(),
}
})?;
match port_meta {
"in" => Ok(Some("in")),
"out" => Ok(Some("out")),
other => Err(CompileError::InvalidCodecPort {
node: node.name.clone(),
value: other.to_string(),
}),
}
}
_ => Err(CompileError::UnknownRoleRuntime {
node: node.name.clone(),
role: required_trait.to_string(),
}),
}
}
fn apply_refinements(model: &mut ModelProto, refinements: &[(String, &'static TypeNode)]) {
if refinements.is_empty() {
return;
}
for function in model.functions.iter_mut() {
for vi in function.value_info.iter_mut() {
if let Some((_, narrowed)) = refinements.iter().find(|(name, _)| *name == vi.name) {
if let Some(ref mut t) = vi.r#type {
t.denotation = narrowed.denotation.to_string();
}
}
}
}
}
fn metadata_value<'a>(node: &'a NodeProto, key: &str) -> Option<&'a str> {
node.metadata_props
.iter()
.find(|p| p.key == key)
.map(|p| p.value.as_str())
}