1use bb_ir::proto::onnx::ModelProto;
6
7use crate::graph::Graph;
8use crate::output::Output;
9
10#[derive(Debug)]
12pub enum BuildError {
13 EmptyModule,
15
16 MissingOutputPort {
18 name: String,
20 },
21}
22
23impl std::fmt::Display for BuildError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::EmptyModule => write!(f, "Module::build: recorded body is empty"),
27 Self::MissingOutputPort { name } => write!(
28 f,
29 "Module::build: output(`{name}`) referenced but no producer recorded",
30 ),
31 }
32 }
33}
34
35impl std::error::Error for BuildError {}
36
37pub struct ModuleCall<'a, M: ?Sized + Module> {
42 module: &'a M,
43 bound_inputs: std::vec::Vec<(&'static str, crate::output::Output)>,
44}
45
46impl<M: ?Sized + Module> ModuleCall<'_, M> {
47 pub fn input(mut self, name: &'static str, handle: crate::output::Output) -> Self {
49 self.bound_inputs.push((name, handle));
50 self
51 }
52
53 pub fn build(self, g: &mut crate::graph::Graph) -> ModuleOutputs<'_> {
61 let bindings: std::vec::Vec<(String, crate::output::Output)> = self
62 .bound_inputs
63 .iter()
64 .map(|(name, h)| ((*name).to_string(), h.clone()))
65 .collect();
66 let outputs = self.module.op(g, &bindings);
67 ModuleOutputs { graph: g, outputs }
68 }
69
70 pub fn bootstrap(self, g: &mut crate::graph::Graph) -> ModuleOutputs<'_> {
74 let bindings: std::vec::Vec<(String, crate::output::Output)> = self
75 .bound_inputs
76 .iter()
77 .map(|(name, h)| ((*name).to_string(), h.clone()))
78 .collect();
79 let name = format!("{}__bootstrap", self.module.name());
80 let outputs = g.with_function(&name, &bindings, |g| self.module.bootstrap(g));
81 ModuleOutputs { graph: g, outputs }
82 }
83}
84
85pub struct ModuleOutputs<'a> {
87 graph: &'a mut crate::graph::Graph,
88 outputs: Vec<(String, String)>,
91}
92
93impl ModuleOutputs<'_> {
94 pub fn output(&self, name: &'static str) -> crate::output::Output {
99 if let Some(call_out) = self
100 .outputs
101 .iter()
102 .find(|(port, _)| port.as_str() == name)
103 .map(|(_, call_name)| call_name.clone())
104 {
105 return crate::output::Output::new(call_out, &bb_ir::types::TYPE_BYTES);
106 }
107 self.graph.lookup_output(name).unwrap_or_else(|| {
108 crate::output::Output::new(name.to_string(), &bb_ir::types::TYPE_BYTES)
109 })
110 }
111}
112
113pub trait Module {
118 fn name(&self) -> &str;
120
121 fn body(&self, g: &mut Graph);
124
125 fn bootstrap(&self, _g: &mut Graph) {}
129
130 fn op(&self, g: &mut Graph, bindings: &[(String, Output)]) -> Vec<(String, String)> {
133 g.with_function(self.name(), bindings, |g| self.body(g))
134 }
135
136 fn call(&self) -> ModuleCall<'_, Self> {
138 ModuleCall {
139 module: self,
140 bound_inputs: std::vec::Vec::new(),
141 }
142 }
143
144 fn build(self) -> Result<ModelProto, BuildError>
149 where
150 Self: Sized,
151 {
152 let mut body_g = Graph::new();
153 let bindings: Vec<(String, Output)> = Vec::new();
154 let _ = self.op(&mut body_g, &bindings);
155 let mut pending = body_g.take_pending_errors();
156 if !pending.is_empty() {
157 return Err(pending.remove(0));
158 }
159 let body_recorded = body_g.finish();
160 if body_recorded.function.node.is_empty() && body_recorded.sub_functions.is_empty() {
161 return Err(BuildError::EmptyModule);
162 }
163
164 let body_name = self.name().to_string();
165 let mut boot_g = Graph::new();
166 boot_g.with_function(&format!("{body_name}__bootstrap"), &[], |g| {
167 self.bootstrap(g);
168 });
169 let mut boot_pending = boot_g.take_pending_errors();
170 if !boot_pending.is_empty() {
171 return Err(boot_pending.remove(0));
172 }
173 let boot_recorded = boot_g.finish();
174
175 let mut functions = Vec::with_capacity(
176 1 + body_recorded.sub_functions.len() + boot_recorded.sub_functions.len() + 1,
177 );
178 let mut body_fn = body_recorded.function;
179 bb_ir::keys::stamp_function_module_phase(&mut body_fn, bb_ir::keys::MODULE_PHASE_BODY);
180 functions.push(body_fn);
181 functions.extend(body_recorded.sub_functions);
182 if !boot_recorded.function.node.is_empty() {
183 let mut boot_fn = boot_recorded.function;
184 bb_ir::keys::stamp_function_module_phase(
185 &mut boot_fn,
186 bb_ir::keys::MODULE_PHASE_BOOTSTRAP,
187 );
188 functions.push(boot_fn);
189 functions.extend(boot_recorded.sub_functions);
190 }
191
192 Ok(ModelProto {
193 functions,
194 ..Default::default()
195 })
196 }
197}
198