1mod elementwise;
17mod linalg;
18mod shape;
19
20use bb_ir::proto::onnx::{attribute_proto::AttributeType, AttributeProto};
21use bb_runtime::atomic::DispatchResult;
22use bb_runtime::bus::OpError;
23use bb_runtime::slot_value::SlotValue;
24
25use crate::backends::cpu::CpuBackend;
26use crate::backends::cpu::CpuTensor;
27
28fn find_attr<'a>(attrs: &'a [AttributeProto], name: &str) -> Option<&'a AttributeProto> {
30 attrs.iter().find(|a| a.name == name)
31}
32
33fn need_int_attr(op: &str, attrs: &[AttributeProto], name: &str) -> Result<i64, OpError> {
34 let a = find_attr(attrs, name).ok_or_else(|| OpError {
35 detail: format!("{op}: missing `{name}` attribute"),
36 ..Default::default()
37 })?;
38 if a.r#type != AttributeType::Int as i32 {
39 return Err(OpError {
40 detail: format!("{op}: `{name}` attribute must be INT"),
41 ..Default::default()
42 });
43 }
44 Ok(a.i)
45}
46
47fn need_ints_attr(op: &str, attrs: &[AttributeProto], name: &str) -> Result<Vec<i64>, OpError> {
48 let a = find_attr(attrs, name).ok_or_else(|| OpError {
49 detail: format!("{op}: missing `{name}` attribute"),
50 ..Default::default()
51 })?;
52 if a.r#type != AttributeType::Ints as i32 {
53 return Err(OpError {
54 detail: format!("{op}: `{name}` attribute must be INTS"),
55 ..Default::default()
56 });
57 }
58 Ok(a.ints.clone())
59}
60
61fn opt_float_attr(attrs: &[AttributeProto], name: &str, default: f32) -> f32 {
62 find_attr(attrs, name)
63 .filter(|a| a.r#type == AttributeType::Float as i32)
64 .map(|a| a.f)
65 .unwrap_or(default)
66}
67
68fn opt_int_attr(attrs: &[AttributeProto], name: &str, default: i64) -> i64 {
69 find_attr(attrs, name)
70 .filter(|a| a.r#type == AttributeType::Int as i32)
71 .map(|a| a.i)
72 .unwrap_or(default)
73}
74
75fn as_cpu_tensor<'a>(op: &str, role: &str, h: &'a dyn SlotValue) -> Result<&'a CpuTensor, OpError> {
78 h.as_any()
79 .downcast_ref::<CpuTensor>()
80 .ok_or_else(|| OpError {
81 detail: format!("{op}: {role} is not a CpuTensor"),
82 ..Default::default()
83 })
84}
85
86fn need_two_inputs<'a>(
87 op: &str,
88 inputs: &'a [(&str, &dyn SlotValue)],
89) -> Result<(&'a CpuTensor, &'a CpuTensor), OpError> {
90 if inputs.len() < 2 {
91 return Err(OpError {
92 detail: format!("{op}: requires two inputs, got {}", inputs.len()),
93 ..Default::default()
94 });
95 }
96 let a = as_cpu_tensor(op, "input 0", inputs[0].1)?;
97 let b = as_cpu_tensor(op, "input 1", inputs[1].1)?;
98 Ok((a, b))
99}
100
101fn need_one_input<'a>(
102 op: &str,
103 inputs: &'a [(&str, &dyn SlotValue)],
104) -> Result<&'a CpuTensor, OpError> {
105 if inputs.is_empty() {
106 return Err(OpError {
107 detail: format!("{op}: requires one input, got 0"),
108 ..Default::default()
109 });
110 }
111 as_cpu_tensor(op, "input 0", inputs[0].1)
112}
113
114fn out(name: &str, tensor: CpuTensor) -> DispatchResult {
115 DispatchResult::Immediate(vec![(name.to_string(), Box::new(tensor))])
116}
117
118pub fn dispatch(
125 backend: &CpuBackend,
126 op_type: &str,
127 inputs: &[(&str, &dyn SlotValue)],
128 attrs: &[AttributeProto],
129) -> Result<DispatchResult, OpError> {
130 match op_type {
131 "Add" => Ok(out("C", elementwise::add(backend, op_type, inputs)?)),
133 "Sub" => Ok(out("C", elementwise::sub(backend, op_type, inputs)?)),
134 "Mul" => Ok(out("C", elementwise::mul(backend, op_type, inputs)?)),
135 "Div" => Ok(out("C", elementwise::div(backend, op_type, inputs)?)),
136 "Pow" => Ok(out("C", elementwise::pow(backend, op_type, inputs)?)),
137
138 "Neg" => Ok(out("Y", elementwise::neg(backend, op_type, inputs)?)),
140 "Abs" => Ok(out("Y", elementwise::abs(backend, op_type, inputs)?)),
141 "Sqrt" => Ok(out("Y", elementwise::sqrt(backend, op_type, inputs)?)),
142 "Exp" => Ok(out("Y", elementwise::exp(backend, op_type, inputs)?)),
143 "Log" => Ok(out("Y", elementwise::log(backend, op_type, inputs)?)),
144
145 "Relu" => Ok(out("Y", elementwise::relu(backend, op_type, inputs)?)),
147 "Sigmoid" => Ok(out("Y", elementwise::sigmoid(backend, op_type, inputs)?)),
148 "Tanh" => Ok(out("Y", elementwise::tanh(backend, op_type, inputs)?)),
149 "Gelu" => Ok(out("Y", elementwise::gelu(backend, op_type, inputs)?)),
150 "Identity" => Ok(out("Y", elementwise::identity(backend, op_type, inputs)?)),
151 "Softmax" => Ok(out("Y", shape::softmax(backend, op_type, inputs, attrs)?)),
152 "LeakyRelu" => Ok(out(
153 "Y",
154 shape::leaky_relu(backend, op_type, inputs, attrs)?,
155 )),
156
157 "Equal" => Ok(out("C", elementwise::equal(backend, op_type, inputs)?)),
159 "Greater" => Ok(out("C", elementwise::greater(backend, op_type, inputs)?)),
160 "Less" => Ok(out("C", elementwise::less(backend, op_type, inputs)?)),
161
162 "MatMul" => Ok(out("Y", linalg::matmul(backend, op_type, inputs)?)),
164 "Dot" => Ok(out("Y", linalg::dot(backend, op_type, inputs)?)),
165 "Gemm" => Ok(out("Y", shape::gemm(backend, op_type, inputs, attrs)?)),
166
167 "ReduceSum" => Ok(out("Y", linalg::reduce_sum(backend, op_type, inputs)?)),
169 "ReduceMean" => Ok(out("Y", linalg::reduce_mean(backend, op_type, inputs)?)),
170 "ReduceMax" => Ok(out("Y", linalg::reduce_max(backend, op_type, inputs)?)),
171 "ReduceMin" => Ok(out("Y", linalg::reduce_min(backend, op_type, inputs)?)),
172
173 "Reshape" => Ok(out("Y", shape::reshape(backend, op_type, inputs, attrs)?)),
175 "Transpose" => Ok(out("Y", shape::transpose(backend, op_type, inputs, attrs)?)),
176 "Concat" => Ok(out("Y", shape::concat(backend, op_type, inputs, attrs)?)),
177 "Squeeze" => Ok(out("Y", shape::squeeze(backend, op_type, inputs, attrs)?)),
178 "Unsqueeze" => Ok(out("Y", shape::unsqueeze(backend, op_type, inputs, attrs)?)),
179 "Cast" => Ok(out("Y", shape::cast(backend, op_type, inputs, attrs)?)),
180 "Slice" => Ok(out("Y", shape::slice(backend, op_type, inputs, attrs)?)),
181 "Split" => Ok(shape::split(backend, op_type, inputs, attrs)?),
182
183 "Gather" => Ok(out("Y", shape::gather(backend, op_type, inputs, attrs)?)),
185
186 "GlobalAveragePool" => Ok(out(
188 "Y",
189 linalg::global_average_pool(backend, op_type, inputs)?,
190 )),
191
192 "Zeros" => Ok(out("Y", shape::zeros(backend, op_type, attrs)?)),
194 "Ones" => Ok(out("Y", shape::ones(backend, op_type, attrs)?)),
195 "Constant" => Ok(out("Y", shape::constant(backend, op_type, attrs)?)),
196
197 other => Err(OpError {
205 detail: format!("CpuBackend: unsupported op_type '{other}'"),
206 ..Default::default()
207 }),
208 }
209}
210