Skip to main content

bb_ops/placeholders/
mod.rs

1//! Role-method dispatch slot placeholders. Each `*Slot` unit
2//! struct is a generic slot bound at compile time via
3//! `Compiler::new().bind_<role>::<T>("slot")`. DSL methods record
4//! NodeProtos stamped with `(required_trait, slot_id)` for
5//! binding-resolution routing.
6//!
7//! ```ignore
8//! pub struct MyModule {
9//!     backend: BackendSlot,       // bind any BackendRuntime
10//!     data:    DataLoaderSlot,    // bind any DataSourceRuntime
11//! }
12//! ```
13
14use bb_dsl::graph::{attr_float, attr_graph, attr_int, attr_ints, attr_tensor, kv, Graph};
15use bb_dsl::output::Output;
16use bb_ir::proto::onnx::{AttributeProto, GraphProto, NodeProto, TensorProto};
17use bb_ir::types::{TYPE_TENSOR, TYPE_TENSOR_F32, TYPE_TRIGGER};
18
19/// Generic Backend slot. Carries the `ai.onnx v1` DSL catalog
20/// (48 methods). Outputs are typed `&TYPE_TENSOR_F32`. The
21/// `BackendSubgraph` carrier is compiler-emitted, not a DSL method.
22#[derive(Debug, Clone, Copy, Default)]
23pub struct BackendSlot;
24
25impl BackendSlot {
26    // --- Private recording helper --------------------------------
27
28    /// Records an `ai.onnx::<op_type>` NodeProto stamped with
29    /// `(required_trait, slot_id)`. Returns `&TYPE_TENSOR_F32` outputs.
30    fn record_op(
31        &self,
32        g: &mut Graph,
33        op_type: &str,
34        input_names: Vec<String>,
35        n_outputs: usize,
36        attribute: Vec<AttributeProto>,
37    ) -> Vec<Output> {
38        let slot_id = g.register_generic(self, "BackendRuntime");
39        let output_names: Vec<String> = (0..n_outputs).map(|_| g.next_site_name()).collect();
40        g.push_node(NodeProto {
41            op_type: op_type.into(),
42            domain: "ai.onnx".into(),
43            input: input_names,
44            output: output_names.clone(),
45            attribute,
46            metadata_props: vec![
47                kv("ai.bytesandbrains.required_trait", "BackendRuntime"),
48                kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
49            ],
50            ..Default::default()
51        });
52        // Stamp value_info for each output so the type_solver sees a
53        // declared denotation at every recorder site.
54        for name in &output_names {
55            g.declare_value_info(name, &TYPE_TENSOR_F32);
56        }
57        output_names
58            .into_iter()
59            .map(|n| Output::new(n, &TYPE_TENSOR_F32))
60            .collect()
61    }
62
63    fn record_one(
64        &self,
65        g: &mut Graph,
66        op_type: &str,
67        input_names: Vec<String>,
68        attribute: Vec<AttributeProto>,
69    ) -> Output {
70        self.record_op(g, op_type, input_names, 1, attribute)
71            .into_iter()
72            .next()
73            .expect("record_op with n_outputs=1")
74    }
75
76    // --- Creation ------------------------------------------------
77
78    /// `Zeros(dims)` - zero-initialized tensor of given shape.
79    pub fn zeros(&self, g: &mut Graph, dims: Vec<i64>) -> Output {
80        self.record_one(g, "Zeros", vec![], vec![attr_ints("dims", dims)])
81    }
82
83    /// `Ones(dims)` - one-initialized tensor of given shape.
84    pub fn ones(&self, g: &mut Graph, dims: Vec<i64>) -> Output {
85        self.record_one(g, "Ones", vec![], vec![attr_ints("dims", dims)])
86    }
87
88    /// `Constant(value)` - embedded literal tensor.
89    pub fn constant(&self, g: &mut Graph, value: TensorProto) -> Output {
90        self.record_one(g, "Constant", vec![], vec![attr_tensor("value", value)])
91    }
92
93    // --- Element-wise arithmetic ---------------------------------
94
95    /// `Add` - element-wise `a + b`.
96    pub fn add(&self, g: &mut Graph, a: Output, b: Output) -> Output {
97        self.record_one(g, "Add", vec![a.name, b.name], vec![])
98    }
99
100    /// `Sub` - element-wise `a - b`.
101    pub fn sub(&self, g: &mut Graph, a: Output, b: Output) -> Output {
102        self.record_one(g, "Sub", vec![a.name, b.name], vec![])
103    }
104
105    /// `Mul` - element-wise `a * b`.
106    pub fn mul(&self, g: &mut Graph, a: Output, b: Output) -> Output {
107        self.record_one(g, "Mul", vec![a.name, b.name], vec![])
108    }
109
110    /// `Div` - element-wise `a / b`.
111    pub fn div(&self, g: &mut Graph, a: Output, b: Output) -> Output {
112        self.record_one(g, "Div", vec![a.name, b.name], vec![])
113    }
114
115    /// `Neg` - element-wise negation.
116    pub fn neg(&self, g: &mut Graph, t: Output) -> Output {
117        self.record_one(g, "Neg", vec![t.name], vec![])
118    }
119
120    /// `Abs` - element-wise absolute value.
121    pub fn abs(&self, g: &mut Graph, t: Output) -> Output {
122        self.record_one(g, "Abs", vec![t.name], vec![])
123    }
124
125    /// `Sqrt` - element-wise square root.
126    pub fn sqrt(&self, g: &mut Graph, t: Output) -> Output {
127        self.record_one(g, "Sqrt", vec![t.name], vec![])
128    }
129
130    /// `Exp` - element-wise natural exponential.
131    pub fn exp(&self, g: &mut Graph, t: Output) -> Output {
132        self.record_one(g, "Exp", vec![t.name], vec![])
133    }
134
135    /// `Log` - element-wise natural logarithm.
136    pub fn log(&self, g: &mut Graph, t: Output) -> Output {
137        self.record_one(g, "Log", vec![t.name], vec![])
138    }
139
140    /// `Pow` - element-wise `a ** b`.
141    pub fn pow(&self, g: &mut Graph, a: Output, b: Output) -> Output {
142        self.record_one(g, "Pow", vec![a.name, b.name], vec![])
143    }
144
145    // --- Linear algebra ------------------------------------------
146
147    /// `MatMul` - matrix multiplication (canonical example).
148    pub fn matmul(&self, g: &mut Graph, a: Output, b: Output) -> Output {
149        self.record_one(g, "MatMul", vec![a.name, b.name], vec![])
150    }
151
152    /// `Gemm` - `alpha * (a @ b) + beta * c` with optional transpose.
153    #[allow(clippy::too_many_arguments)]
154    pub fn gemm(
155        &self,
156        g: &mut Graph,
157        a: Output,
158        b: Output,
159        c: Option<Output>,
160        alpha: f32,
161        beta: f32,
162        trans_a: bool,
163        trans_b: bool,
164    ) -> Output {
165        let mut inputs = vec![a.name, b.name];
166        if let Some(c) = c {
167            inputs.push(c.name);
168        }
169        self.record_one(
170            g,
171            "Gemm",
172            inputs,
173            vec![
174                attr_float("alpha", alpha),
175                attr_float("beta", beta),
176                attr_int("transA", trans_a as i64),
177                attr_int("transB", trans_b as i64),
178            ],
179        )
180    }
181
182    /// `Dot` - dot product (reduces along last axis for higher rank).
183    pub fn dot(&self, g: &mut Graph, a: Output, b: Output) -> Output {
184        self.record_one(g, "Dot", vec![a.name, b.name], vec![])
185    }
186
187    // --- Activations ---------------------------------------------
188
189    /// `Relu` - `max(0, x)`.
190    pub fn relu(&self, g: &mut Graph, t: Output) -> Output {
191        self.record_one(g, "Relu", vec![t.name], vec![])
192    }
193
194    /// `Sigmoid` - `1 / (1 + exp(-x))`.
195    pub fn sigmoid(&self, g: &mut Graph, t: Output) -> Output {
196        self.record_one(g, "Sigmoid", vec![t.name], vec![])
197    }
198
199    /// `Tanh` - hyperbolic tangent.
200    pub fn tanh(&self, g: &mut Graph, t: Output) -> Output {
201        self.record_one(g, "Tanh", vec![t.name], vec![])
202    }
203
204    /// `Softmax(axis)` - softmax along the given axis.
205    pub fn softmax(&self, g: &mut Graph, t: Output, axis: i64) -> Output {
206        self.record_one(g, "Softmax", vec![t.name], vec![attr_int("axis", axis)])
207    }
208
209    /// `LeakyRelu(alpha)` - `x if x > 0 else alpha * x`.
210    pub fn leaky_relu(&self, g: &mut Graph, t: Output, alpha: f32) -> Output {
211        self.record_one(
212            g,
213            "LeakyRelu",
214            vec![t.name],
215            vec![attr_float("alpha", alpha)],
216        )
217    }
218
219    /// `Gelu` - Gaussian Error Linear Unit.
220    pub fn gelu(&self, g: &mut Graph, t: Output) -> Output {
221        self.record_one(g, "Gelu", vec![t.name], vec![])
222    }
223
224    // --- Shape / structural --------------------------------------
225
226    /// `Reshape(dims)` - reshape to given dims.
227    pub fn reshape(&self, g: &mut Graph, t: Output, dims: Vec<i64>) -> Output {
228        self.record_one(g, "Reshape", vec![t.name], vec![attr_ints("dims", dims)])
229    }
230
231    /// `Transpose(perm)` - `None` reverses all dims.
232    pub fn transpose(&self, g: &mut Graph, t: Output, perm: Option<Vec<i64>>) -> Output {
233        let attrs = match perm {
234            Some(p) => vec![attr_ints("perm", p)],
235            None => vec![],
236        };
237        self.record_one(g, "Transpose", vec![t.name], attrs)
238    }
239
240    /// `Concat(axis)` - concatenate `tensors` along axis.
241    pub fn concat(&self, g: &mut Graph, tensors: Vec<Output>, axis: i64) -> Output {
242        let inputs = tensors.into_iter().map(|t| t.name).collect();
243        self.record_one(g, "Concat", inputs, vec![attr_int("axis", axis)])
244    }
245
246    /// `Split(axis, sizes)` - split into N parts. Returns one
247    /// `Output` per size.
248    pub fn split(&self, g: &mut Graph, t: Output, axis: i64, sizes: Vec<i64>) -> Vec<Output> {
249        let n = sizes.len();
250        self.record_op(
251            g,
252            "Split",
253            vec![t.name],
254            n,
255            vec![attr_int("axis", axis), attr_ints("split", sizes)],
256        )
257    }
258
259    /// `Slice(starts, ends, axes?, steps?)` - NumPy-style slice.
260    pub fn slice(
261        &self,
262        g: &mut Graph,
263        t: Output,
264        starts: Vec<i64>,
265        ends: Vec<i64>,
266        axes: Option<Vec<i64>>,
267        steps: Option<Vec<i64>>,
268    ) -> Output {
269        let mut attrs = vec![attr_ints("starts", starts), attr_ints("ends", ends)];
270        if let Some(a) = axes {
271            attrs.push(attr_ints("axes", a));
272        }
273        if let Some(s) = steps {
274            attrs.push(attr_ints("steps", s));
275        }
276        self.record_one(g, "Slice", vec![t.name], attrs)
277    }
278
279    /// `Squeeze(axes?)` - remove length-1 dimensions.
280    pub fn squeeze(&self, g: &mut Graph, t: Output, axes: Option<Vec<i64>>) -> Output {
281        let attrs = match axes {
282            Some(a) => vec![attr_ints("axes", a)],
283            None => vec![],
284        };
285        self.record_one(g, "Squeeze", vec![t.name], attrs)
286    }
287
288    /// `Unsqueeze(axes)` - insert length-1 dimensions.
289    pub fn unsqueeze(&self, g: &mut Graph, t: Output, axes: Vec<i64>) -> Output {
290        self.record_one(g, "Unsqueeze", vec![t.name], vec![attr_ints("axes", axes)])
291    }
292
293    /// `Identity` - clone pass-through.
294    pub fn identity(&self, g: &mut Graph, t: Output) -> Output {
295        self.record_one(g, "Identity", vec![t.name], vec![])
296    }
297
298    /// `Cast(to)` - cast to the given ONNX `DataType` enum value.
299    pub fn cast(&self, g: &mut Graph, t: Output, to_elem_type: i32) -> Output {
300        self.record_one(
301            g,
302            "Cast",
303            vec![t.name],
304            vec![attr_int("to", to_elem_type as i64)],
305        )
306    }
307
308    // --- Reductions ----------------------------------------------
309
310    fn reduce(
311        &self,
312        g: &mut Graph,
313        op_type: &str,
314        t: Output,
315        axes: Option<Vec<i64>>,
316        keepdims: bool,
317    ) -> Output {
318        let mut attrs = vec![attr_int("keepdims", keepdims as i64)];
319        if let Some(a) = axes {
320            attrs.push(attr_ints("axes", a));
321        }
322        self.record_one(g, op_type, vec![t.name], attrs)
323    }
324
325    /// `ReduceSum(axes?, keepdims)`.
326    pub fn reduce_sum(
327        &self,
328        g: &mut Graph,
329        t: Output,
330        axes: Option<Vec<i64>>,
331        keepdims: bool,
332    ) -> Output {
333        self.reduce(g, "ReduceSum", t, axes, keepdims)
334    }
335
336    /// `ReduceMean(axes?, keepdims)`.
337    pub fn reduce_mean(
338        &self,
339        g: &mut Graph,
340        t: Output,
341        axes: Option<Vec<i64>>,
342        keepdims: bool,
343    ) -> Output {
344        self.reduce(g, "ReduceMean", t, axes, keepdims)
345    }
346
347    /// `ReduceMax(axes?, keepdims)`.
348    pub fn reduce_max(
349        &self,
350        g: &mut Graph,
351        t: Output,
352        axes: Option<Vec<i64>>,
353        keepdims: bool,
354    ) -> Output {
355        self.reduce(g, "ReduceMax", t, axes, keepdims)
356    }
357
358    /// `ReduceMin(axes?, keepdims)`.
359    pub fn reduce_min(
360        &self,
361        g: &mut Graph,
362        t: Output,
363        axes: Option<Vec<i64>>,
364        keepdims: bool,
365    ) -> Output {
366        self.reduce(g, "ReduceMin", t, axes, keepdims)
367    }
368
369    // --- Comparison ----------------------------------------------
370
371    /// `Equal` - element-wise `a == b` (bool tensor).
372    pub fn equal(&self, g: &mut Graph, a: Output, b: Output) -> Output {
373        self.record_one(g, "Equal", vec![a.name, b.name], vec![])
374    }
375
376    /// `Greater` - element-wise `a > b` (bool tensor).
377    pub fn greater(&self, g: &mut Graph, a: Output, b: Output) -> Output {
378        self.record_one(g, "Greater", vec![a.name, b.name], vec![])
379    }
380
381    /// `Less` - element-wise `a < b` (bool tensor).
382    pub fn less(&self, g: &mut Graph, a: Output, b: Output) -> Output {
383        self.record_one(g, "Less", vec![a.name, b.name], vec![])
384    }
385
386    // --- Normalization -------------------------------------------
387
388    /// `BatchNormalization(epsilon, momentum)`.
389    #[allow(clippy::too_many_arguments)]
390    pub fn batch_normalization(
391        &self,
392        g: &mut Graph,
393        input: Output,
394        scale: Output,
395        bias: Output,
396        mean: Output,
397        variance: Output,
398        epsilon: f32,
399        momentum: f32,
400    ) -> Output {
401        self.record_one(
402            g,
403            "BatchNormalization",
404            vec![input.name, scale.name, bias.name, mean.name, variance.name],
405            vec![
406                attr_float("epsilon", epsilon),
407                attr_float("momentum", momentum),
408            ],
409        )
410    }
411
412    /// `LayerNormalization(axis, epsilon)`.
413    pub fn layer_normalization(
414        &self,
415        g: &mut Graph,
416        input: Output,
417        scale: Output,
418        bias: Option<Output>,
419        axis: i64,
420        epsilon: f32,
421    ) -> Output {
422        let mut inputs = vec![input.name, scale.name];
423        if let Some(b) = bias {
424            inputs.push(b.name);
425        }
426        self.record_one(
427            g,
428            "LayerNormalization",
429            inputs,
430            vec![attr_int("axis", axis), attr_float("epsilon", epsilon)],
431        )
432    }
433
434    // --- Conv / Pool ---------------------------------------------
435
436    /// `Conv(kernel_shape, strides, pads, dilations, group)`.
437    #[allow(clippy::too_many_arguments)]
438    pub fn conv(
439        &self,
440        g: &mut Graph,
441        input: Output,
442        weight: Output,
443        bias: Option<Output>,
444        kernel_shape: Vec<i64>,
445        strides: Vec<i64>,
446        pads: Vec<i64>,
447        dilations: Vec<i64>,
448        group: i64,
449    ) -> Output {
450        let mut inputs = vec![input.name, weight.name];
451        if let Some(b) = bias {
452            inputs.push(b.name);
453        }
454        self.record_one(
455            g,
456            "Conv",
457            inputs,
458            vec![
459                attr_ints("kernel_shape", kernel_shape),
460                attr_ints("strides", strides),
461                attr_ints("pads", pads),
462                attr_ints("dilations", dilations),
463                attr_int("group", group),
464            ],
465        )
466    }
467
468    /// `MaxPool(kernel_shape, strides, pads)`.
469    pub fn max_pool(
470        &self,
471        g: &mut Graph,
472        input: Output,
473        kernel_shape: Vec<i64>,
474        strides: Vec<i64>,
475        pads: Vec<i64>,
476    ) -> Output {
477        self.record_one(
478            g,
479            "MaxPool",
480            vec![input.name],
481            vec![
482                attr_ints("kernel_shape", kernel_shape),
483                attr_ints("strides", strides),
484                attr_ints("pads", pads),
485            ],
486        )
487    }
488
489    /// `AveragePool(kernel_shape, strides, pads, count_include_pad)`.
490    pub fn average_pool(
491        &self,
492        g: &mut Graph,
493        input: Output,
494        kernel_shape: Vec<i64>,
495        strides: Vec<i64>,
496        pads: Vec<i64>,
497        count_include_pad: bool,
498    ) -> Output {
499        self.record_one(
500            g,
501            "AveragePool",
502            vec![input.name],
503            vec![
504                attr_ints("kernel_shape", kernel_shape),
505                attr_ints("strides", strides),
506                attr_ints("pads", pads),
507                attr_int("count_include_pad", count_include_pad as i64),
508            ],
509        )
510    }
511
512    /// `GlobalAveragePool` - collapse spatial dims to length 1.
513    pub fn global_average_pool(&self, g: &mut Graph, input: Output) -> Output {
514        self.record_one(g, "GlobalAveragePool", vec![input.name], vec![])
515    }
516
517    // --- Indexing ------------------------------------------------
518
519    /// `Gather(axis)`.
520    pub fn gather(&self, g: &mut Graph, data: Output, indices: Output, axis: i64) -> Output {
521        self.record_one(
522            g,
523            "Gather",
524            vec![data.name, indices.name],
525            vec![attr_int("axis", axis)],
526        )
527    }
528
529    /// `Scatter(axis)`.
530    pub fn scatter(
531        &self,
532        g: &mut Graph,
533        data: Output,
534        indices: Output,
535        updates: Output,
536        axis: i64,
537    ) -> Output {
538        self.record_one(
539            g,
540            "Scatter",
541            vec![data.name, indices.name, updates.name],
542            vec![attr_int("axis", axis)],
543        )
544    }
545
546    // --- Control flow --------------------------------------------
547
548    /// `If(then_branch, else_branch)` - both branches are sub-graphs
549    /// carried on `AttributeProto.g` per IR_AND_DSL.md Part 2 line 80.
550    /// Returns one `Output` per branch output.
551    pub fn if_op(
552        &self,
553        g: &mut Graph,
554        cond: Output,
555        then_branch: GraphProto,
556        else_branch: GraphProto,
557        n_outputs: usize,
558    ) -> Vec<Output> {
559        self.record_op(
560            g,
561            "If",
562            vec![cond.name],
563            n_outputs,
564            vec![
565                attr_graph("then_branch", then_branch),
566                attr_graph("else_branch", else_branch),
567            ],
568        )
569    }
570
571    /// `Loop(body)` - execute `body` until `cond` becomes false or
572    /// `max_trip_count` is reached.
573    pub fn loop_op(
574        &self,
575        g: &mut Graph,
576        max_trip_count: Option<Output>,
577        cond: Option<Output>,
578        body: GraphProto,
579        initial: Vec<Output>,
580        n_outputs: usize,
581    ) -> Vec<Output> {
582        let mut inputs = vec![
583            max_trip_count.map(|o| o.name).unwrap_or_default(),
584            cond.map(|o| o.name).unwrap_or_default(),
585        ];
586        inputs.extend(initial.into_iter().map(|o| o.name));
587        self.record_op(g, "Loop", inputs, n_outputs, vec![attr_graph("body", body)])
588    }
589}
590
591// ---------------------------------------------------------------
592// Role-method recording helper
593// ---------------------------------------------------------------
594
595/// Record a NodeProto for one role-method DSL call against a generic
596/// placeholder. Domain follows the `ai.bytesandbrains.role.<role>`
597/// convention from `docs/IR_AND_DSL.md` §5c; metadata stamps the
598/// `(required_trait, slot_id)` pair so the compiler's
599/// `inline_role_methods` pass can swap in the bound impl's body.
600#[allow(clippy::too_many_arguments)]
601fn record_role_op<P: 'static>(
602    g: &mut Graph,
603    placeholder: &P,
604    required_trait: &'static str,
605    role_domain: &'static str,
606    op_type: &str,
607    input_names: Vec<String>,
608    n_outputs: usize,
609    attribute: Vec<AttributeProto>,
610) -> Vec<Output> {
611    let slot_id = g.register_generic(placeholder, required_trait);
612    let output_names: Vec<String> = (0..n_outputs).map(|_| g.next_site_name()).collect();
613    g.push_node(NodeProto {
614        op_type: op_type.into(),
615        domain: role_domain.into(),
616        input: input_names,
617        output: output_names.clone(),
618        attribute,
619        metadata_props: vec![
620            kv("ai.bytesandbrains.required_trait", required_trait),
621            kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
622        ],
623        ..Default::default()
624    });
625    for name in &output_names {
626        g.declare_value_info(name, &TYPE_TENSOR);
627    }
628    output_names
629        .into_iter()
630        .map(|n| Output::new(n, &TYPE_TENSOR))
631        .collect()
632}
633
634#[allow(clippy::too_many_arguments)]
635fn record_role_op_one(
636    g: &mut Graph,
637    placeholder: &impl std::any::Any,
638    required_trait: &'static str,
639    role_domain: &'static str,
640    op_type: &str,
641    input_names: Vec<String>,
642    attribute: Vec<AttributeProto>,
643) -> Output {
644    record_role_op(
645        g,
646        placeholder,
647        required_trait,
648        role_domain,
649        op_type,
650        input_names,
651        1,
652        attribute,
653    )
654    .into_iter()
655    .next()
656    .expect("record_role_op with n_outputs=1")
657}
658
659// ---------------------------------------------------------------
660// Model
661// ---------------------------------------------------------------
662
663/// Generic Model slot placeholder. Bind a concrete `ModelRuntime` via
664/// `Node::with_model(...)`. Exposes the six role-method
665/// DSL operations:
666/// `Forward`, `Backward`, `ComputeLoss`, `ApplyDelta`,
667/// `LoadParameters`, `Params`.
668#[derive(Debug, Clone, Copy, Default)]
669pub struct ModelSlot;
670
671impl ModelSlot {
672    /// `Forward(input) -> output` - tensor → tensor forward pass.
673    pub fn forward(&self, g: &mut Graph, input: Output) -> Output {
674        record_role_op_one(
675            g,
676            self,
677            "ModelRuntime",
678            "ai.bytesandbrains.role.model",
679            "Forward",
680            vec![input.name],
681            vec![],
682        )
683    }
684
685    /// `Backward(grad) -> cmd` - accumulate gradients given upstream gradient.
686    pub fn backward(&self, g: &mut Graph, grad: Output) -> Output {
687        record_role_op_one(
688            g,
689            self,
690            "ModelRuntime",
691            "ai.bytesandbrains.role.model",
692            "Backward",
693            vec![grad.name],
694            vec![],
695        )
696    }
697
698    /// `ComputeLoss(input, target) -> loss` - scalar loss score.
699    pub fn compute_loss(&self, g: &mut Graph, input: Output, target: Output) -> Output {
700        record_role_op_one(
701            g,
702            self,
703            "ModelRuntime",
704            "ai.bytesandbrains.role.model",
705            "ComputeLoss",
706            vec![input.name, target.name],
707            vec![],
708        )
709    }
710
711    /// `ApplyDelta(delta) -> cmd` - apply parameter delta in-place.
712    pub fn apply_delta(&self, g: &mut Graph, delta: Output) -> Output {
713        record_role_op_one(
714            g,
715            self,
716            "ModelRuntime",
717            "ai.bytesandbrains.role.model",
718            "ApplyDelta",
719            vec![delta.name],
720            vec![],
721        )
722    }
723
724    /// `LoadParameters(params) -> cmd` - load parameters wholesale.
725    pub fn load_parameters(&self, g: &mut Graph, params: Output) -> Output {
726        record_role_op_one(
727            g,
728            self,
729            "ModelRuntime",
730            "ai.bytesandbrains.role.model",
731            "LoadParameters",
732            vec![params.name],
733            vec![],
734        )
735    }
736
737    /// `Params() -> params` - snapshot the current parameter tensor.
738    pub fn params(&self, g: &mut Graph) -> Output {
739        record_role_op_one(
740            g,
741            self,
742            "ModelRuntime",
743            "ai.bytesandbrains.role.model",
744            "Params",
745            vec![],
746            vec![],
747        )
748    }
749}
750
751// ---------------------------------------------------------------
752// Index
753// ---------------------------------------------------------------
754
755/// Generic Index slot placeholder. Bind a concrete `IndexRuntime` via
756/// `Node::with_index(...)`. Exposes the three role-method DSL
757/// operations: `Add`, `Search`,
758/// `Remove`.
759#[derive(Debug, Clone, Copy, Default)]
760pub struct IndexSlot;
761
762impl IndexSlot {
763    /// `Add(vec) -> cmd` - Shape 2 (stateful insert).
764    pub fn add(&self, g: &mut Graph, vec: Output) -> Output {
765        record_role_op_one(
766            g,
767            self,
768            "IndexRuntime",
769            "ai.bytesandbrains.role.index",
770            "Add",
771            vec![vec.name],
772            vec![],
773        )
774    }
775
776    /// `Search(query, k=...)` - Shape 2 typically, Shape 1 for
777    /// in-memory flat indexes.
778    pub fn search(&self, g: &mut Graph, query: Output, k: i64) -> Output {
779        record_role_op_one(
780            g,
781            self,
782            "IndexRuntime",
783            "ai.bytesandbrains.role.index",
784            "Search",
785            vec![query.name],
786            vec![attr_int("k", k)],
787        )
788    }
789
790    /// `Remove(id) -> cmd` - Shape 2 (stateful delete).
791    pub fn remove(&self, g: &mut Graph, id: Output) -> Output {
792        record_role_op_one(
793            g,
794            self,
795            "IndexRuntime",
796            "ai.bytesandbrains.role.index",
797            "Remove",
798            vec![id.name],
799            vec![],
800        )
801    }
802
803    /// `Train(samples) -> trigger` — fire-and-forget calibration pass.
804    /// The output is `TYPE_TRIGGER` so authors that need to gate body
805    /// ops on training completion can wire the trigger through a
806    /// `bb.barrier` or place the call in `Module::bootstrap` to run
807    /// before body ops fire.
808    pub fn train(&self, g: &mut Graph, samples: Output) -> Output {
809        let slot_id = g.register_generic(self, "IndexRuntime");
810        let out_name = g.next_site_name();
811        g.push_node(NodeProto {
812            op_type: "Train".into(),
813            domain: "ai.bytesandbrains.role.index".into(),
814            input: vec![samples.name],
815            output: vec![out_name.clone()],
816            metadata_props: vec![
817                kv("ai.bytesandbrains.required_trait", "IndexRuntime"),
818                kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
819                kv("ai.bytesandbrains.index.port", "samples"),
820            ],
821            ..Default::default()
822        });
823        g.declare_value_info(&out_name, &TYPE_TRIGGER);
824        Output::new(out_name, &TYPE_TRIGGER)
825    }
826}
827
828// ---------------------------------------------------------------
829// Aggregator
830// ---------------------------------------------------------------
831
832/// Generic Aggregator slot placeholder. Bind a concrete
833/// `AggregatorRuntime` at compile time via
834/// `Compiler::new().bind_aggregator::<T>("slot")`. Exposes
835/// `Contribute` + `Aggregate`.
836///
837/// Both ops carry an opaque metadata channel alongside the
838/// tensor: `Contribute` takes `(contribution, metadata)`;
839/// `Aggregate` returns `(params, metadata)`. This is the channel
840/// hierarchical aggregation rides on — a child aggregator emits
841/// its summed `num_samples` (or whatever schema the impl uses) so
842/// the parent layer's reduction can weight the child's
843/// contribution correctly.
844#[derive(Debug, Clone, Copy, Default)]
845pub struct AggregatorSlot;
846
847impl AggregatorSlot {
848    /// `Contribute(contribution, metadata) -> cmd` - Shape 2
849    /// (buffer write). `metadata` is impl-defined bytes (e.g. a
850    /// sample count for FedAvg); pass an empty `Output` when the
851    /// impl has no metadata channel.
852    pub fn contribute(&self, g: &mut Graph, contribution: Output, metadata: Output) -> Output {
853        record_role_op_one(
854            g,
855            self,
856            "AggregatorRuntime",
857            "ai.bytesandbrains.role.aggregator",
858            "Contribute",
859            vec![contribution.name, metadata.name],
860            vec![],
861        )
862    }
863
864    /// `Aggregate(trigger) -> (params, metadata)` - Shape 1 (mean /
865    /// weighted sum / replace expressible as `ai.onnx`). The output
866    /// edge fires only when the reduction completes; downstream
867    /// consumers read `params` AND the aggregation's accompanying
868    /// `metadata` (e.g. summed `num_samples` for hierarchical
869    /// FedAvg) directly off the op's two outputs — no separate
870    /// `current_tensor` read needed.
871    pub fn aggregate(&self, g: &mut Graph, trigger: Output) -> (Output, Output) {
872        let mut outs = record_role_op(
873            g,
874            self,
875            "AggregatorRuntime",
876            "ai.bytesandbrains.role.aggregator",
877            "Aggregate",
878            vec![trigger.name],
879            2,
880            vec![],
881        );
882        let metadata = outs.pop().expect("two outputs");
883        let params = outs.pop().expect("two outputs");
884        (params, metadata)
885    }
886}
887
888// ---------------------------------------------------------------
889// Codec
890// ---------------------------------------------------------------
891
892/// Generic Codec slot placeholder. Embed as `codec: CodecSlot` in your
893/// Module struct; bind a concrete `CodecRuntime` via
894/// `Compiler::new().bind_codec::<T>("slot")…`.
895#[derive(Debug, Clone, Copy, Default)]
896pub struct CodecSlot;
897
898impl CodecSlot {
899    /// `Train(samples) → trigger` — optional calibration pass.
900    /// Quantizers learn scale/zero-point, PQ codebooks run k-means,
901    /// dtype casts skip the call. The output rides `TYPE_TRIGGER` so
902    /// `Module::bootstrap` (or a downstream `bb.barrier`) can gate
903    /// body ops on training completion. Stamps
904    /// `ai.bytesandbrains.codec.port = "in"` since samples flow at
905    /// the In storage position.
906    pub fn train(&self, g: &mut Graph, samples: Output) -> Output {
907        let slot_id = g.register_generic(self, "CodecRuntime");
908        let out_name = g.next_site_name();
909        g.push_node(NodeProto {
910            op_type: "Train".into(),
911            domain: "ai.bytesandbrains.role.codec".into(),
912            input: vec![samples.name],
913            output: vec![out_name.clone()],
914            metadata_props: vec![
915                kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
916                kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
917                kv("ai.bytesandbrains.codec.port", "in"),
918            ],
919            ..Default::default()
920        });
921        g.declare_value_info(&out_name, &TYPE_TRIGGER);
922        Output::new(out_name, &TYPE_TRIGGER)
923    }
924
925    /// `Encode(input) → output` — In → Out direction. Stamps
926    /// `ai.bytesandbrains.codec.port = "out"` on the NodeProto so
927    /// the refinement pass knows which port of the bound concrete's
928    /// `<In, Out>` to read for the output denotation.
929    pub fn encode(&self, g: &mut Graph, input: Output) -> Output {
930        let slot_id = g.register_generic(self, "CodecRuntime");
931        let out_name = g.next_site_name();
932        g.push_node(NodeProto {
933            op_type: "Encode".into(),
934            domain: "ai.bytesandbrains.role.codec".into(),
935            input: vec![input.name],
936            output: vec![out_name.clone()],
937            metadata_props: vec![
938                kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
939                kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
940                kv("ai.bytesandbrains.codec.port", "out"),
941            ],
942            ..Default::default()
943        });
944        g.declare_value_info(&out_name, &TYPE_TENSOR);
945        Output::new(out_name, &TYPE_TENSOR)
946    }
947
948    /// `Decode(encoded) → output` — Out → In direction. Stamps
949    /// `ai.bytesandbrains.codec.port = "in"`.
950    pub fn decode(&self, g: &mut Graph, encoded: Output) -> Output {
951        let slot_id = g.register_generic(self, "CodecRuntime");
952        let out_name = g.next_site_name();
953        g.push_node(NodeProto {
954            op_type: "Decode".into(),
955            domain: "ai.bytesandbrains.role.codec".into(),
956            input: vec![encoded.name],
957            output: vec![out_name.clone()],
958            metadata_props: vec![
959                kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
960                kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
961                kv("ai.bytesandbrains.codec.port", "in"),
962            ],
963            ..Default::default()
964        });
965        g.declare_value_info(&out_name, &TYPE_TENSOR);
966        Output::new(out_name, &TYPE_TENSOR)
967    }
968}
969
970// ---------------------------------------------------------------
971// DataLoader
972// ---------------------------------------------------------------
973
974/// Generic DataSource slot placeholder. Bind a concrete
975/// `DataSourceRuntime` via `Node::with_data_source(...)`.
976/// Exposes `NextBatch`, `Reset`, `OnDataLoaded` per
977/// `docs/IR_AND_DSL.md` §5c.2.
978#[derive(Debug, Clone, Copy, Default)]
979pub struct DataLoaderSlot;
980
981impl DataLoaderSlot {
982    /// `NextBatch() -> (batch, labels)` - Shape 2 (data source has
983    /// side effects). Returns two `Output` handles; the second is
984    /// optional in spec but always materialized in the DSL surface
985    /// for shape symmetry.
986    pub fn next_batch(&self, g: &mut Graph) -> (Output, Output) {
987        let mut outs = record_role_op(
988            g,
989            self,
990            "DataSourceRuntime",
991            "ai.bytesandbrains.role.data_source",
992            "NextBatch",
993            vec![],
994            2,
995            vec![],
996        );
997        let labels = outs.pop().expect("two outputs");
998        let batch = outs.pop().expect("two outputs");
999        (batch, labels)
1000    }
1001
1002    /// `Reset(trigger) -> trigger` - Shape 2.
1003    pub fn reset(&self, g: &mut Graph, trigger: Output) -> Output {
1004        record_role_op_one(
1005            g,
1006            self,
1007            "DataSourceRuntime",
1008            "ai.bytesandbrains.role.data_source",
1009            "Reset",
1010            vec![trigger.name],
1011            vec![],
1012        )
1013    }
1014
1015    /// `OnDataLoaded() -> trigger` - Shape 2 (one-shot
1016    /// notification).
1017    pub fn on_data_loaded(&self, g: &mut Graph) -> Output {
1018        record_role_op_one(
1019            g,
1020            self,
1021            "DataSourceRuntime",
1022            "ai.bytesandbrains.role.data_source",
1023            "OnDataLoaded",
1024            vec![],
1025            vec![],
1026        )
1027    }
1028}
1029
1030// ---------------------------------------------------------------
1031// PeerSelector
1032// ---------------------------------------------------------------
1033
1034/// Generic PeerSelector slot placeholder. Bind a concrete
1035/// `PeerSelectorRuntime` via `Node::with_peer_selector(...)`.
1036/// Exposes `Sample`, `CurrentView`
1037///
1038/// `class` tags every `Output<PeerId>` this placeholder yields with
1039/// the class of peer it samples from. The compiler's
1040/// `infer_peer_classes` pass reads that tag to attribute downstream
1041/// `wire.send`s to the right destination class - that's what makes
1042/// gossip's self-send case partition correctly (1 class → 1
1043/// partition with both send and synthesized recv).
1044#[derive(Debug, Clone, Copy)]
1045pub struct PeerSelectorSlot {
1046    /// Class identifier the compiler stamps onto every produced
1047    /// `Output<PeerId>` via `peer_class` metadata. Defaults to
1048    /// [`bb_ir::peer_class::SELF_CLASS`] for placeholders
1049    /// constructed via `Default` - that puts samples on the same
1050    /// class as the registering Node, the natural gossip case.
1051    pub class: &'static str,
1052}
1053
1054impl Default for PeerSelectorSlot {
1055    fn default() -> Self {
1056        Self {
1057            class: bb_ir::peer_class::SELF_CLASS,
1058        }
1059    }
1060}
1061
1062impl PeerSelectorSlot {
1063    /// Construct a sampling placeholder bound to the given class. Use
1064    /// `PeerSelectorSlot::of_class("gossip_peer")` to tag samples as
1065    /// "gossip peers" so a downstream `wire.send(payload, neighbor)`
1066    /// puts its `data` output on the same `gossip_peer` partition.
1067    pub fn of_class(class: &'static str) -> Self {
1068        Self { class }
1069    }
1070
1071    fn record_peer_op(&self, g: &mut Graph, op_type: &str, attrs: Vec<AttributeProto>) -> Output {
1072        let slot_id = g.register_generic(self, "PeerSelectorRuntime");
1073        let out_name = g.next_site_name();
1074        g.push_node(NodeProto {
1075            op_type: op_type.into(),
1076            domain: "ai.bytesandbrains.role.peer_selector".into(),
1077            input: vec![],
1078            output: vec![out_name.clone()],
1079            attribute: attrs,
1080            metadata_props: vec![
1081                kv("ai.bytesandbrains.required_trait", "PeerSelectorRuntime"),
1082                kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
1083                kv(bb_ir::peer_class::PEER_CLASS_KEY, self.class),
1084            ],
1085            ..Default::default()
1086        });
1087        g.declare_value_info(&out_name, &bb_ir::types::TYPE_PEER_ID);
1088        Output::new(out_name, &bb_ir::types::TYPE_PEER_ID)
1089    }
1090
1091    /// `Sample(n) -> peers` - Shape 2 (state-dependent sampling).
1092    pub fn sample(&self, g: &mut Graph, n: i64) -> Output {
1093        self.record_peer_op(g, "Sample", vec![attr_int("n", n)])
1094    }
1095
1096    /// `CurrentView() -> view` - Shape 2 (state read).
1097    pub fn current_view(&self, g: &mut Graph) -> Output {
1098        self.record_peer_op(g, "CurrentView", vec![])
1099    }
1100}
1101
1102// ---------------------------------------------------------------
1103// Protocol
1104// ---------------------------------------------------------------
1105
1106/// Generic Protocol slot placeholder. Bind a concrete protocol at this
1107/// slot via the compiler chain
1108/// (`Compiler::new().bind_protocol::<T>("slot").compile(...)`).
1109///
1110/// Protocols have no user-facing DSL ops on the placeholder itself —
1111/// they're stateful control-plane runtimes that surface via
1112/// `dispatch_atomic` against their per-impl atomic opset
1113/// (`register_protocol!{}` emits the DSL methods for the impl's own
1114/// atomic opset). The placeholder exists solely so Modules can
1115/// declare a generic Protocol slot.
1116#[derive(Debug, Clone, Copy, Default)]
1117pub struct ProtocolSlot;
1118