1use bb_dsl::graph::{attr_float, attr_graph, attr_int, attr_ints, attr_tensor, kv, Graph};
15use bb_dsl::output::Output;
16use bb_ir::proto::onnx::{AttributeProto, GraphProto, NodeProto, TensorProto};
17use bb_ir::types::{TYPE_TENSOR, TYPE_TENSOR_F32, TYPE_TRIGGER};
18
19#[derive(Debug, Clone, Copy, Default)]
23pub struct BackendSlot;
24
25impl BackendSlot {
26 fn record_op(
31 &self,
32 g: &mut Graph,
33 op_type: &str,
34 input_names: Vec<String>,
35 n_outputs: usize,
36 attribute: Vec<AttributeProto>,
37 ) -> Vec<Output> {
38 let slot_id = g.register_generic(self, "BackendRuntime");
39 let output_names: Vec<String> = (0..n_outputs).map(|_| g.next_site_name()).collect();
40 g.push_node(NodeProto {
41 op_type: op_type.into(),
42 domain: "ai.onnx".into(),
43 input: input_names,
44 output: output_names.clone(),
45 attribute,
46 metadata_props: vec![
47 kv("ai.bytesandbrains.required_trait", "BackendRuntime"),
48 kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
49 ],
50 ..Default::default()
51 });
52 for name in &output_names {
55 g.declare_value_info(name, &TYPE_TENSOR_F32);
56 }
57 output_names
58 .into_iter()
59 .map(|n| Output::new(n, &TYPE_TENSOR_F32))
60 .collect()
61 }
62
63 fn record_one(
64 &self,
65 g: &mut Graph,
66 op_type: &str,
67 input_names: Vec<String>,
68 attribute: Vec<AttributeProto>,
69 ) -> Output {
70 self.record_op(g, op_type, input_names, 1, attribute)
71 .into_iter()
72 .next()
73 .expect("record_op with n_outputs=1")
74 }
75
76 pub fn zeros(&self, g: &mut Graph, dims: Vec<i64>) -> Output {
80 self.record_one(g, "Zeros", vec![], vec![attr_ints("dims", dims)])
81 }
82
83 pub fn ones(&self, g: &mut Graph, dims: Vec<i64>) -> Output {
85 self.record_one(g, "Ones", vec![], vec![attr_ints("dims", dims)])
86 }
87
88 pub fn constant(&self, g: &mut Graph, value: TensorProto) -> Output {
90 self.record_one(g, "Constant", vec![], vec![attr_tensor("value", value)])
91 }
92
93 pub fn add(&self, g: &mut Graph, a: Output, b: Output) -> Output {
97 self.record_one(g, "Add", vec![a.name, b.name], vec![])
98 }
99
100 pub fn sub(&self, g: &mut Graph, a: Output, b: Output) -> Output {
102 self.record_one(g, "Sub", vec![a.name, b.name], vec![])
103 }
104
105 pub fn mul(&self, g: &mut Graph, a: Output, b: Output) -> Output {
107 self.record_one(g, "Mul", vec![a.name, b.name], vec![])
108 }
109
110 pub fn div(&self, g: &mut Graph, a: Output, b: Output) -> Output {
112 self.record_one(g, "Div", vec![a.name, b.name], vec![])
113 }
114
115 pub fn neg(&self, g: &mut Graph, t: Output) -> Output {
117 self.record_one(g, "Neg", vec![t.name], vec![])
118 }
119
120 pub fn abs(&self, g: &mut Graph, t: Output) -> Output {
122 self.record_one(g, "Abs", vec![t.name], vec![])
123 }
124
125 pub fn sqrt(&self, g: &mut Graph, t: Output) -> Output {
127 self.record_one(g, "Sqrt", vec![t.name], vec![])
128 }
129
130 pub fn exp(&self, g: &mut Graph, t: Output) -> Output {
132 self.record_one(g, "Exp", vec![t.name], vec![])
133 }
134
135 pub fn log(&self, g: &mut Graph, t: Output) -> Output {
137 self.record_one(g, "Log", vec![t.name], vec![])
138 }
139
140 pub fn pow(&self, g: &mut Graph, a: Output, b: Output) -> Output {
142 self.record_one(g, "Pow", vec![a.name, b.name], vec![])
143 }
144
145 pub fn matmul(&self, g: &mut Graph, a: Output, b: Output) -> Output {
149 self.record_one(g, "MatMul", vec![a.name, b.name], vec![])
150 }
151
152 #[allow(clippy::too_many_arguments)]
154 pub fn gemm(
155 &self,
156 g: &mut Graph,
157 a: Output,
158 b: Output,
159 c: Option<Output>,
160 alpha: f32,
161 beta: f32,
162 trans_a: bool,
163 trans_b: bool,
164 ) -> Output {
165 let mut inputs = vec![a.name, b.name];
166 if let Some(c) = c {
167 inputs.push(c.name);
168 }
169 self.record_one(
170 g,
171 "Gemm",
172 inputs,
173 vec![
174 attr_float("alpha", alpha),
175 attr_float("beta", beta),
176 attr_int("transA", trans_a as i64),
177 attr_int("transB", trans_b as i64),
178 ],
179 )
180 }
181
182 pub fn dot(&self, g: &mut Graph, a: Output, b: Output) -> Output {
184 self.record_one(g, "Dot", vec![a.name, b.name], vec![])
185 }
186
187 pub fn relu(&self, g: &mut Graph, t: Output) -> Output {
191 self.record_one(g, "Relu", vec![t.name], vec![])
192 }
193
194 pub fn sigmoid(&self, g: &mut Graph, t: Output) -> Output {
196 self.record_one(g, "Sigmoid", vec![t.name], vec![])
197 }
198
199 pub fn tanh(&self, g: &mut Graph, t: Output) -> Output {
201 self.record_one(g, "Tanh", vec![t.name], vec![])
202 }
203
204 pub fn softmax(&self, g: &mut Graph, t: Output, axis: i64) -> Output {
206 self.record_one(g, "Softmax", vec![t.name], vec![attr_int("axis", axis)])
207 }
208
209 pub fn leaky_relu(&self, g: &mut Graph, t: Output, alpha: f32) -> Output {
211 self.record_one(
212 g,
213 "LeakyRelu",
214 vec![t.name],
215 vec![attr_float("alpha", alpha)],
216 )
217 }
218
219 pub fn gelu(&self, g: &mut Graph, t: Output) -> Output {
221 self.record_one(g, "Gelu", vec![t.name], vec![])
222 }
223
224 pub fn reshape(&self, g: &mut Graph, t: Output, dims: Vec<i64>) -> Output {
228 self.record_one(g, "Reshape", vec![t.name], vec![attr_ints("dims", dims)])
229 }
230
231 pub fn transpose(&self, g: &mut Graph, t: Output, perm: Option<Vec<i64>>) -> Output {
233 let attrs = match perm {
234 Some(p) => vec![attr_ints("perm", p)],
235 None => vec![],
236 };
237 self.record_one(g, "Transpose", vec![t.name], attrs)
238 }
239
240 pub fn concat(&self, g: &mut Graph, tensors: Vec<Output>, axis: i64) -> Output {
242 let inputs = tensors.into_iter().map(|t| t.name).collect();
243 self.record_one(g, "Concat", inputs, vec![attr_int("axis", axis)])
244 }
245
246 pub fn split(&self, g: &mut Graph, t: Output, axis: i64, sizes: Vec<i64>) -> Vec<Output> {
249 let n = sizes.len();
250 self.record_op(
251 g,
252 "Split",
253 vec![t.name],
254 n,
255 vec![attr_int("axis", axis), attr_ints("split", sizes)],
256 )
257 }
258
259 pub fn slice(
261 &self,
262 g: &mut Graph,
263 t: Output,
264 starts: Vec<i64>,
265 ends: Vec<i64>,
266 axes: Option<Vec<i64>>,
267 steps: Option<Vec<i64>>,
268 ) -> Output {
269 let mut attrs = vec![attr_ints("starts", starts), attr_ints("ends", ends)];
270 if let Some(a) = axes {
271 attrs.push(attr_ints("axes", a));
272 }
273 if let Some(s) = steps {
274 attrs.push(attr_ints("steps", s));
275 }
276 self.record_one(g, "Slice", vec![t.name], attrs)
277 }
278
279 pub fn squeeze(&self, g: &mut Graph, t: Output, axes: Option<Vec<i64>>) -> Output {
281 let attrs = match axes {
282 Some(a) => vec![attr_ints("axes", a)],
283 None => vec![],
284 };
285 self.record_one(g, "Squeeze", vec![t.name], attrs)
286 }
287
288 pub fn unsqueeze(&self, g: &mut Graph, t: Output, axes: Vec<i64>) -> Output {
290 self.record_one(g, "Unsqueeze", vec![t.name], vec![attr_ints("axes", axes)])
291 }
292
293 pub fn identity(&self, g: &mut Graph, t: Output) -> Output {
295 self.record_one(g, "Identity", vec![t.name], vec![])
296 }
297
298 pub fn cast(&self, g: &mut Graph, t: Output, to_elem_type: i32) -> Output {
300 self.record_one(
301 g,
302 "Cast",
303 vec![t.name],
304 vec![attr_int("to", to_elem_type as i64)],
305 )
306 }
307
308 fn reduce(
311 &self,
312 g: &mut Graph,
313 op_type: &str,
314 t: Output,
315 axes: Option<Vec<i64>>,
316 keepdims: bool,
317 ) -> Output {
318 let mut attrs = vec![attr_int("keepdims", keepdims as i64)];
319 if let Some(a) = axes {
320 attrs.push(attr_ints("axes", a));
321 }
322 self.record_one(g, op_type, vec![t.name], attrs)
323 }
324
325 pub fn reduce_sum(
327 &self,
328 g: &mut Graph,
329 t: Output,
330 axes: Option<Vec<i64>>,
331 keepdims: bool,
332 ) -> Output {
333 self.reduce(g, "ReduceSum", t, axes, keepdims)
334 }
335
336 pub fn reduce_mean(
338 &self,
339 g: &mut Graph,
340 t: Output,
341 axes: Option<Vec<i64>>,
342 keepdims: bool,
343 ) -> Output {
344 self.reduce(g, "ReduceMean", t, axes, keepdims)
345 }
346
347 pub fn reduce_max(
349 &self,
350 g: &mut Graph,
351 t: Output,
352 axes: Option<Vec<i64>>,
353 keepdims: bool,
354 ) -> Output {
355 self.reduce(g, "ReduceMax", t, axes, keepdims)
356 }
357
358 pub fn reduce_min(
360 &self,
361 g: &mut Graph,
362 t: Output,
363 axes: Option<Vec<i64>>,
364 keepdims: bool,
365 ) -> Output {
366 self.reduce(g, "ReduceMin", t, axes, keepdims)
367 }
368
369 pub fn equal(&self, g: &mut Graph, a: Output, b: Output) -> Output {
373 self.record_one(g, "Equal", vec![a.name, b.name], vec![])
374 }
375
376 pub fn greater(&self, g: &mut Graph, a: Output, b: Output) -> Output {
378 self.record_one(g, "Greater", vec![a.name, b.name], vec![])
379 }
380
381 pub fn less(&self, g: &mut Graph, a: Output, b: Output) -> Output {
383 self.record_one(g, "Less", vec![a.name, b.name], vec![])
384 }
385
386 #[allow(clippy::too_many_arguments)]
390 pub fn batch_normalization(
391 &self,
392 g: &mut Graph,
393 input: Output,
394 scale: Output,
395 bias: Output,
396 mean: Output,
397 variance: Output,
398 epsilon: f32,
399 momentum: f32,
400 ) -> Output {
401 self.record_one(
402 g,
403 "BatchNormalization",
404 vec![input.name, scale.name, bias.name, mean.name, variance.name],
405 vec![
406 attr_float("epsilon", epsilon),
407 attr_float("momentum", momentum),
408 ],
409 )
410 }
411
412 pub fn layer_normalization(
414 &self,
415 g: &mut Graph,
416 input: Output,
417 scale: Output,
418 bias: Option<Output>,
419 axis: i64,
420 epsilon: f32,
421 ) -> Output {
422 let mut inputs = vec![input.name, scale.name];
423 if let Some(b) = bias {
424 inputs.push(b.name);
425 }
426 self.record_one(
427 g,
428 "LayerNormalization",
429 inputs,
430 vec![attr_int("axis", axis), attr_float("epsilon", epsilon)],
431 )
432 }
433
434 #[allow(clippy::too_many_arguments)]
438 pub fn conv(
439 &self,
440 g: &mut Graph,
441 input: Output,
442 weight: Output,
443 bias: Option<Output>,
444 kernel_shape: Vec<i64>,
445 strides: Vec<i64>,
446 pads: Vec<i64>,
447 dilations: Vec<i64>,
448 group: i64,
449 ) -> Output {
450 let mut inputs = vec![input.name, weight.name];
451 if let Some(b) = bias {
452 inputs.push(b.name);
453 }
454 self.record_one(
455 g,
456 "Conv",
457 inputs,
458 vec![
459 attr_ints("kernel_shape", kernel_shape),
460 attr_ints("strides", strides),
461 attr_ints("pads", pads),
462 attr_ints("dilations", dilations),
463 attr_int("group", group),
464 ],
465 )
466 }
467
468 pub fn max_pool(
470 &self,
471 g: &mut Graph,
472 input: Output,
473 kernel_shape: Vec<i64>,
474 strides: Vec<i64>,
475 pads: Vec<i64>,
476 ) -> Output {
477 self.record_one(
478 g,
479 "MaxPool",
480 vec![input.name],
481 vec![
482 attr_ints("kernel_shape", kernel_shape),
483 attr_ints("strides", strides),
484 attr_ints("pads", pads),
485 ],
486 )
487 }
488
489 pub fn average_pool(
491 &self,
492 g: &mut Graph,
493 input: Output,
494 kernel_shape: Vec<i64>,
495 strides: Vec<i64>,
496 pads: Vec<i64>,
497 count_include_pad: bool,
498 ) -> Output {
499 self.record_one(
500 g,
501 "AveragePool",
502 vec![input.name],
503 vec![
504 attr_ints("kernel_shape", kernel_shape),
505 attr_ints("strides", strides),
506 attr_ints("pads", pads),
507 attr_int("count_include_pad", count_include_pad as i64),
508 ],
509 )
510 }
511
512 pub fn global_average_pool(&self, g: &mut Graph, input: Output) -> Output {
514 self.record_one(g, "GlobalAveragePool", vec![input.name], vec![])
515 }
516
517 pub fn gather(&self, g: &mut Graph, data: Output, indices: Output, axis: i64) -> Output {
521 self.record_one(
522 g,
523 "Gather",
524 vec![data.name, indices.name],
525 vec![attr_int("axis", axis)],
526 )
527 }
528
529 pub fn scatter(
531 &self,
532 g: &mut Graph,
533 data: Output,
534 indices: Output,
535 updates: Output,
536 axis: i64,
537 ) -> Output {
538 self.record_one(
539 g,
540 "Scatter",
541 vec![data.name, indices.name, updates.name],
542 vec![attr_int("axis", axis)],
543 )
544 }
545
546 pub fn if_op(
552 &self,
553 g: &mut Graph,
554 cond: Output,
555 then_branch: GraphProto,
556 else_branch: GraphProto,
557 n_outputs: usize,
558 ) -> Vec<Output> {
559 self.record_op(
560 g,
561 "If",
562 vec![cond.name],
563 n_outputs,
564 vec![
565 attr_graph("then_branch", then_branch),
566 attr_graph("else_branch", else_branch),
567 ],
568 )
569 }
570
571 pub fn loop_op(
574 &self,
575 g: &mut Graph,
576 max_trip_count: Option<Output>,
577 cond: Option<Output>,
578 body: GraphProto,
579 initial: Vec<Output>,
580 n_outputs: usize,
581 ) -> Vec<Output> {
582 let mut inputs = vec![
583 max_trip_count.map(|o| o.name).unwrap_or_default(),
584 cond.map(|o| o.name).unwrap_or_default(),
585 ];
586 inputs.extend(initial.into_iter().map(|o| o.name));
587 self.record_op(g, "Loop", inputs, n_outputs, vec![attr_graph("body", body)])
588 }
589}
590
591#[allow(clippy::too_many_arguments)]
601fn record_role_op<P: 'static>(
602 g: &mut Graph,
603 placeholder: &P,
604 required_trait: &'static str,
605 role_domain: &'static str,
606 op_type: &str,
607 input_names: Vec<String>,
608 n_outputs: usize,
609 attribute: Vec<AttributeProto>,
610) -> Vec<Output> {
611 let slot_id = g.register_generic(placeholder, required_trait);
612 let output_names: Vec<String> = (0..n_outputs).map(|_| g.next_site_name()).collect();
613 g.push_node(NodeProto {
614 op_type: op_type.into(),
615 domain: role_domain.into(),
616 input: input_names,
617 output: output_names.clone(),
618 attribute,
619 metadata_props: vec![
620 kv("ai.bytesandbrains.required_trait", required_trait),
621 kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
622 ],
623 ..Default::default()
624 });
625 for name in &output_names {
626 g.declare_value_info(name, &TYPE_TENSOR);
627 }
628 output_names
629 .into_iter()
630 .map(|n| Output::new(n, &TYPE_TENSOR))
631 .collect()
632}
633
634#[allow(clippy::too_many_arguments)]
635fn record_role_op_one(
636 g: &mut Graph,
637 placeholder: &impl std::any::Any,
638 required_trait: &'static str,
639 role_domain: &'static str,
640 op_type: &str,
641 input_names: Vec<String>,
642 attribute: Vec<AttributeProto>,
643) -> Output {
644 record_role_op(
645 g,
646 placeholder,
647 required_trait,
648 role_domain,
649 op_type,
650 input_names,
651 1,
652 attribute,
653 )
654 .into_iter()
655 .next()
656 .expect("record_role_op with n_outputs=1")
657}
658
659#[derive(Debug, Clone, Copy, Default)]
669pub struct ModelSlot;
670
671impl ModelSlot {
672 pub fn forward(&self, g: &mut Graph, input: Output) -> Output {
674 record_role_op_one(
675 g,
676 self,
677 "ModelRuntime",
678 "ai.bytesandbrains.role.model",
679 "Forward",
680 vec![input.name],
681 vec![],
682 )
683 }
684
685 pub fn backward(&self, g: &mut Graph, grad: Output) -> Output {
687 record_role_op_one(
688 g,
689 self,
690 "ModelRuntime",
691 "ai.bytesandbrains.role.model",
692 "Backward",
693 vec![grad.name],
694 vec![],
695 )
696 }
697
698 pub fn compute_loss(&self, g: &mut Graph, input: Output, target: Output) -> Output {
700 record_role_op_one(
701 g,
702 self,
703 "ModelRuntime",
704 "ai.bytesandbrains.role.model",
705 "ComputeLoss",
706 vec![input.name, target.name],
707 vec![],
708 )
709 }
710
711 pub fn apply_delta(&self, g: &mut Graph, delta: Output) -> Output {
713 record_role_op_one(
714 g,
715 self,
716 "ModelRuntime",
717 "ai.bytesandbrains.role.model",
718 "ApplyDelta",
719 vec![delta.name],
720 vec![],
721 )
722 }
723
724 pub fn load_parameters(&self, g: &mut Graph, params: Output) -> Output {
726 record_role_op_one(
727 g,
728 self,
729 "ModelRuntime",
730 "ai.bytesandbrains.role.model",
731 "LoadParameters",
732 vec![params.name],
733 vec![],
734 )
735 }
736
737 pub fn params(&self, g: &mut Graph) -> Output {
739 record_role_op_one(
740 g,
741 self,
742 "ModelRuntime",
743 "ai.bytesandbrains.role.model",
744 "Params",
745 vec![],
746 vec![],
747 )
748 }
749}
750
751#[derive(Debug, Clone, Copy, Default)]
760pub struct IndexSlot;
761
762impl IndexSlot {
763 pub fn add(&self, g: &mut Graph, vec: Output) -> Output {
765 record_role_op_one(
766 g,
767 self,
768 "IndexRuntime",
769 "ai.bytesandbrains.role.index",
770 "Add",
771 vec![vec.name],
772 vec![],
773 )
774 }
775
776 pub fn search(&self, g: &mut Graph, query: Output, k: i64) -> Output {
779 record_role_op_one(
780 g,
781 self,
782 "IndexRuntime",
783 "ai.bytesandbrains.role.index",
784 "Search",
785 vec![query.name],
786 vec![attr_int("k", k)],
787 )
788 }
789
790 pub fn remove(&self, g: &mut Graph, id: Output) -> Output {
792 record_role_op_one(
793 g,
794 self,
795 "IndexRuntime",
796 "ai.bytesandbrains.role.index",
797 "Remove",
798 vec![id.name],
799 vec![],
800 )
801 }
802
803 pub fn train(&self, g: &mut Graph, samples: Output) -> Output {
809 let slot_id = g.register_generic(self, "IndexRuntime");
810 let out_name = g.next_site_name();
811 g.push_node(NodeProto {
812 op_type: "Train".into(),
813 domain: "ai.bytesandbrains.role.index".into(),
814 input: vec![samples.name],
815 output: vec![out_name.clone()],
816 metadata_props: vec![
817 kv("ai.bytesandbrains.required_trait", "IndexRuntime"),
818 kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
819 kv("ai.bytesandbrains.index.port", "samples"),
820 ],
821 ..Default::default()
822 });
823 g.declare_value_info(&out_name, &TYPE_TRIGGER);
824 Output::new(out_name, &TYPE_TRIGGER)
825 }
826}
827
828#[derive(Debug, Clone, Copy, Default)]
845pub struct AggregatorSlot;
846
847impl AggregatorSlot {
848 pub fn contribute(&self, g: &mut Graph, contribution: Output, metadata: Output) -> Output {
853 record_role_op_one(
854 g,
855 self,
856 "AggregatorRuntime",
857 "ai.bytesandbrains.role.aggregator",
858 "Contribute",
859 vec![contribution.name, metadata.name],
860 vec![],
861 )
862 }
863
864 pub fn aggregate(&self, g: &mut Graph, trigger: Output) -> (Output, Output) {
872 let mut outs = record_role_op(
873 g,
874 self,
875 "AggregatorRuntime",
876 "ai.bytesandbrains.role.aggregator",
877 "Aggregate",
878 vec![trigger.name],
879 2,
880 vec![],
881 );
882 let metadata = outs.pop().expect("two outputs");
883 let params = outs.pop().expect("two outputs");
884 (params, metadata)
885 }
886}
887
888#[derive(Debug, Clone, Copy, Default)]
896pub struct CodecSlot;
897
898impl CodecSlot {
899 pub fn train(&self, g: &mut Graph, samples: Output) -> Output {
907 let slot_id = g.register_generic(self, "CodecRuntime");
908 let out_name = g.next_site_name();
909 g.push_node(NodeProto {
910 op_type: "Train".into(),
911 domain: "ai.bytesandbrains.role.codec".into(),
912 input: vec![samples.name],
913 output: vec![out_name.clone()],
914 metadata_props: vec![
915 kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
916 kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
917 kv("ai.bytesandbrains.codec.port", "in"),
918 ],
919 ..Default::default()
920 });
921 g.declare_value_info(&out_name, &TYPE_TRIGGER);
922 Output::new(out_name, &TYPE_TRIGGER)
923 }
924
925 pub fn encode(&self, g: &mut Graph, input: Output) -> Output {
930 let slot_id = g.register_generic(self, "CodecRuntime");
931 let out_name = g.next_site_name();
932 g.push_node(NodeProto {
933 op_type: "Encode".into(),
934 domain: "ai.bytesandbrains.role.codec".into(),
935 input: vec![input.name],
936 output: vec![out_name.clone()],
937 metadata_props: vec![
938 kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
939 kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
940 kv("ai.bytesandbrains.codec.port", "out"),
941 ],
942 ..Default::default()
943 });
944 g.declare_value_info(&out_name, &TYPE_TENSOR);
945 Output::new(out_name, &TYPE_TENSOR)
946 }
947
948 pub fn decode(&self, g: &mut Graph, encoded: Output) -> Output {
951 let slot_id = g.register_generic(self, "CodecRuntime");
952 let out_name = g.next_site_name();
953 g.push_node(NodeProto {
954 op_type: "Decode".into(),
955 domain: "ai.bytesandbrains.role.codec".into(),
956 input: vec![encoded.name],
957 output: vec![out_name.clone()],
958 metadata_props: vec![
959 kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
960 kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
961 kv("ai.bytesandbrains.codec.port", "in"),
962 ],
963 ..Default::default()
964 });
965 g.declare_value_info(&out_name, &TYPE_TENSOR);
966 Output::new(out_name, &TYPE_TENSOR)
967 }
968}
969
970#[derive(Debug, Clone, Copy, Default)]
979pub struct DataLoaderSlot;
980
981impl DataLoaderSlot {
982 pub fn next_batch(&self, g: &mut Graph) -> (Output, Output) {
987 let mut outs = record_role_op(
988 g,
989 self,
990 "DataSourceRuntime",
991 "ai.bytesandbrains.role.data_source",
992 "NextBatch",
993 vec![],
994 2,
995 vec![],
996 );
997 let labels = outs.pop().expect("two outputs");
998 let batch = outs.pop().expect("two outputs");
999 (batch, labels)
1000 }
1001
1002 pub fn reset(&self, g: &mut Graph, trigger: Output) -> Output {
1004 record_role_op_one(
1005 g,
1006 self,
1007 "DataSourceRuntime",
1008 "ai.bytesandbrains.role.data_source",
1009 "Reset",
1010 vec![trigger.name],
1011 vec![],
1012 )
1013 }
1014
1015 pub fn on_data_loaded(&self, g: &mut Graph) -> Output {
1018 record_role_op_one(
1019 g,
1020 self,
1021 "DataSourceRuntime",
1022 "ai.bytesandbrains.role.data_source",
1023 "OnDataLoaded",
1024 vec![],
1025 vec![],
1026 )
1027 }
1028}
1029
1030#[derive(Debug, Clone, Copy)]
1045pub struct PeerSelectorSlot {
1046 pub class: &'static str,
1052}
1053
1054impl Default for PeerSelectorSlot {
1055 fn default() -> Self {
1056 Self {
1057 class: bb_ir::peer_class::SELF_CLASS,
1058 }
1059 }
1060}
1061
1062impl PeerSelectorSlot {
1063 pub fn of_class(class: &'static str) -> Self {
1068 Self { class }
1069 }
1070
1071 fn record_peer_op(&self, g: &mut Graph, op_type: &str, attrs: Vec<AttributeProto>) -> Output {
1072 let slot_id = g.register_generic(self, "PeerSelectorRuntime");
1073 let out_name = g.next_site_name();
1074 g.push_node(NodeProto {
1075 op_type: op_type.into(),
1076 domain: "ai.bytesandbrains.role.peer_selector".into(),
1077 input: vec![],
1078 output: vec![out_name.clone()],
1079 attribute: attrs,
1080 metadata_props: vec![
1081 kv("ai.bytesandbrains.required_trait", "PeerSelectorRuntime"),
1082 kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
1083 kv(bb_ir::peer_class::PEER_CLASS_KEY, self.class),
1084 ],
1085 ..Default::default()
1086 });
1087 g.declare_value_info(&out_name, &bb_ir::types::TYPE_PEER_ID);
1088 Output::new(out_name, &bb_ir::types::TYPE_PEER_ID)
1089 }
1090
1091 pub fn sample(&self, g: &mut Graph, n: i64) -> Output {
1093 self.record_peer_op(g, "Sample", vec![attr_int("n", n)])
1094 }
1095
1096 pub fn current_view(&self, g: &mut Graph) -> Output {
1098 self.record_peer_op(g, "CurrentView", vec![])
1099 }
1100}
1101
1102#[derive(Debug, Clone, Copy, Default)]
1117pub struct ProtocolSlot;
1118