1use std::collections::HashSet;
17
18use crate::artifact::BindingSpec;
19use crate::error::CompileError;
20use crate::refine_polymorphic_value_info::refine_polymorphic_value_info;
21use crate::resolve_component_dependencies::resolve_component_dependencies;
22use crate::runner::{run_pipeline_with_options, CANONICAL_PASS_NAMES};
23use crate::validate_all_slots_bound::validate_all_slots_bound;
24use bb_dsl::recorded::RecordedModule;
25use bb_ir::proto::onnx::ModelProto;
26
27fn merge_partitions_into_one(partitions: Vec<ModelProto>) -> Result<ModelProto, CompileError> {
31 let mut iter = partitions.into_iter();
32 let Some(mut head) = iter.next() else {
33 return Ok(ModelProto::default());
34 };
35 for next in iter {
36 for fn_b in &next.functions {
38 if head.functions.iter().any(|fn_a| fn_a.name == fn_b.name) {
39 return Err(CompileError::Internal {
40 detail: format!(
41 "duplicate function name after partition merge: {}",
42 fn_b.name
43 ),
44 });
45 }
46 }
47 head.functions.extend(next.functions);
48 head.metadata_props.extend(next.metadata_props);
49 }
50 Ok(head)
51}
52
53#[derive(Debug)]
55pub enum PassError {
56 Custom(String),
58}
59
60impl std::fmt::Display for PassError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::Custom(msg) => write!(f, "{msg}"),
64 }
65 }
66}
67
68impl std::error::Error for PassError {}
69
70pub trait CompilerStage: Send + Sync {
72 fn name(&self) -> &'static str;
74
75 fn run(&self, model: &mut ModelProto) -> Result<(), PassError>;
77}
78
79#[derive(Clone, Debug)]
81struct CompilerBinding {
82 slot: String,
84 role: String,
86 concrete_type_name: String,
88}
89
90pub struct Compiler {
92 pub(crate) canonical_disabled: HashSet<String>,
93 pub(crate) stages: Vec<Box<dyn CompilerStage>>,
94 pub(crate) per_hop_budget_ns: u64,
95 pub(crate) target_ir_version: u32,
100 bindings: Vec<CompilerBinding>,
102 binding_storage: Vec<Vec<(&'static str, &'static bb_ir::types::TypeNode)>>,
104 pub(crate) strict_types: bool,
107}
108
109impl Default for Compiler {
110 fn default() -> Self {
111 Self {
112 canonical_disabled: HashSet::new(),
113 stages: Vec::new(),
114 per_hop_budget_ns: bb_ir::syscall_ids::DEFAULT_PER_HOP_BUDGET_NS,
115 target_ir_version: bb_ir::version::FRAMEWORK_IR_VERSION,
116 bindings: Vec::new(),
117 binding_storage: Vec::new(),
118 strict_types: true,
119 }
120 }
121}
122
123impl Compiler {
124 pub fn new() -> Self {
133 Self::default()
134 }
135
136 pub fn with_target_version(mut self, version: u32) -> Self {
140 self.target_ir_version = version;
141 self
142 }
143
144 pub fn with_per_hop_budget_ns(mut self, budget_ns: u64) -> Self {
148 self.per_hop_budget_ns = budget_ns;
149 self
150 }
151
152 pub fn with_permissive_types(mut self) -> Self {
155 self.strict_types = false;
156 self
157 }
158
159 pub fn without_stage(mut self, name: &str) -> Self {
161 if CANONICAL_PASS_NAMES.iter().any(|n| *n == name) {
162 self.canonical_disabled.insert(name.to_string());
163 return self;
164 }
165 self.stages.retain(|s| s.name() != name);
166 self
167 }
168
169 pub fn push_front_stage<S: CompilerStage + 'static>(mut self, stage: S) -> Self {
171 self.stages.insert(0, Box::new(stage));
172 self
173 }
174
175 pub fn push_back_stage<S: CompilerStage + 'static>(mut self, stage: S) -> Self {
177 self.stages.push(Box::new(stage));
178 self
179 }
180
181 pub fn insert_stage<S: CompilerStage + 'static>(mut self, index: usize, stage: S) -> Self {
184 let idx = index.min(self.stages.len());
185 self.stages.insert(idx, Box::new(stage));
186 self
187 }
188
189 pub fn bind_backend<T>(self, slot: impl Into<String>) -> Self
196 where
197 T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::BackendRuntime,
198 {
199 self.bind_concrete_with_storage::<T>(slot.into(), "BackendRuntime", &["tensor"])
200 }
201
202 pub fn bind_index<T>(self, slot: impl Into<String>) -> Self
204 where
205 T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::IndexRuntime,
206 {
207 self.bind_concrete_with_storage::<T>(slot.into(), "IndexRuntime", &["vector"])
208 }
209
210 pub fn bind_model<T>(self, slot: impl Into<String>) -> Self
212 where
213 T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::ModelRuntime,
214 {
215 self.bind_concrete_with_storage::<T>(slot.into(), "ModelRuntime", &["tensor"])
216 }
217
218 pub fn bind_aggregator<T>(self, slot: impl Into<String>) -> Self
220 where
221 T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::AggregatorRuntime,
222 {
223 self.bind_concrete_with_storage::<T>(slot.into(), "AggregatorRuntime", &["element"])
224 }
225
226 pub fn bind_codec<T>(self, slot: impl Into<String>) -> Self
228 where
229 T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::CodecRuntime,
230 {
231 self.bind_concrete_with_storage::<T>(slot.into(), "CodecRuntime", &["in", "out"])
232 }
233
234 pub fn bind_data_source<T>(self, slot: impl Into<String>) -> Self
236 where
237 T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::DataSourceRuntime,
238 {
239 self.bind_concrete_with_storage::<T>(slot.into(), "DataSourceRuntime", &["sample"])
240 }
241
242 pub fn bind_peer_selector<T>(self, slot: impl Into<String>) -> Self
244 where
245 T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::PeerSelectorRuntime,
246 {
247 self.bind_concrete_with_storage::<T>(slot.into(), "PeerSelectorRuntime", &[])
249 }
250
251 pub fn bind_protocol<T>(self, slot: impl Into<String>) -> Self
253 where
254 T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::ProtocolRuntime,
255 {
256 self.bind_concrete_with_storage::<T>(slot.into(), "ProtocolRuntime", &[])
257 }
258
259 fn bind_concrete_with_storage<T: bb_runtime::concrete::ConcreteComponent>(
264 mut self,
265 slot: String,
266 role_runtime: &'static str,
267 port_names: &[&'static str],
268 ) -> Self {
269 let concrete_type_name = T::TYPE_NAME;
270 let storage_types: Vec<(&'static str, &'static bb_ir::types::TypeNode)> = port_names
271 .iter()
272 .filter_map(|&port| {
273 bb_runtime::registry::lookup_storage_type(concrete_type_name, role_runtime, port)
274 .map(|t| (port, t))
275 })
276 .collect();
277 self.bindings.push(CompilerBinding {
278 slot,
279 role: role_runtime.to_string(),
280 concrete_type_name: concrete_type_name.to_string(),
281 });
282 self.binding_storage.push(storage_types);
283 self
284 }
285
286 #[cfg(test)]
288 pub(crate) fn into_binding_spec(self) -> BindingSpec {
289 let empty_storage: Vec<(&'static str, &'static bb_ir::types::TypeNode)> = Vec::new();
290 let mut spec = BindingSpec::new();
291 for (i, b) in self.bindings.into_iter().enumerate() {
292 let storage = self
293 .binding_storage
294 .get(i)
295 .cloned()
296 .unwrap_or_else(|| empty_storage.clone());
297 spec.push_with_storage(b.slot, b.role, b.concrete_type_name, storage);
298 }
299 spec
300 }
301
302 pub fn compile(self, mut model: ModelProto) -> Result<ModelProto, CompileError> {
308 let mut binding_spec = BindingSpec::new();
311 let empty_storage: Vec<(&'static str, &'static bb_ir::types::TypeNode)> = Vec::new();
312 for (i, b) in self.bindings.iter().enumerate() {
313 let storage = self.binding_storage.get(i).unwrap_or(&empty_storage);
314 binding_spec.push_with_storage(
315 b.slot.clone(),
316 b.role.clone(),
317 b.concrete_type_name.clone(),
318 storage.clone(),
319 );
320 }
321
322 refine_polymorphic_value_info(&mut model, &binding_spec)?;
323
324 let mut models = self.run_pipeline(model)?;
325
326 resolve_component_dependencies(&binding_spec, &mut models)?;
328 validate_all_slots_bound(&binding_spec, &models)?;
329
330 let mut targets_per_model: Vec<String> = Vec::with_capacity(models.len());
331 for partition in &models {
332 let target = partition
333 .functions
334 .first()
335 .map(|f| f.name.clone())
336 .unwrap_or_default();
337 targets_per_model.push(target);
338 }
339 for (partition, target) in models.iter_mut().zip(targets_per_model.iter()) {
340 crate::stamp_compilation_metadata::stamp_compilation_metadata(
341 partition,
342 &binding_spec,
343 target,
344 );
345 }
346
347 merge_partitions_into_one(models)
348 }
349
350 pub fn compile_partitions(&self, model: ModelProto) -> Result<Vec<ModelProto>, CompileError> {
354 self.run_pipeline(model)
355 }
356
357 fn run_pipeline(&self, model: ModelProto) -> Result<Vec<ModelProto>, CompileError> {
358 if model.functions.is_empty() {
359 return Err(CompileError::EmptyFunctionTable);
360 }
361 let stamped: Option<u32> = model
362 .metadata_props
363 .iter()
364 .find(|p| p.key == bb_ir::version::FRAMEWORK_IR_VERSION_KEY)
365 .and_then(|p| p.value.parse::<u32>().ok());
366 if let Some(got) = stamped {
367 if got != self.target_ir_version {
368 return Err(CompileError::IrVersionMismatch {
369 expected: self.target_ir_version,
370 got,
371 });
372 }
373 }
374 let mut iter = model.functions.into_iter();
375 let root = iter.next().expect("non-empty checked above");
376 let module_name = root.name.clone();
377 let sub_functions: Vec<bb_ir::proto::onnx::FunctionProto> = iter.collect();
378 let recorded = RecordedModule {
379 function: root,
380 sub_functions,
381 };
382
383 let enabled: HashSet<String> = CANONICAL_PASS_NAMES
384 .iter()
385 .filter(|n| !self.canonical_disabled.contains(**n))
386 .map(|s| s.to_string())
387 .collect();
388
389 let mut models = run_pipeline_with_options(
390 recorded,
391 module_name,
392 &enabled,
393 self.per_hop_budget_ns,
394 self.strict_types,
395 )?;
396
397 for stage in &self.stages {
398 for model in models.iter_mut() {
399 stage.run(model).map_err(|e| CompileError::Internal {
400 detail: format!("compiler stage `{}` failed: {e}", stage.name()),
401 })?;
402 }
403 }
404
405 Ok(models)
406 }
407}
408