use std::collections::HashSet;
use crate::error::CompileError;
use crate::infer_peer_classes::infer_peer_classes;
use crate::{
analyze_wire_edges, derive_wire_deadlines, expand_ops, inline_for_partition,
insert_async_deadlines, insert_backoff_gate_rx, insert_backoff_gate_tx, insert_dedup_gate_rx,
insert_peer_health_gate_rx, insert_peer_health_gate_tx, partition_by_wire_ops, resolve_slots,
synthesize_wire_recvs, validate, validate_bootstrap_composition, validate_runtime_complete,
verify_no_dangling_calls,
};
use bb_dsl::recorded::RecordedModule;
use bb_ir::proto::onnx::{FunctionProto, GraphProto, ModelProto};
pub const CANONICAL_PASS_NAMES: &[&str] = &[
"inline_for_partition",
"derive_wire_deadlines",
"validate",
"expand_ops",
"type_solver",
"infer_peer_classes",
"synthesize_wire_recvs",
"partition_by_wire_ops",
"resolve_slots",
"analyze_wire_edges",
"insert_dedup_gate_rx",
"insert_peer_health_gate_rx",
"insert_backoff_gate_rx",
"insert_peer_health_gate_tx",
"insert_backoff_gate_tx",
"insert_async_deadlines",
"validate_runtime_complete",
];
pub(crate) fn run_pipeline_with_options(
recorded: RecordedModule,
_module_name: String,
enabled: &HashSet<String>,
per_hop_budget_ns: u64,
strict_types: bool,
) -> Result<Vec<ModelProto>, CompileError> {
let RecordedModule {
function,
sub_functions: dsl_sub_functions,
} = recorded;
let on = |name: &str| enabled.contains(name);
let mut temp = ModelProto::default();
temp.functions.push(function);
temp.functions.extend(dsl_sub_functions);
if on("inline_for_partition") {
inline_for_partition(&mut temp)?;
}
if on("derive_wire_deadlines") {
derive_wire_deadlines(&mut temp, per_hop_budget_ns)?;
}
bb_ir::verify::types(&temp).map_err(|e| CompileError::Internal {
detail: format!("verify::types failed at frontend seam: {e}"),
})?;
bb_ir::verify::function_calls(&temp).map_err(|e| CompileError::Internal {
detail: format!("verify::function_calls failed at frontend seam: {e}"),
})?;
let mut models: Vec<ModelProto> = Vec::new();
let shared_functions: Vec<FunctionProto> = temp.functions.iter().skip(1).cloned().collect();
let root = temp
.functions
.first()
.ok_or_else(|| CompileError::Internal {
detail: "compiler received an empty function table".into(),
})?;
let root_name = root.name.clone();
validate_bootstrap_composition(&temp, &root_name)?;
let root = temp
.functions
.into_iter()
.next()
.expect("non-empty checked above");
let target_models = process_target(root, &root_name, enabled, &shared_functions, strict_types)?;
models.extend(target_models);
if models.is_empty() {
return Err(CompileError::Internal {
detail: "compiler produced no partitions - recorded function was empty".into(),
});
}
Ok(models)
}
fn process_target(
mut target_function: FunctionProto,
target_name: &str,
enabled: &HashSet<String>,
shared_functions: &[FunctionProto],
strict_types: bool,
) -> Result<Vec<ModelProto>, CompileError> {
let on = |name: &str| enabled.contains(name);
if on("validate") {
let view = function_to_graph_view(&target_function);
validate(&view).map_err(CompileError::Validation)?;
}
let view = function_to_graph_view(&target_function);
let mut graph = view;
if on("expand_ops") {
expand_ops(&mut graph)?;
}
let type_solver_ran = if on("type_solver") {
run_type_solver(&mut graph, strict_types)?;
true
} else {
false
};
if on("infer_peer_classes") {
infer_peer_classes(&mut graph)?;
}
if on("synthesize_wire_recvs") {
synthesize_wire_recvs(&mut graph)?;
if type_solver_ran {
let fresh_solution = run_type_solver(&mut graph, strict_types)?;
check_wire_edge_types(&graph, &fresh_solution)?;
}
}
target_function = merge_graph_into_function(target_function, graph);
let view = function_to_graph_view(&target_function);
let mut analysis = partition_by_wire_ops(&view)?;
if on("resolve_slots") {
resolve_slots(&target_function)?;
}
if on("analyze_wire_edges") {
for sub_graph in analysis.per_role.values_mut() {
analyze_wire_edges(sub_graph, &analysis.wire_edges)?;
}
}
let mut models: Vec<ModelProto> = Vec::new();
for (role, mut sub_graph) in analysis.per_role {
let hoisted: Vec<FunctionProto> = Vec::new();
if on("insert_dedup_gate_rx") {
insert_dedup_gate_rx(&mut sub_graph)?;
}
if on("insert_peer_health_gate_rx") {
insert_peer_health_gate_rx(&mut sub_graph)?;
}
if on("insert_backoff_gate_rx") {
insert_backoff_gate_rx(&mut sub_graph)?;
}
if on("insert_peer_health_gate_tx") {
insert_peer_health_gate_tx(&mut sub_graph)?;
}
if on("insert_backoff_gate_tx") {
insert_backoff_gate_tx(&mut sub_graph)?;
}
if on("insert_async_deadlines") {
insert_async_deadlines(&mut sub_graph)?;
}
if on("validate_runtime_complete") {
validate_runtime_complete(&sub_graph)?;
}
let (composite_name, mut partition_function) =
split_partition(&target_function, role.clone(), &sub_graph, target_name);
partition_function.name = composite_name;
let mut all_hoisted = hoisted;
all_hoisted.extend(shared_functions.iter().cloned());
let model = wrap_as_model(partition_function, all_hoisted);
verify_no_dangling_calls(&model)?;
models.push(model);
}
Ok(models)
}
use bb_ir::proto::function_to_graph_view;
fn merge_graph_into_function(mut function: FunctionProto, graph: GraphProto) -> FunctionProto {
function.node = graph.node;
function
}
fn split_partition(
base: &FunctionProto,
role: String,
sub_graph: &GraphProto,
module_name: &str,
) -> (String, FunctionProto) {
let mut function = base.clone();
function.node = sub_graph.node.clone();
let base_name = if role == "@default" || role == bb_ir::peer_class::SELF_CLASS {
module_name.to_string()
} else if role == module_name || role.starts_with(&format!("{module_name}_")) {
role
} else {
format!("{module_name}_{role}")
};
let content_hash = crate::function_dedup::hash_node_bodies(&sub_graph.node);
let composite_name = format!("{base_name}#{content_hash:016x}");
(composite_name, function)
}
fn wrap_as_model(function: FunctionProto, hoisted: Vec<FunctionProto>) -> ModelProto {
let mut functions = Vec::with_capacity(1 + hoisted.len());
functions.push(function);
functions.extend(hoisted);
ModelProto {
functions,
..Default::default()
}
}
fn run_type_solver(
graph: &mut GraphProto,
strict: bool,
) -> Result<crate::type_solver::TypeSolution, CompileError> {
let decl_for_op = |_: &str, _: &str| -> Option<&'static bb_ir::atomic::AtomicOpDecl> { None };
let mut solver = crate::type_solver::TypeSolver::from_graph(graph, decl_for_op)
.map_err(CompileError::from)?;
solver.seed_from_value_info(graph);
let solution = if strict {
solver.solve_strict().map_err(CompileError::from)?
} else {
solver.solve().map_err(CompileError::from)?
};
crate::type_solver::TypeSolver::apply_solution_to_value_info(graph, &solution);
Ok(solution)
}
pub(crate) fn check_wire_edge_types(
graph: &GraphProto,
solution: &crate::type_solver::TypeSolution,
) -> Result<(), CompileError> {
const SYNTHESIZED_FROM_KEY: &str = "ai.bytesandbrains.synthesized_from_send";
const WIRE_DOMAIN: &str = "ai.bytesandbrains.wire";
for node in &graph.node {
if node.domain != WIRE_DOMAIN || node.op_type != "Recv" {
continue;
}
let Some(src_val) = node
.metadata_props
.iter()
.find(|p| p.key == SYNTHESIZED_FROM_KEY)
.map(|p| p.value.as_str())
else {
continue;
};
let Some(dst_val) = node.output.first().filter(|s| !s.is_empty()) else {
continue;
};
let Some(actual_node) = solution.type_of(src_val) else {
continue;
};
let Some(expected_node) = solution.type_of(dst_val) else {
continue;
};
if !actual_node.is_concrete() || !expected_node.is_concrete() {
continue;
}
if std::ptr::eq(actual_node, expected_node) {
continue;
}
return Err(CompileError::IncompatibleStorageOnEdge {
src: src_val.to_string(),
dst: dst_val.to_string(),
expected_id: expected_node.id,
actual_id: actual_node.id,
});
}
Ok(())
}