use bb_dsl::graph::{attr_float, attr_graph, attr_int, attr_ints, attr_tensor, kv, Graph};
use bb_dsl::output::Output;
use bb_ir::proto::onnx::{AttributeProto, GraphProto, NodeProto, TensorProto};
use bb_ir::types::{TYPE_TENSOR, TYPE_TENSOR_F32, TYPE_TRIGGER};
#[derive(Debug, Clone, Copy, Default)]
pub struct BackendSlot;
impl BackendSlot {
fn record_op(
&self,
g: &mut Graph,
op_type: &str,
input_names: Vec<String>,
n_outputs: usize,
attribute: Vec<AttributeProto>,
) -> Vec<Output> {
let slot_id = g.register_generic(self, "BackendRuntime");
let output_names: Vec<String> = (0..n_outputs).map(|_| g.next_site_name()).collect();
g.push_node(NodeProto {
op_type: op_type.into(),
domain: "ai.onnx".into(),
input: input_names,
output: output_names.clone(),
attribute,
metadata_props: vec![
kv("ai.bytesandbrains.required_trait", "BackendRuntime"),
kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
],
..Default::default()
});
for name in &output_names {
g.declare_value_info(name, &TYPE_TENSOR_F32);
}
output_names
.into_iter()
.map(|n| Output::new(n, &TYPE_TENSOR_F32))
.collect()
}
fn record_one(
&self,
g: &mut Graph,
op_type: &str,
input_names: Vec<String>,
attribute: Vec<AttributeProto>,
) -> Output {
self.record_op(g, op_type, input_names, 1, attribute)
.into_iter()
.next()
.expect("record_op with n_outputs=1")
}
pub fn zeros(&self, g: &mut Graph, dims: Vec<i64>) -> Output {
self.record_one(g, "Zeros", vec![], vec![attr_ints("dims", dims)])
}
pub fn ones(&self, g: &mut Graph, dims: Vec<i64>) -> Output {
self.record_one(g, "Ones", vec![], vec![attr_ints("dims", dims)])
}
pub fn constant(&self, g: &mut Graph, value: TensorProto) -> Output {
self.record_one(g, "Constant", vec![], vec![attr_tensor("value", value)])
}
pub fn add(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Add", vec![a.name, b.name], vec![])
}
pub fn sub(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Sub", vec![a.name, b.name], vec![])
}
pub fn mul(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Mul", vec![a.name, b.name], vec![])
}
pub fn div(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Div", vec![a.name, b.name], vec![])
}
pub fn neg(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Neg", vec![t.name], vec![])
}
pub fn abs(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Abs", vec![t.name], vec![])
}
pub fn sqrt(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Sqrt", vec![t.name], vec![])
}
pub fn exp(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Exp", vec![t.name], vec![])
}
pub fn log(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Log", vec![t.name], vec![])
}
pub fn pow(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Pow", vec![a.name, b.name], vec![])
}
pub fn matmul(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "MatMul", vec![a.name, b.name], vec![])
}
#[allow(clippy::too_many_arguments)]
pub fn gemm(
&self,
g: &mut Graph,
a: Output,
b: Output,
c: Option<Output>,
alpha: f32,
beta: f32,
trans_a: bool,
trans_b: bool,
) -> Output {
let mut inputs = vec![a.name, b.name];
if let Some(c) = c {
inputs.push(c.name);
}
self.record_one(
g,
"Gemm",
inputs,
vec![
attr_float("alpha", alpha),
attr_float("beta", beta),
attr_int("transA", trans_a as i64),
attr_int("transB", trans_b as i64),
],
)
}
pub fn dot(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Dot", vec![a.name, b.name], vec![])
}
pub fn relu(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Relu", vec![t.name], vec![])
}
pub fn sigmoid(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Sigmoid", vec![t.name], vec![])
}
pub fn tanh(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Tanh", vec![t.name], vec![])
}
pub fn softmax(&self, g: &mut Graph, t: Output, axis: i64) -> Output {
self.record_one(g, "Softmax", vec![t.name], vec![attr_int("axis", axis)])
}
pub fn leaky_relu(&self, g: &mut Graph, t: Output, alpha: f32) -> Output {
self.record_one(
g,
"LeakyRelu",
vec![t.name],
vec![attr_float("alpha", alpha)],
)
}
pub fn gelu(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Gelu", vec![t.name], vec![])
}
pub fn reshape(&self, g: &mut Graph, t: Output, dims: Vec<i64>) -> Output {
self.record_one(g, "Reshape", vec![t.name], vec![attr_ints("dims", dims)])
}
pub fn transpose(&self, g: &mut Graph, t: Output, perm: Option<Vec<i64>>) -> Output {
let attrs = match perm {
Some(p) => vec![attr_ints("perm", p)],
None => vec![],
};
self.record_one(g, "Transpose", vec![t.name], attrs)
}
pub fn concat(&self, g: &mut Graph, tensors: Vec<Output>, axis: i64) -> Output {
let inputs = tensors.into_iter().map(|t| t.name).collect();
self.record_one(g, "Concat", inputs, vec![attr_int("axis", axis)])
}
pub fn split(&self, g: &mut Graph, t: Output, axis: i64, sizes: Vec<i64>) -> Vec<Output> {
let n = sizes.len();
self.record_op(
g,
"Split",
vec![t.name],
n,
vec![attr_int("axis", axis), attr_ints("split", sizes)],
)
}
pub fn slice(
&self,
g: &mut Graph,
t: Output,
starts: Vec<i64>,
ends: Vec<i64>,
axes: Option<Vec<i64>>,
steps: Option<Vec<i64>>,
) -> Output {
let mut attrs = vec![attr_ints("starts", starts), attr_ints("ends", ends)];
if let Some(a) = axes {
attrs.push(attr_ints("axes", a));
}
if let Some(s) = steps {
attrs.push(attr_ints("steps", s));
}
self.record_one(g, "Slice", vec![t.name], attrs)
}
pub fn squeeze(&self, g: &mut Graph, t: Output, axes: Option<Vec<i64>>) -> Output {
let attrs = match axes {
Some(a) => vec![attr_ints("axes", a)],
None => vec![],
};
self.record_one(g, "Squeeze", vec![t.name], attrs)
}
pub fn unsqueeze(&self, g: &mut Graph, t: Output, axes: Vec<i64>) -> Output {
self.record_one(g, "Unsqueeze", vec![t.name], vec![attr_ints("axes", axes)])
}
pub fn identity(&self, g: &mut Graph, t: Output) -> Output {
self.record_one(g, "Identity", vec![t.name], vec![])
}
pub fn cast(&self, g: &mut Graph, t: Output, to_elem_type: i32) -> Output {
self.record_one(
g,
"Cast",
vec![t.name],
vec![attr_int("to", to_elem_type as i64)],
)
}
fn reduce(
&self,
g: &mut Graph,
op_type: &str,
t: Output,
axes: Option<Vec<i64>>,
keepdims: bool,
) -> Output {
let mut attrs = vec![attr_int("keepdims", keepdims as i64)];
if let Some(a) = axes {
attrs.push(attr_ints("axes", a));
}
self.record_one(g, op_type, vec![t.name], attrs)
}
pub fn reduce_sum(
&self,
g: &mut Graph,
t: Output,
axes: Option<Vec<i64>>,
keepdims: bool,
) -> Output {
self.reduce(g, "ReduceSum", t, axes, keepdims)
}
pub fn reduce_mean(
&self,
g: &mut Graph,
t: Output,
axes: Option<Vec<i64>>,
keepdims: bool,
) -> Output {
self.reduce(g, "ReduceMean", t, axes, keepdims)
}
pub fn reduce_max(
&self,
g: &mut Graph,
t: Output,
axes: Option<Vec<i64>>,
keepdims: bool,
) -> Output {
self.reduce(g, "ReduceMax", t, axes, keepdims)
}
pub fn reduce_min(
&self,
g: &mut Graph,
t: Output,
axes: Option<Vec<i64>>,
keepdims: bool,
) -> Output {
self.reduce(g, "ReduceMin", t, axes, keepdims)
}
pub fn equal(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Equal", vec![a.name, b.name], vec![])
}
pub fn greater(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Greater", vec![a.name, b.name], vec![])
}
pub fn less(&self, g: &mut Graph, a: Output, b: Output) -> Output {
self.record_one(g, "Less", vec![a.name, b.name], vec![])
}
#[allow(clippy::too_many_arguments)]
pub fn batch_normalization(
&self,
g: &mut Graph,
input: Output,
scale: Output,
bias: Output,
mean: Output,
variance: Output,
epsilon: f32,
momentum: f32,
) -> Output {
self.record_one(
g,
"BatchNormalization",
vec![input.name, scale.name, bias.name, mean.name, variance.name],
vec![
attr_float("epsilon", epsilon),
attr_float("momentum", momentum),
],
)
}
pub fn layer_normalization(
&self,
g: &mut Graph,
input: Output,
scale: Output,
bias: Option<Output>,
axis: i64,
epsilon: f32,
) -> Output {
let mut inputs = vec![input.name, scale.name];
if let Some(b) = bias {
inputs.push(b.name);
}
self.record_one(
g,
"LayerNormalization",
inputs,
vec![attr_int("axis", axis), attr_float("epsilon", epsilon)],
)
}
#[allow(clippy::too_many_arguments)]
pub fn conv(
&self,
g: &mut Graph,
input: Output,
weight: Output,
bias: Option<Output>,
kernel_shape: Vec<i64>,
strides: Vec<i64>,
pads: Vec<i64>,
dilations: Vec<i64>,
group: i64,
) -> Output {
let mut inputs = vec![input.name, weight.name];
if let Some(b) = bias {
inputs.push(b.name);
}
self.record_one(
g,
"Conv",
inputs,
vec![
attr_ints("kernel_shape", kernel_shape),
attr_ints("strides", strides),
attr_ints("pads", pads),
attr_ints("dilations", dilations),
attr_int("group", group),
],
)
}
pub fn max_pool(
&self,
g: &mut Graph,
input: Output,
kernel_shape: Vec<i64>,
strides: Vec<i64>,
pads: Vec<i64>,
) -> Output {
self.record_one(
g,
"MaxPool",
vec![input.name],
vec![
attr_ints("kernel_shape", kernel_shape),
attr_ints("strides", strides),
attr_ints("pads", pads),
],
)
}
pub fn average_pool(
&self,
g: &mut Graph,
input: Output,
kernel_shape: Vec<i64>,
strides: Vec<i64>,
pads: Vec<i64>,
count_include_pad: bool,
) -> Output {
self.record_one(
g,
"AveragePool",
vec![input.name],
vec![
attr_ints("kernel_shape", kernel_shape),
attr_ints("strides", strides),
attr_ints("pads", pads),
attr_int("count_include_pad", count_include_pad as i64),
],
)
}
pub fn global_average_pool(&self, g: &mut Graph, input: Output) -> Output {
self.record_one(g, "GlobalAveragePool", vec![input.name], vec![])
}
pub fn gather(&self, g: &mut Graph, data: Output, indices: Output, axis: i64) -> Output {
self.record_one(
g,
"Gather",
vec![data.name, indices.name],
vec![attr_int("axis", axis)],
)
}
pub fn scatter(
&self,
g: &mut Graph,
data: Output,
indices: Output,
updates: Output,
axis: i64,
) -> Output {
self.record_one(
g,
"Scatter",
vec![data.name, indices.name, updates.name],
vec![attr_int("axis", axis)],
)
}
pub fn if_op(
&self,
g: &mut Graph,
cond: Output,
then_branch: GraphProto,
else_branch: GraphProto,
n_outputs: usize,
) -> Vec<Output> {
self.record_op(
g,
"If",
vec![cond.name],
n_outputs,
vec![
attr_graph("then_branch", then_branch),
attr_graph("else_branch", else_branch),
],
)
}
pub fn loop_op(
&self,
g: &mut Graph,
max_trip_count: Option<Output>,
cond: Option<Output>,
body: GraphProto,
initial: Vec<Output>,
n_outputs: usize,
) -> Vec<Output> {
let mut inputs = vec![
max_trip_count.map(|o| o.name).unwrap_or_default(),
cond.map(|o| o.name).unwrap_or_default(),
];
inputs.extend(initial.into_iter().map(|o| o.name));
self.record_op(g, "Loop", inputs, n_outputs, vec![attr_graph("body", body)])
}
}
#[allow(clippy::too_many_arguments)]
fn record_role_op<P: 'static>(
g: &mut Graph,
placeholder: &P,
required_trait: &'static str,
role_domain: &'static str,
op_type: &str,
input_names: Vec<String>,
n_outputs: usize,
attribute: Vec<AttributeProto>,
) -> Vec<Output> {
let slot_id = g.register_generic(placeholder, required_trait);
let output_names: Vec<String> = (0..n_outputs).map(|_| g.next_site_name()).collect();
g.push_node(NodeProto {
op_type: op_type.into(),
domain: role_domain.into(),
input: input_names,
output: output_names.clone(),
attribute,
metadata_props: vec![
kv("ai.bytesandbrains.required_trait", required_trait),
kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
],
..Default::default()
});
for name in &output_names {
g.declare_value_info(name, &TYPE_TENSOR);
}
output_names
.into_iter()
.map(|n| Output::new(n, &TYPE_TENSOR))
.collect()
}
#[allow(clippy::too_many_arguments)]
fn record_role_op_one(
g: &mut Graph,
placeholder: &impl std::any::Any,
required_trait: &'static str,
role_domain: &'static str,
op_type: &str,
input_names: Vec<String>,
attribute: Vec<AttributeProto>,
) -> Output {
record_role_op(
g,
placeholder,
required_trait,
role_domain,
op_type,
input_names,
1,
attribute,
)
.into_iter()
.next()
.expect("record_role_op with n_outputs=1")
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ModelSlot;
impl ModelSlot {
pub fn forward(&self, g: &mut Graph, input: Output) -> Output {
record_role_op_one(
g,
self,
"ModelRuntime",
"ai.bytesandbrains.role.model",
"Forward",
vec![input.name],
vec![],
)
}
pub fn backward(&self, g: &mut Graph, grad: Output) -> Output {
record_role_op_one(
g,
self,
"ModelRuntime",
"ai.bytesandbrains.role.model",
"Backward",
vec![grad.name],
vec![],
)
}
pub fn compute_loss(&self, g: &mut Graph, input: Output, target: Output) -> Output {
record_role_op_one(
g,
self,
"ModelRuntime",
"ai.bytesandbrains.role.model",
"ComputeLoss",
vec![input.name, target.name],
vec![],
)
}
pub fn apply_delta(&self, g: &mut Graph, delta: Output) -> Output {
record_role_op_one(
g,
self,
"ModelRuntime",
"ai.bytesandbrains.role.model",
"ApplyDelta",
vec![delta.name],
vec![],
)
}
pub fn load_parameters(&self, g: &mut Graph, params: Output) -> Output {
record_role_op_one(
g,
self,
"ModelRuntime",
"ai.bytesandbrains.role.model",
"LoadParameters",
vec![params.name],
vec![],
)
}
pub fn params(&self, g: &mut Graph) -> Output {
record_role_op_one(
g,
self,
"ModelRuntime",
"ai.bytesandbrains.role.model",
"Params",
vec![],
vec![],
)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IndexSlot;
impl IndexSlot {
pub fn add(&self, g: &mut Graph, vec: Output) -> Output {
record_role_op_one(
g,
self,
"IndexRuntime",
"ai.bytesandbrains.role.index",
"Add",
vec![vec.name],
vec![],
)
}
pub fn search(&self, g: &mut Graph, query: Output, k: i64) -> Output {
record_role_op_one(
g,
self,
"IndexRuntime",
"ai.bytesandbrains.role.index",
"Search",
vec![query.name],
vec![attr_int("k", k)],
)
}
pub fn remove(&self, g: &mut Graph, id: Output) -> Output {
record_role_op_one(
g,
self,
"IndexRuntime",
"ai.bytesandbrains.role.index",
"Remove",
vec![id.name],
vec![],
)
}
pub fn train(&self, g: &mut Graph, samples: Output) -> Output {
let slot_id = g.register_generic(self, "IndexRuntime");
let out_name = g.next_site_name();
g.push_node(NodeProto {
op_type: "Train".into(),
domain: "ai.bytesandbrains.role.index".into(),
input: vec![samples.name],
output: vec![out_name.clone()],
metadata_props: vec![
kv("ai.bytesandbrains.required_trait", "IndexRuntime"),
kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
kv("ai.bytesandbrains.index.port", "samples"),
],
..Default::default()
});
g.declare_value_info(&out_name, &TYPE_TRIGGER);
Output::new(out_name, &TYPE_TRIGGER)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AggregatorSlot;
impl AggregatorSlot {
pub fn contribute(&self, g: &mut Graph, contribution: Output, metadata: Output) -> Output {
record_role_op_one(
g,
self,
"AggregatorRuntime",
"ai.bytesandbrains.role.aggregator",
"Contribute",
vec![contribution.name, metadata.name],
vec![],
)
}
pub fn aggregate(&self, g: &mut Graph, trigger: Output) -> (Output, Output) {
let mut outs = record_role_op(
g,
self,
"AggregatorRuntime",
"ai.bytesandbrains.role.aggregator",
"Aggregate",
vec![trigger.name],
2,
vec![],
);
let metadata = outs.pop().expect("two outputs");
let params = outs.pop().expect("two outputs");
(params, metadata)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CodecSlot;
impl CodecSlot {
pub fn train(&self, g: &mut Graph, samples: Output) -> Output {
let slot_id = g.register_generic(self, "CodecRuntime");
let out_name = g.next_site_name();
g.push_node(NodeProto {
op_type: "Train".into(),
domain: "ai.bytesandbrains.role.codec".into(),
input: vec![samples.name],
output: vec![out_name.clone()],
metadata_props: vec![
kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
kv("ai.bytesandbrains.codec.port", "in"),
],
..Default::default()
});
g.declare_value_info(&out_name, &TYPE_TRIGGER);
Output::new(out_name, &TYPE_TRIGGER)
}
pub fn encode(&self, g: &mut Graph, input: Output) -> Output {
let slot_id = g.register_generic(self, "CodecRuntime");
let out_name = g.next_site_name();
g.push_node(NodeProto {
op_type: "Encode".into(),
domain: "ai.bytesandbrains.role.codec".into(),
input: vec![input.name],
output: vec![out_name.clone()],
metadata_props: vec![
kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
kv("ai.bytesandbrains.codec.port", "out"),
],
..Default::default()
});
g.declare_value_info(&out_name, &TYPE_TENSOR);
Output::new(out_name, &TYPE_TENSOR)
}
pub fn decode(&self, g: &mut Graph, encoded: Output) -> Output {
let slot_id = g.register_generic(self, "CodecRuntime");
let out_name = g.next_site_name();
g.push_node(NodeProto {
op_type: "Decode".into(),
domain: "ai.bytesandbrains.role.codec".into(),
input: vec![encoded.name],
output: vec![out_name.clone()],
metadata_props: vec![
kv("ai.bytesandbrains.required_trait", "CodecRuntime"),
kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
kv("ai.bytesandbrains.codec.port", "in"),
],
..Default::default()
});
g.declare_value_info(&out_name, &TYPE_TENSOR);
Output::new(out_name, &TYPE_TENSOR)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DataLoaderSlot;
impl DataLoaderSlot {
pub fn next_batch(&self, g: &mut Graph) -> (Output, Output) {
let mut outs = record_role_op(
g,
self,
"DataSourceRuntime",
"ai.bytesandbrains.role.data_source",
"NextBatch",
vec![],
2,
vec![],
);
let labels = outs.pop().expect("two outputs");
let batch = outs.pop().expect("two outputs");
(batch, labels)
}
pub fn reset(&self, g: &mut Graph, trigger: Output) -> Output {
record_role_op_one(
g,
self,
"DataSourceRuntime",
"ai.bytesandbrains.role.data_source",
"Reset",
vec![trigger.name],
vec![],
)
}
pub fn on_data_loaded(&self, g: &mut Graph) -> Output {
record_role_op_one(
g,
self,
"DataSourceRuntime",
"ai.bytesandbrains.role.data_source",
"OnDataLoaded",
vec![],
vec![],
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct PeerSelectorSlot {
pub class: &'static str,
}
impl Default for PeerSelectorSlot {
fn default() -> Self {
Self {
class: bb_ir::peer_class::SELF_CLASS,
}
}
}
impl PeerSelectorSlot {
pub fn of_class(class: &'static str) -> Self {
Self { class }
}
fn record_peer_op(&self, g: &mut Graph, op_type: &str, attrs: Vec<AttributeProto>) -> Output {
let slot_id = g.register_generic(self, "PeerSelectorRuntime");
let out_name = g.next_site_name();
g.push_node(NodeProto {
op_type: op_type.into(),
domain: "ai.bytesandbrains.role.peer_selector".into(),
input: vec![],
output: vec![out_name.clone()],
attribute: attrs,
metadata_props: vec![
kv("ai.bytesandbrains.required_trait", "PeerSelectorRuntime"),
kv("ai.bytesandbrains.slot_id", &slot_id.to_string()),
kv(bb_ir::peer_class::PEER_CLASS_KEY, self.class),
],
..Default::default()
});
g.declare_value_info(&out_name, &bb_ir::types::TYPE_PEER_ID);
Output::new(out_name, &bb_ir::types::TYPE_PEER_ID)
}
pub fn sample(&self, g: &mut Graph, n: i64) -> Output {
self.record_peer_op(g, "Sample", vec![attr_int("n", n)])
}
pub fn current_view(&self, g: &mut Graph) -> Output {
self.record_peer_op(g, "CurrentView", vec![])
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ProtocolSlot;