bb_ops/aggregators/fedavg/
mod.rs1use std::any::Any;
12use std::collections::BTreeMap;
13use std::marker::PhantomData;
14
15use serde::{Deserialize, Serialize};
16
17#[cfg(feature = "cpu-backend")]
18use bb_ir::component::ErasedComponent;
19use bb_ir::component::{AnyComponent, DependencyDecl, RestoreError};
20use bb_ir::ids::PeerId;
21use bb_ir::proto::onnx::TensorProto;
22use bb_ir::tensor::Tensor;
23use bb_ir::types::common_relations::NO_RELATIONS;
24use bb_runtime::atomic::{AtomicOpDecl, AtomicOpKind, AtomicOpsetDecl, DispatchResult};
25use bb_runtime::bus::{OpError, OpErrorKind};
26use bb_runtime::completion::{CompletionHandle, ContractResponse};
27use bb_runtime::concrete::{ComponentPackage, ConcreteComponent};
28use bb_runtime::contracts::{Aggregator as AggregatorContract, Backend};
29use bb_runtime::roles::AggregatorRuntime;
30use bb_runtime::runtime::RuntimeResourceRef;
31use bb_runtime::slot_value::SlotValue;
32
33#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
35pub struct FedAvgMeta {
36 pub num_samples: u64,
39}
40
41const ONNX_FLOAT: i32 = 1;
43
44#[cfg(feature = "cpu-backend")]
47fn fedavg_element_type<B: Backend>() -> &'static bb_ir::types::TypeNode {
48 <B::Tensor as bb_ir::types::Storage>::TYPE
49}
50
51#[derive(Debug, Serialize, Deserialize)]
64pub struct FedAvg<B: Backend> {
65 #[serde(skip)]
73 buffer: BTreeMap<PeerId, (B::Tensor, u64)>,
74 #[serde(skip)]
75 _backend: PhantomData<B>,
76}
77
78impl<B: Backend> Default for FedAvg<B> {
79 fn default() -> Self {
80 Self {
81 buffer: BTreeMap::new(),
82 _backend: PhantomData,
83 }
84 }
85}
86
87impl<B: Backend> Clone for FedAvg<B> {
88 fn clone(&self) -> Self {
89 Self::default()
94 }
95}
96
97impl<B: Backend> AggregatorContract for FedAvg<B>
98where
99 B: 'static,
100 B::Tensor: Tensor,
101{
102 type Element = B::Tensor;
103 type Error = OpError;
104 type Metadata = FedAvgMeta;
105
106 fn contribute(
107 &mut self,
108 _ctx: &mut RuntimeResourceRef<'_>,
109 src: PeerId,
110 tensor: &Self::Element,
111 metadata: FedAvgMeta,
112 _completion: CompletionHandle<(), Self::Error>,
113 ) -> ContractResponse<(), Self::Error> {
114 if metadata.num_samples == 0 {
118 return ContractResponse::Now(Err(OpError {
119 detail: "FedAvg::contribute: num_samples = 0 — degenerate weight".into(),
120 ..Default::default()
121 }));
122 }
123 self.buffer
127 .insert(src, (tensor.clone(), metadata.num_samples));
128 ContractResponse::Now(Ok(()))
129 }
130
131 fn aggregate(
132 &mut self,
133 ctx: &mut RuntimeResourceRef<'_>,
134 _completion: CompletionHandle<(Box<Self::Element>, FedAvgMeta), Self::Error>,
135 ) -> ContractResponse<(Box<Self::Element>, FedAvgMeta), Self::Error> {
136 let backend = match ctx.dependency::<B>("backend") {
137 Ok(b) => b,
138 Err(e) => {
139 return ContractResponse::Now(Err(OpError {
140 detail: format!("FedAvg::aggregate: backend lookup failed: {e}"),
141 ..Default::default()
142 }));
143 }
144 };
145
146 let entries: Vec<(B::Tensor, u64)> =
147 std::mem::take(&mut self.buffer).into_values().collect();
148 if entries.is_empty() {
149 return ContractResponse::Now(Err(OpError {
150 detail: "FedAvg::aggregate: empty buffer — no contributions to reduce".into(),
151 ..Default::default()
152 }));
153 }
154
155 let total_samples: u64 = entries.iter().map(|(_, n)| *n).sum();
156 if total_samples == 0 {
157 return ContractResponse::Now(Err(OpError {
158 detail: "FedAvg::aggregate: total_samples = 0".into(),
159 ..Default::default()
160 }));
161 }
162 let total_f = total_samples as f32;
163
164 let canonical_dims: Vec<i64> = entries[0].0.dims().to_vec();
173 let canonical_len: usize = canonical_dims
174 .iter()
175 .map(|d| (*d).max(0) as usize)
176 .product();
177
178 let mut acc: Option<B::Tensor> = None;
179 for (tensor, n) in &entries {
180 let w = (*n as f32) / total_f;
181 let weight_proto = TensorProto {
182 data_type: ONNX_FLOAT,
183 dims: canonical_dims.clone(),
184 float_data: vec![w; canonical_len],
185 ..Default::default()
186 };
187 let weight = match backend.constant(weight_proto) {
188 Ok(t) => t,
189 Err(e) => {
190 return ContractResponse::Now(Err(OpError {
191 detail: format!("FedAvg::aggregate: backend.constant failed: {e}"),
192 ..Default::default()
193 }));
194 }
195 };
196 let scaled = match backend.mul(tensor, &weight) {
197 Ok(t) => t,
198 Err(e) => {
199 return ContractResponse::Now(Err(OpError {
200 detail: format!("FedAvg::aggregate: backend.mul failed: {e}"),
201 ..Default::default()
202 }));
203 }
204 };
205 acc = Some(match acc {
206 None => scaled,
207 Some(prev) => match backend.add(&prev, &scaled) {
208 Ok(t) => t,
209 Err(e) => {
210 return ContractResponse::Now(Err(OpError {
211 detail: format!("FedAvg::aggregate: backend.add failed: {e}"),
212 ..Default::default()
213 }));
214 }
215 },
216 });
217 }
218
219 let params = acc.expect("entries non-empty implies acc populated");
220 ContractResponse::Now(Ok((
221 Box::new(params),
222 FedAvgMeta {
223 num_samples: total_samples,
224 },
225 )))
226 }
227}
228
229impl<B: Backend> ConcreteComponent for FedAvg<B>
238where
239 B: 'static + Default,
240{
241 const TYPE_NAME: &'static str = "FedAvg";
242 const PACKAGE: ComponentPackage = ComponentPackage::Framework;
243 const DEPENDENCIES: &'static [DependencyDecl] = &[DependencyDecl {
244 role: "Backend",
245 slot: "backend",
246 }];
247
248 type Config = ();
249 type Error = std::convert::Infallible;
250
251 fn new(_: &Self::Config) -> Result<Self, Self::Error> {
252 Ok(Self::default())
253 }
254
255 fn serialize(&self) -> Vec<u8> {
256 bincode::serialize(self).expect("FedAvg serialize — bincode infallible on Default state")
257 }
258
259 fn restore(bytes: &[u8]) -> Result<Self, RestoreError> {
260 bincode::deserialize(bytes).map_err(RestoreError::Malformed)
261 }
262}
263
264impl<B: Backend + 'static> AnyComponent for FedAvg<B> {
265 fn as_any(&self) -> &dyn Any {
266 self
267 }
268 fn as_any_mut(&mut self) -> &mut dyn Any {
269 self
270 }
271}
272
273static FEDAVG_ATOMIC_OPS: &[AtomicOpDecl] = &[
277 AtomicOpDecl {
278 name: "Contribute",
279 inputs: &[],
280 outputs: &[],
281 kind: AtomicOpKind::Immediate,
282 type_relations: NO_RELATIONS,
283 },
284 AtomicOpDecl {
285 name: "Aggregate",
286 inputs: &[],
287 outputs: &[],
288 kind: AtomicOpKind::Immediate,
289 type_relations: NO_RELATIONS,
290 },
291];
292
293impl<B> AggregatorRuntime for FedAvg<B>
294where
295 B: Backend + 'static + Default,
296 B::Tensor: Tensor,
297{
298 type Error = OpError;
299
300 fn atomic_opset(&self) -> AtomicOpsetDecl {
301 AtomicOpsetDecl {
302 domain: "ai.bytesandbrains.role.aggregator",
303 version: 1,
304 ops: FEDAVG_ATOMIC_OPS,
305 }
306 }
307
308 fn dispatch_atomic(
309 &mut self,
310 op_type: &str,
311 inputs: &[(&str, &dyn SlotValue)],
312 ctx: &mut RuntimeResourceRef<'_>,
313 ) -> Result<DispatchResult, Self::Error> {
314 match op_type {
315 "Contribute" => {
316 let tensor_ref: &B::Tensor = match inputs
322 .first()
323 .and_then(|(_, v)| v.as_any().downcast_ref::<Box<B::Tensor>>())
324 {
325 Some(b) => b,
326 None => {
327 return Err(OpError {
328 kind: OpErrorKind::TypeMismatch,
329 reason: "input_type_mismatch",
330 detail: format!(
331 "FedAvg::Contribute input 0 expected `Box<{}>`",
332 std::any::type_name::<B::Tensor>(),
333 ),
334 });
335 }
336 };
337 let metadata = match inputs
338 .get(1)
339 .and_then(|(_, v)| v.as_any().downcast_ref::<FedAvgMeta>())
340 {
341 Some(m) => m.clone(),
342 None => {
343 return Err(OpError {
344 kind: OpErrorKind::TypeMismatch,
345 reason: "input_type_mismatch",
346 detail: "FedAvg::Contribute input 1 expected `FedAvgMeta`".into(),
347 });
348 }
349 };
350 let src = match ctx.current.inbound.src_peer {
351 Some(p) => p,
352 None => {
353 return Err(OpError {
354 detail: "FedAvg::Contribute: envelope_src_peer is None — wire envelope did not carry src_peer multihash bytes".into(),
355 ..Default::default()
356 });
357 }
358 };
359 let completion = ctx.open_completion::<(), OpError>();
360 let cmd_id = completion.cmd_id();
361 match <Self as AggregatorContract>::contribute(
362 self, ctx, src, tensor_ref, metadata, completion,
363 ) {
364 ContractResponse::Now(Ok(())) => Ok(DispatchResult::Immediate(Vec::new())),
365 ContractResponse::Now(Err(e)) => Err(OpError {
366 detail: format!("{e}"),
367 ..Default::default()
368 }),
369 ContractResponse::Later => Ok(DispatchResult::Async(cmd_id)),
370 }
371 }
372 "Aggregate" => {
373 let completion = ctx.open_completion::<(Box<B::Tensor>, FedAvgMeta), OpError>();
374 let cmd_id = completion.cmd_id();
375 match <Self as AggregatorContract>::aggregate(self, ctx, completion) {
376 ContractResponse::Now(Ok((params, metadata))) => {
377 Ok(DispatchResult::Immediate(vec![
378 ("params".to_string(), Box::new(params) as Box<dyn SlotValue>),
379 (
380 "metadata".to_string(),
381 Box::new(metadata) as Box<dyn SlotValue>,
382 ),
383 ]))
384 }
385 ContractResponse::Now(Err(e)) => Err(OpError {
386 detail: format!("{e}"),
387 ..Default::default()
388 }),
389 ContractResponse::Later => Ok(DispatchResult::Async(cmd_id)),
390 }
391 }
392 other => Err(OpError {
393 detail: format!("FedAvg::dispatch_atomic: unknown op_type `{other}`"),
394 ..Default::default()
395 }),
396 }
397 }
398}
399
400#[cfg(feature = "cpu-backend")]
407type FedAvgCpu = FedAvg<crate::backends::cpu::CpuBackend>;
408
409#[cfg(feature = "cpu-backend")]
410#[doc(hidden)]
411fn __fedavg_cpu_serialize(erased: &dyn ErasedComponent) -> Vec<u8> {
412 let any: &dyn Any = erased;
413 let concrete: &FedAvgCpu = any
414 .downcast_ref::<FedAvgCpu>()
415 .expect("inventory downcast: FedAvg<CpuBackend>");
416 <FedAvgCpu as ConcreteComponent>::serialize(concrete)
417}
418
419#[cfg(feature = "cpu-backend")]
420#[doc(hidden)]
421fn __fedavg_cpu_restore(bytes: &[u8]) -> Result<Box<dyn ErasedComponent>, RestoreError> {
422 <FedAvgCpu as ConcreteComponent>::restore(bytes)
423 .map(|v| Box::new(v) as Box<dyn ErasedComponent>)
424}
425
426#[cfg(feature = "cpu-backend")]
427#[doc(hidden)]
428fn __fedavg_cpu_construct(
429 cfg: &dyn Any,
430) -> Result<Box<dyn ErasedComponent>, bb_runtime::concrete::ConstructError> {
431 let typed = cfg
432 .downcast_ref::<()>()
433 .ok_or_else(|| bb_runtime::concrete::ConstructError {
434 type_name: "FedAvg",
435 detail: "config type mismatch: expected `()`".into(),
436 })?;
437 <FedAvgCpu as ConcreteComponent>::new(typed)
438 .map(|v| Box::new(v) as Box<dyn ErasedComponent>)
439 .map_err(|e| bb_runtime::concrete::ConstructError {
440 type_name: "FedAvg",
441 detail: format!("{e}"),
442 })
443}
444
445#[cfg(feature = "cpu-backend")]
446#[doc(hidden)]
447fn __fedavg_cpu_element_type_node() -> &'static bb_ir::types::TypeNode {
448 fedavg_element_type::<crate::backends::cpu::CpuBackend>()
449}
450
451#[cfg(feature = "cpu-backend")]
452inventory::submit! {
453 bb_runtime::registry::ConcreteComponentRegistration {
454 type_name: "FedAvg",
455 package: ComponentPackage::Framework,
456 serialize_fn: __fedavg_cpu_serialize,
457 restore_fn: __fedavg_cpu_restore,
458 construct_fn: __fedavg_cpu_construct,
459 dependencies: <FedAvgCpu as ConcreteComponent>::DEPENDENCIES,
460 }
461}
462
463#[cfg(feature = "cpu-backend")]
464inventory::submit! {
465 bb_runtime::registry::ComponentRoleBinding {
466 type_name: "FedAvg",
467 role: bb_runtime::registry::ComponentRole::Aggregator,
468 }
469}
470
471#[cfg(feature = "cpu-backend")]
472inventory::submit! {
473 bb_runtime::registry::DispatcherRegistration {
474 type_name: "FedAvg",
475 role: bb_runtime::registry::ComponentRole::Aggregator,
476 register_fn: |engine: &mut bb_runtime::engine::Engine| {
477 engine.register_aggregator_dispatcher::<FedAvgCpu>();
478 },
479 }
480}
481
482#[cfg(feature = "cpu-backend")]
483inventory::submit! {
484 bb_runtime::registry::StorageTypeEntry {
485 concrete_type_name: <FedAvgCpu as ConcreteComponent>::TYPE_NAME,
486 role_runtime: "AggregatorRuntime",
487 port: "element",
488 type_node_fn: __fedavg_cpu_element_type_node,
489 }
490}
491