Skip to main content

bb_compiler/
driver.rs

1//! `Compiler` — single compile entry point. Canonical pipeline
2//! runs once per compile target; user stages fire after, once per
3//! emitted partition.
4//!
5//! ```ignore
6//! use bytesandbrains::compiler::Compiler;
7//!
8//! let model = my_module.build()?;
9//!
10//! let installables = Compiler::default()
11//!     .push_back_stage(MyStage::new())
12//!     .without_stage("optimize")
13//!     .compile(model)?;
14//! ```
15
16use std::collections::HashSet;
17
18use crate::artifact::BindingSpec;
19use crate::error::CompileError;
20use crate::refine_polymorphic_value_info::refine_polymorphic_value_info;
21use crate::resolve_component_dependencies::resolve_component_dependencies;
22use crate::runner::{run_pipeline_with_options, CANONICAL_PASS_NAMES};
23use crate::validate_all_slots_bound::validate_all_slots_bound;
24use bb_dsl::recorded::RecordedModule;
25use bb_ir::proto::onnx::ModelProto;
26
27/// Concatenate partition `functions[]` into one `ModelProto`.
28/// First partition's non-functions fields win; `metadata_props`
29/// concatenates (later entries shadow on duplicate keys).
30fn merge_partitions_into_one(partitions: Vec<ModelProto>) -> Result<ModelProto, CompileError> {
31    let mut iter = partitions.into_iter();
32    let Some(mut head) = iter.next() else {
33        return Ok(ModelProto::default());
34    };
35    for next in iter {
36        // Content-hash suffixes make collisions vanishingly rare.
37        for fn_b in &next.functions {
38            if head.functions.iter().any(|fn_a| fn_a.name == fn_b.name) {
39                return Err(CompileError::Internal {
40                    detail: format!(
41                        "duplicate function name after partition merge: {}",
42                        fn_b.name
43                    ),
44                });
45            }
46        }
47        head.functions.extend(next.functions);
48        head.metadata_props.extend(next.metadata_props);
49    }
50    Ok(head)
51}
52
53/// Error variants a user-supplied stage may return.
54#[derive(Debug)]
55pub enum PassError {
56    /// Free-form error message from the user's stage.
57    Custom(String),
58}
59
60impl std::fmt::Display for PassError {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            Self::Custom(msg) => write!(f, "{msg}"),
64        }
65    }
66}
67
68impl std::error::Error for PassError {}
69
70/// User-supplied compiler stage.
71pub trait CompilerStage: Send + Sync {
72    /// Unique stage identifier within a Compiler.
73    fn name(&self) -> &'static str;
74
75    /// Mutate the emitted partition model.
76    fn run(&self, model: &mut ModelProto) -> Result<(), PassError>;
77}
78
79/// One author-declared binding from `Compiler::bind_<role>::<T>(slot)`.
80#[derive(Clone, Debug)]
81struct CompilerBinding {
82    /// Slot name matching `#[depends(role = "<slot>")]`.
83    slot: String,
84    /// PascalCase role identifier.
85    role: String,
86    /// `<T as ConcreteComponent>::TYPE_NAME`.
87    concrete_type_name: String,
88}
89
90/// The single compile entry point.
91pub struct Compiler {
92    pub(crate) canonical_disabled: HashSet<String>,
93    pub(crate) stages: Vec<Box<dyn CompilerStage>>,
94    pub(crate) per_hop_budget_ns: u64,
95    /// IR version the compiler was built to consume.
96    /// `run()` checks the input `ModelProto.metadata_props` carries
97    /// the matching `FRAMEWORK_IR_VERSION` stamp before any pass
98    /// runs; mismatch surfaces as `CompileError::IrVersionMismatch`.
99    pub(crate) target_ir_version: u32,
100    /// Bindings collected by `bind_<role>::<T>` builders.
101    bindings: Vec<CompilerBinding>,
102    /// Storage `TypeNode`s parallel to `bindings`.
103    binding_storage: Vec<Vec<(&'static str, &'static bb_ir::types::TypeNode)>>,
104    /// `true` = strict TypeSolver; `false` allows `TYPE_ANY`
105    /// fall-through (for hand-authored test fixtures).
106    pub(crate) strict_types: bool,
107}
108
109impl Default for Compiler {
110    fn default() -> Self {
111        Self {
112            canonical_disabled: HashSet::new(),
113            stages: Vec::new(),
114            per_hop_budget_ns: bb_ir::syscall_ids::DEFAULT_PER_HOP_BUDGET_NS,
115            target_ir_version: bb_ir::version::FRAMEWORK_IR_VERSION,
116            bindings: Vec::new(),
117            binding_storage: Vec::new(),
118            strict_types: true,
119        }
120    }
121}
122
123impl Compiler {
124    /// Fresh compiler.
125    ///
126    /// ```ignore
127    /// let artifact = bb::Compiler::new()
128    ///     .bind_backend::<CpuBackend>("compute")
129    ///     .bind_index::<HnswIndex>("primary_index")
130    ///     .compile(module)?;
131    /// ```
132    pub fn new() -> Self {
133        Self::default()
134    }
135
136    /// Override the IR contract version. Mismatch with the input
137    /// model's `FRAMEWORK_IR_VERSION` stamp raises
138    /// `CompileError::IrVersionMismatch`.
139    pub fn with_target_version(mut self, version: u32) -> Self {
140        self.target_ir_version = version;
141        self
142    }
143
144    /// Override the per-hop budget in nanoseconds used by
145    /// `derive_wire_deadlines` when stamping static deadlines on
146    /// wire ops.
147    pub fn with_per_hop_budget_ns(mut self, budget_ns: u64) -> Self {
148        self.per_hop_budget_ns = budget_ns;
149        self
150    }
151
152    /// Permissive type-solver mode for hand-authored NodeProtos.
153    /// Unresolved values pass through as `TYPE_ANY`.
154    pub fn with_permissive_types(mut self) -> Self {
155        self.strict_types = false;
156        self
157    }
158
159    /// Disable a canonical pass by name. No-op for non-canonical.
160    pub fn without_stage(mut self, name: &str) -> Self {
161        if CANONICAL_PASS_NAMES.iter().any(|n| *n == name) {
162            self.canonical_disabled.insert(name.to_string());
163            return self;
164        }
165        self.stages.retain(|s| s.name() != name);
166        self
167    }
168
169    /// Insert a stage at the front of the user-stage list.
170    pub fn push_front_stage<S: CompilerStage + 'static>(mut self, stage: S) -> Self {
171        self.stages.insert(0, Box::new(stage));
172        self
173    }
174
175    /// Insert a stage at the back of the user-stage list.
176    pub fn push_back_stage<S: CompilerStage + 'static>(mut self, stage: S) -> Self {
177        self.stages.push(Box::new(stage));
178        self
179    }
180
181    /// Insert a stage at the supplied index. Index clamped to
182    /// `[0, stages.len()]`.
183    pub fn insert_stage<S: CompilerStage + 'static>(mut self, index: usize, stage: S) -> Self {
184        let idx = index.min(self.stages.len());
185        self.stages.insert(idx, Box::new(stage));
186        self
187    }
188
189    // ─── Binding chain ─────────────────────────────────────────────
190    //
191    // Each `bind_<role>::<T>(slot)` records a binding; the type bound
192    // (`T: ConcreteComponent + <Role>Runtime`) enforces role match.
193
194    /// Bind a `Backend`-role concrete at `slot`.
195    pub fn bind_backend<T>(self, slot: impl Into<String>) -> Self
196    where
197        T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::BackendRuntime,
198    {
199        self.bind_concrete_with_storage::<T>(slot.into(), "BackendRuntime", &["tensor"])
200    }
201
202    /// Bind an `Index`-role concrete at `slot`.
203    pub fn bind_index<T>(self, slot: impl Into<String>) -> Self
204    where
205        T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::IndexRuntime,
206    {
207        self.bind_concrete_with_storage::<T>(slot.into(), "IndexRuntime", &["vector"])
208    }
209
210    /// Bind a `Model`-role concrete at `slot`.
211    pub fn bind_model<T>(self, slot: impl Into<String>) -> Self
212    where
213        T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::ModelRuntime,
214    {
215        self.bind_concrete_with_storage::<T>(slot.into(), "ModelRuntime", &["tensor"])
216    }
217
218    /// Bind an `Aggregator`-role concrete at `slot`.
219    pub fn bind_aggregator<T>(self, slot: impl Into<String>) -> Self
220    where
221        T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::AggregatorRuntime,
222    {
223        self.bind_concrete_with_storage::<T>(slot.into(), "AggregatorRuntime", &["element"])
224    }
225
226    /// Bind a `Codec`-role concrete at `slot`.
227    pub fn bind_codec<T>(self, slot: impl Into<String>) -> Self
228    where
229        T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::CodecRuntime,
230    {
231        self.bind_concrete_with_storage::<T>(slot.into(), "CodecRuntime", &["in", "out"])
232    }
233
234    /// Bind a `DataSource`-role concrete at `slot`.
235    pub fn bind_data_source<T>(self, slot: impl Into<String>) -> Self
236    where
237        T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::DataSourceRuntime,
238    {
239        self.bind_concrete_with_storage::<T>(slot.into(), "DataSourceRuntime", &["sample"])
240    }
241
242    /// Bind a `PeerSelector`-role concrete at `slot`.
243    pub fn bind_peer_selector<T>(self, slot: impl Into<String>) -> Self
244    where
245        T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::PeerSelectorRuntime,
246    {
247        // PeerSelector has no Storage-bound associated type.
248        self.bind_concrete_with_storage::<T>(slot.into(), "PeerSelectorRuntime", &[])
249    }
250
251    /// Bind a `Protocol`-role concrete at `slot`.
252    pub fn bind_protocol<T>(self, slot: impl Into<String>) -> Self
253    where
254        T: bb_runtime::concrete::ConcreteComponent + bb_runtime::roles::ProtocolRuntime,
255    {
256        self.bind_concrete_with_storage::<T>(slot.into(), "ProtocolRuntime", &[])
257    }
258
259    /// Look up per-port Storage `TypeNode`s and stamp them on
260    /// the binding. Missing inventory entries (hand-rolled
261    /// `<Role>Runtime` impls) silently omit the port; the type
262    /// solver treats missing as "unconstrained."
263    fn bind_concrete_with_storage<T: bb_runtime::concrete::ConcreteComponent>(
264        mut self,
265        slot: String,
266        role_runtime: &'static str,
267        port_names: &[&'static str],
268    ) -> Self {
269        let concrete_type_name = T::TYPE_NAME;
270        let storage_types: Vec<(&'static str, &'static bb_ir::types::TypeNode)> = port_names
271            .iter()
272            .filter_map(|&port| {
273                bb_runtime::registry::lookup_storage_type(concrete_type_name, role_runtime, port)
274                    .map(|t| (port, t))
275            })
276            .collect();
277        self.bindings.push(CompilerBinding {
278            slot,
279            role: role_runtime.to_string(),
280            concrete_type_name: concrete_type_name.to_string(),
281        });
282        self.binding_storage.push(storage_types);
283        self
284    }
285
286    /// Test-only `BindingSpec` materializer.
287    #[cfg(test)]
288    pub(crate) fn into_binding_spec(self) -> BindingSpec {
289        let empty_storage: Vec<(&'static str, &'static bb_ir::types::TypeNode)> = Vec::new();
290        let mut spec = BindingSpec::new();
291        for (i, b) in self.bindings.into_iter().enumerate() {
292            let storage = self
293                .binding_storage
294                .get(i)
295                .cloned()
296                .unwrap_or_else(|| empty_storage.clone());
297            spec.push_with_storage(b.slot, b.role, b.concrete_type_name, storage);
298        }
299        spec
300    }
301
302    /// Run the canonical pipeline and emit one compiled `ModelProto`.
303    /// Output carries `compiled` passport,
304    /// `binding.<target>.<slot>` entries, and `functions[]`
305    /// (one partition root per target). See `src/install.rs` for
306    /// the install-side parse.
307    pub fn compile(self, mut model: ModelProto) -> Result<ModelProto, CompileError> {
308        // Build BindingSpec before refine so the type solver walks
309        // the narrowed denotations, not the placeholders.
310        let mut binding_spec = BindingSpec::new();
311        let empty_storage: Vec<(&'static str, &'static bb_ir::types::TypeNode)> = Vec::new();
312        for (i, b) in self.bindings.iter().enumerate() {
313            let storage = self.binding_storage.get(i).unwrap_or(&empty_storage);
314            binding_spec.push_with_storage(
315                b.slot.clone(),
316                b.role.clone(),
317                b.concrete_type_name.clone(),
318                storage.clone(),
319            );
320        }
321
322        refine_polymorphic_value_info(&mut model, &binding_spec)?;
323
324        let mut models = self.run_pipeline(model)?;
325
326        // Stamp dep metadata; reject placeholders missing a binding.
327        resolve_component_dependencies(&binding_spec, &mut models)?;
328        validate_all_slots_bound(&binding_spec, &models)?;
329
330        let mut targets_per_model: Vec<String> = Vec::with_capacity(models.len());
331        for partition in &models {
332            let target = partition
333                .functions
334                .first()
335                .map(|f| f.name.clone())
336                .unwrap_or_default();
337            targets_per_model.push(target);
338        }
339        for (partition, target) in models.iter_mut().zip(targets_per_model.iter()) {
340            crate::stamp_compilation_metadata::stamp_compilation_metadata(
341                partition,
342                &binding_spec,
343                target,
344            );
345        }
346
347        merge_partitions_into_one(models)
348    }
349
350    /// Inspection-only: pipeline output as `Vec<ModelProto>`
351    /// without binding validation or passport stamping. Tests only;
352    /// production paths use [`Self::compile`].
353    pub fn compile_partitions(&self, model: ModelProto) -> Result<Vec<ModelProto>, CompileError> {
354        self.run_pipeline(model)
355    }
356
357    fn run_pipeline(&self, model: ModelProto) -> Result<Vec<ModelProto>, CompileError> {
358        if model.functions.is_empty() {
359            return Err(CompileError::EmptyFunctionTable);
360        }
361        let stamped: Option<u32> = model
362            .metadata_props
363            .iter()
364            .find(|p| p.key == bb_ir::version::FRAMEWORK_IR_VERSION_KEY)
365            .and_then(|p| p.value.parse::<u32>().ok());
366        if let Some(got) = stamped {
367            if got != self.target_ir_version {
368                return Err(CompileError::IrVersionMismatch {
369                    expected: self.target_ir_version,
370                    got,
371                });
372            }
373        }
374        let mut iter = model.functions.into_iter();
375        let root = iter.next().expect("non-empty checked above");
376        let module_name = root.name.clone();
377        let sub_functions: Vec<bb_ir::proto::onnx::FunctionProto> = iter.collect();
378        let recorded = RecordedModule {
379            function: root,
380            sub_functions,
381        };
382
383        let enabled: HashSet<String> = CANONICAL_PASS_NAMES
384            .iter()
385            .filter(|n| !self.canonical_disabled.contains(**n))
386            .map(|s| s.to_string())
387            .collect();
388
389        let mut models = run_pipeline_with_options(
390            recorded,
391            module_name,
392            &enabled,
393            self.per_hop_budget_ns,
394            self.strict_types,
395        )?;
396
397        for stage in &self.stages {
398            for model in models.iter_mut() {
399                stage.run(model).map_err(|e| CompileError::Internal {
400                    detail: format!("compiler stage `{}` failed: {e}", stage.name()),
401                })?;
402            }
403        }
404
405        Ok(models)
406    }
407}
408