Skip to main content

bb_dsl/
module.rs

1//! The `Module` trait. `Module::build()` produces one pre-compile
2//! `ModelProto` where `functions[0]` is the top-level body and
3//! `functions[1..]` are composed sub-Modules (deduped by `name()`).
4
5use bb_ir::proto::onnx::ModelProto;
6
7use crate::graph::Graph;
8use crate::output::Output;
9
10/// Recording-time errors. Compile-time errors come from the compiler.
11#[derive(Debug)]
12pub enum BuildError {
13    /// Body recorded zero NodeProtos.
14    EmptyModule,
15
16    /// `output()` referenced a port with no recorded producer.
17    MissingOutputPort {
18        /// The port name as supplied by the user.
19        name: String,
20    },
21}
22
23impl std::fmt::Display for BuildError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::EmptyModule => write!(f, "Module::build: recorded body is empty"),
27            Self::MissingOutputPort { name } => write!(
28                f,
29                "Module::build: output(`{name}`) referenced but no producer recorded",
30            ),
31        }
32    }
33}
34
35impl std::error::Error for BuildError {}
36
37/// Fluent call-site for inlining a sub-Module's body. Inlining
38/// rather than `FunctionProto` calls so independent branches inside
39/// the sub-Module run as soon as their inputs are ready (not
40/// blocked on a single CALL barrier).
41pub struct ModuleCall<'a, M: ?Sized + Module> {
42    module: &'a M,
43    bound_inputs: std::vec::Vec<(&'static str, crate::output::Output)>,
44}
45
46impl<M: ?Sized + Module> ModuleCall<'_, M> {
47    /// Bind a named input to a value the caller already produced.
48    pub fn input(mut self, name: &'static str, handle: crate::output::Output) -> Self {
49        self.bound_inputs.push((name, handle));
50        self
51    }
52
53    /// Record the sub-Module's body into `g`. Pull named outputs
54    /// from the returned [`ModuleOutputs`].
55    ///
56    /// ```ignore
57    /// let coord_out = self.coordinator.call().input("incoming", q).build(g);
58    /// let grad = coord_out.output("aggregated_grad");
59    /// ```
60    pub fn build(self, g: &mut crate::graph::Graph) -> ModuleOutputs<'_> {
61        let bindings: std::vec::Vec<(String, crate::output::Output)> = self
62            .bound_inputs
63            .iter()
64            .map(|(name, h)| ((*name).to_string(), h.clone()))
65            .collect();
66        let outputs = self.module.op(g, &bindings);
67        ModuleOutputs { graph: g, outputs }
68    }
69
70    /// Compose a child Module's bootstrap into the parent's
71    /// bootstrap. Emits a CALL to `"<name>__bootstrap"`; body-phase
72    /// ops gate until the child's CallContext drops.
73    pub fn bootstrap(self, g: &mut crate::graph::Graph) -> ModuleOutputs<'_> {
74        let bindings: std::vec::Vec<(String, crate::output::Output)> = self
75            .bound_inputs
76            .iter()
77            .map(|(name, h)| ((*name).to_string(), h.clone()))
78            .collect();
79        let name = format!("{}__bootstrap", self.module.name());
80        let outputs = g.with_function(&name, &bindings, |g| self.module.bootstrap(g));
81        ModuleOutputs { graph: g, outputs }
82    }
83}
84
85/// Named-output handle returned by [`ModuleCall::build`].
86pub struct ModuleOutputs<'a> {
87    graph: &'a mut crate::graph::Graph,
88    /// `(child_port_name, parent_call_output_name)` pairs. Empty
89    /// for top-level wraps (body `g.output` lands in parent scope).
90    outputs: Vec<(String, String)>,
91}
92
93impl ModuleOutputs<'_> {
94    /// Resolve a named output. Returns the CALL NodeProto's
95    /// outer-scope output for sub-function calls, the parent-scope
96    /// `lookup_output` for top-level wraps, or a sentinel that
97    /// surfaces as `BuildError::MissingOutputPort` downstream.
98    pub fn output(&self, name: &'static str) -> crate::output::Output {
99        if let Some(call_out) = self
100            .outputs
101            .iter()
102            .find(|(port, _)| port.as_str() == name)
103            .map(|(_, call_name)| call_name.clone())
104        {
105            return crate::output::Output::new(call_out, &bb_ir::types::TYPE_BYTES);
106        }
107        self.graph.lookup_output(name).unwrap_or_else(|| {
108            crate::output::Output::new(name.to_string(), &bb_ir::types::TYPE_BYTES)
109        })
110    }
111}
112
113/// Unit of composition. Implement `name()` + `body()`; framework
114/// supplies `op()` and `build()`. Body declares inputs via
115/// `g.input("name")` and emits via `g.output("name", value)` (local)
116/// or `g.net_out("name", peers, value)` (network).
117pub trait Module {
118    /// Short stable identifier — becomes `FunctionProto.name`.
119    fn name(&self) -> &str;
120
121    /// Recording logic. Compose child Modules via
122    /// `self.child.call().input(...).build(g).output(...)`.
123    fn body(&self, g: &mut Graph);
124
125    /// Setup recording, run once before the first `body` poll. May
126    /// emit `ContractResponse::Later`; the engine drains every
127    /// outstanding bootstrap completion before activating body ops.
128    fn bootstrap(&self, _g: &mut Graph) {}
129
130    /// Records `body()` into a function scope named `self.name()`.
131    /// Emits a CALL in the outer target. Do not override.
132    fn op(&self, g: &mut Graph, bindings: &[(String, Output)]) -> Vec<(String, String)> {
133        g.with_function(self.name(), bindings, |g| self.body(g))
134    }
135
136    /// Open a fluent call-site that inlines `self`'s body.
137    fn call(&self) -> ModuleCall<'_, Self> {
138        ModuleCall {
139            module: self,
140            bound_inputs: std::vec::Vec::new(),
141        }
142    }
143
144    /// Emit one pre-compile `ModelProto`. Body becomes
145    /// `functions[0]` stamped with `module_phase = "body"`. If
146    /// `bootstrap` recorded any ops it lands as a sibling
147    /// `"<name>__bootstrap"` stamped with `"bootstrap"`.
148    fn build(self) -> Result<ModelProto, BuildError>
149    where
150        Self: Sized,
151    {
152        let mut body_g = Graph::new();
153        let bindings: Vec<(String, Output)> = Vec::new();
154        let _ = self.op(&mut body_g, &bindings);
155        let mut pending = body_g.take_pending_errors();
156        if !pending.is_empty() {
157            return Err(pending.remove(0));
158        }
159        let body_recorded = body_g.finish();
160        if body_recorded.function.node.is_empty() && body_recorded.sub_functions.is_empty() {
161            return Err(BuildError::EmptyModule);
162        }
163
164        let body_name = self.name().to_string();
165        let mut boot_g = Graph::new();
166        boot_g.with_function(&format!("{body_name}__bootstrap"), &[], |g| {
167            self.bootstrap(g);
168        });
169        let mut boot_pending = boot_g.take_pending_errors();
170        if !boot_pending.is_empty() {
171            return Err(boot_pending.remove(0));
172        }
173        let boot_recorded = boot_g.finish();
174
175        let mut functions = Vec::with_capacity(
176            1 + body_recorded.sub_functions.len() + boot_recorded.sub_functions.len() + 1,
177        );
178        let mut body_fn = body_recorded.function;
179        bb_ir::keys::stamp_function_module_phase(&mut body_fn, bb_ir::keys::MODULE_PHASE_BODY);
180        functions.push(body_fn);
181        functions.extend(body_recorded.sub_functions);
182        if !boot_recorded.function.node.is_empty() {
183            let mut boot_fn = boot_recorded.function;
184            bb_ir::keys::stamp_function_module_phase(
185                &mut boot_fn,
186                bb_ir::keys::MODULE_PHASE_BOOTSTRAP,
187            );
188            functions.push(boot_fn);
189            functions.extend(boot_recorded.sub_functions);
190        }
191
192        Ok(ModelProto {
193            functions,
194            ..Default::default()
195        })
196    }
197}
198