use std::collections::{HashMap, HashSet};
use bb_ir::keys::{read_function_module_phase, MODULE_PHASE_BOOTSTRAP};
use bb_ir::proto::onnx::ModelProto;
use crate::error::CompileError;
const MODULE_CALL_DOMAIN: &str = "ai.bytesandbrains.module";
pub fn validate_bootstrap_composition(
model: &ModelProto,
target_name: &str,
) -> Result<(), CompileError> {
let by_name: HashMap<&str, &bb_ir::proto::onnx::FunctionProto> = model
.functions
.iter()
.map(|f| (f.name.as_str(), f))
.collect();
let root_bootstrap = format!("{target_name}__bootstrap");
let Some(root_fn) = by_name.get(root_bootstrap.as_str()) else {
return Ok(());
};
if read_function_module_phase(root_fn) != Some(MODULE_PHASE_BOOTSTRAP) {
return Ok(());
}
let mut black: HashSet<String> = HashSet::new();
let mut gray: Vec<String> = Vec::new();
walk(&root_bootstrap, &by_name, &mut gray, &mut black)
}
fn walk(
name: &str,
by_name: &HashMap<&str, &bb_ir::proto::onnx::FunctionProto>,
gray: &mut Vec<String>,
black: &mut HashSet<String>,
) -> Result<(), CompileError> {
if black.contains(name) {
return Ok(());
}
if let Some(pos) = gray.iter().position(|n| n == name) {
let mut involves: Vec<String> = gray[pos..].to_vec();
involves.push(name.to_string());
return Err(CompileError::BootstrapCompositionCycle { involves });
}
let Some(function) = by_name.get(name) else {
return Err(CompileError::BootstrapCompositionGap {
caller: gray.last().cloned().unwrap_or_else(|| name.to_string()),
target: name.to_string(),
});
};
gray.push(name.to_string());
for node in &function.node {
if node.domain != MODULE_CALL_DOMAIN {
continue;
}
let target = node.op_type.as_str();
if target.is_empty() {
continue;
}
if !by_name.contains_key(target) {
return Err(CompileError::BootstrapCompositionGap {
caller: name.to_string(),
target: target.to_string(),
});
}
let callee = by_name[target];
if read_function_module_phase(callee) == Some(MODULE_PHASE_BOOTSTRAP) {
walk(target, by_name, gray, black)?;
}
}
gray.pop();
black.insert(name.to_string());
Ok(())
}