Skip to main content

bb_ops/aggregators/fedavg/
mod.rs

1//! `FedAvg<B>` — federated-averaging aggregator. Composes the
2//! reduction from the bound backend's `Mul` + `Add` primitives so
3//! the 30-op floor stays unchanged. Aggregate emits the cumulative
4//! `num_samples` for hierarchical weighting.
5//!
6//! Trust model: contributions are assumed finite. NaN/Inf
7//! propagates per IEEE 754 and poisons the round; defenses belong
8//! at the contribution boundary (signed Codec, attesting
9//! PeerSelector, secure-aggregation protocol).
10
11use std::any::Any;
12use std::collections::BTreeMap;
13use std::marker::PhantomData;
14
15use serde::{Deserialize, Serialize};
16
17#[cfg(feature = "cpu-backend")]
18use bb_ir::component::ErasedComponent;
19use bb_ir::component::{AnyComponent, DependencyDecl, RestoreError};
20use bb_ir::ids::PeerId;
21use bb_ir::proto::onnx::TensorProto;
22use bb_ir::tensor::Tensor;
23use bb_ir::types::common_relations::NO_RELATIONS;
24use bb_runtime::atomic::{AtomicOpDecl, AtomicOpKind, AtomicOpsetDecl, DispatchResult};
25use bb_runtime::bus::{OpError, OpErrorKind};
26use bb_runtime::completion::{CompletionHandle, ContractResponse};
27use bb_runtime::concrete::{ComponentPackage, ConcreteComponent};
28use bb_runtime::contracts::{Aggregator as AggregatorContract, Backend};
29use bb_runtime::roles::AggregatorRuntime;
30use bb_runtime::runtime::RuntimeResourceRef;
31use bb_runtime::slot_value::SlotValue;
32
33/// Sample-count metadata for FedAvg contributions / aggregates.
34#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
35pub struct FedAvgMeta {
36    /// Single-layer: local batch size. Hierarchical: subtree sum,
37    /// used to weight this aggregate's contribution at parent level.
38    pub num_samples: u64,
39}
40
41/// ONNX `DataType::FLOAT`.
42const ONNX_FLOAT: i32 = 1;
43
44/// `<B::Tensor as Storage>::TYPE`. Inventory-gated under
45/// `cpu-backend` to avoid unused monomorphizations.
46#[cfg(feature = "cpu-backend")]
47fn fedavg_element_type<B: Backend>() -> &'static bb_ir::types::TypeNode {
48    <B::Tensor as bb_ir::types::Storage>::TYPE
49}
50
51/// FedAvg aggregator: weighted average where weights are
52/// `num_samples / total_samples`. Generic over `B: Backend` so the
53/// reduction composes from the backend's `Mul` + `Add` primitives
54/// without bumping the 30-op floor.
55///
56/// The buffer is a `BTreeMap<PeerId, (B::Tensor, u64)>` so the
57/// reduction walk has a deterministic iteration order (lexical by
58/// peer id). The engine's `dispatch_atomic(&mut self, ...)` contract
59/// gives the aggregator exclusive access for the duration of each
60/// call. The buffer is `#[serde(skip)]` because snapshot captures
61/// the aggregator's structural identity, not the per-round transient
62/// contribution state.
63#[derive(Debug, Serialize, Deserialize)]
64pub struct FedAvg<B: Backend> {
65    /// Per-round buffer keyed by source `PeerId`. Duplicate
66    /// contributions from the same peer in the same round REPLACE
67    /// the prior entry, so a buggy or malicious peer cannot double
68    /// its weight by contributing twice. `BTreeMap` is the ordered
69    /// peer-id walk the spec's determinism guarantee relies on (the
70    /// reduction's f32 accumulation order is the BTree's lexical
71    /// order, not a hash-map's runtime-randomized order).
72    #[serde(skip)]
73    buffer: BTreeMap<PeerId, (B::Tensor, u64)>,
74    #[serde(skip)]
75    _backend: PhantomData<B>,
76}
77
78impl<B: Backend> Default for FedAvg<B> {
79    fn default() -> Self {
80        Self {
81            buffer: BTreeMap::new(),
82            _backend: PhantomData,
83        }
84    }
85}
86
87impl<B: Backend> Clone for FedAvg<B> {
88    fn clone(&self) -> Self {
89        // Cloning returns a fresh empty buffer — snapshots restore
90        // via `restore` (the universal `ConcreteComponent` path),
91        // not through `Clone`. Clone is only used to satisfy the
92        // framework's `T: Clone` bounds.
93        Self::default()
94    }
95}
96
97impl<B: Backend> AggregatorContract for FedAvg<B>
98where
99    B: 'static,
100    B::Tensor: Tensor,
101{
102    type Element = B::Tensor;
103    type Error = OpError;
104    type Metadata = FedAvgMeta;
105
106    fn contribute(
107        &mut self,
108        _ctx: &mut RuntimeResourceRef<'_>,
109        src: PeerId,
110        tensor: &Self::Element,
111        metadata: FedAvgMeta,
112        _completion: CompletionHandle<(), Self::Error>,
113    ) -> ContractResponse<(), Self::Error> {
114        // Reject zero-sample contributions: an `n=0` entry would
115        // contribute zero weight to the reduction yet still
116        // displace a real same-peer contribution from the buffer.
117        if metadata.num_samples == 0 {
118            return ContractResponse::Now(Err(OpError {
119                detail: "FedAvg::contribute: num_samples = 0 — degenerate weight".into(),
120                ..Default::default()
121            }));
122        }
123        // Keying on src prevents a peer from doubling its weight by
124        // contributing twice in one round: the second entry replaces
125        // the first rather than landing alongside it.
126        self.buffer
127            .insert(src, (tensor.clone(), metadata.num_samples));
128        ContractResponse::Now(Ok(()))
129    }
130
131    fn aggregate(
132        &mut self,
133        ctx: &mut RuntimeResourceRef<'_>,
134        _completion: CompletionHandle<(Box<Self::Element>, FedAvgMeta), Self::Error>,
135    ) -> ContractResponse<(Box<Self::Element>, FedAvgMeta), Self::Error> {
136        let backend = match ctx.dependency::<B>("backend") {
137            Ok(b) => b,
138            Err(e) => {
139                return ContractResponse::Now(Err(OpError {
140                    detail: format!("FedAvg::aggregate: backend lookup failed: {e}"),
141                    ..Default::default()
142                }));
143            }
144        };
145
146        let entries: Vec<(B::Tensor, u64)> =
147            std::mem::take(&mut self.buffer).into_values().collect();
148        if entries.is_empty() {
149            return ContractResponse::Now(Err(OpError {
150                detail: "FedAvg::aggregate: empty buffer — no contributions to reduce".into(),
151                ..Default::default()
152            }));
153        }
154
155        let total_samples: u64 = entries.iter().map(|(_, n)| *n).sum();
156        if total_samples == 0 {
157            return ContractResponse::Now(Err(OpError {
158                detail: "FedAvg::aggregate: total_samples = 0".into(),
159                ..Default::default()
160            }));
161        }
162        let total_f = total_samples as f32;
163
164        // Determine the canonical output shape from the first
165        // contribution; later contributions of mismatched shape will
166        // be rejected by the backend's elementwise kernel. The shape
167        // also drives the per-peer weight tensor's construction —
168        // `CpuBackend`'s `Mul` requires same-shape inputs (full
169        // NumPy broadcasting isn't implemented yet), so the weight
170        // is materialized at the canonical shape rather than as a
171        // length-1 broadcast scalar.
172        let canonical_dims: Vec<i64> = entries[0].0.dims().to_vec();
173        let canonical_len: usize = canonical_dims
174            .iter()
175            .map(|d| (*d).max(0) as usize)
176            .product();
177
178        let mut acc: Option<B::Tensor> = None;
179        for (tensor, n) in &entries {
180            let w = (*n as f32) / total_f;
181            let weight_proto = TensorProto {
182                data_type: ONNX_FLOAT,
183                dims: canonical_dims.clone(),
184                float_data: vec![w; canonical_len],
185                ..Default::default()
186            };
187            let weight = match backend.constant(weight_proto) {
188                Ok(t) => t,
189                Err(e) => {
190                    return ContractResponse::Now(Err(OpError {
191                        detail: format!("FedAvg::aggregate: backend.constant failed: {e}"),
192                        ..Default::default()
193                    }));
194                }
195            };
196            let scaled = match backend.mul(tensor, &weight) {
197                Ok(t) => t,
198                Err(e) => {
199                    return ContractResponse::Now(Err(OpError {
200                        detail: format!("FedAvg::aggregate: backend.mul failed: {e}"),
201                        ..Default::default()
202                    }));
203                }
204            };
205            acc = Some(match acc {
206                None => scaled,
207                Some(prev) => match backend.add(&prev, &scaled) {
208                    Ok(t) => t,
209                    Err(e) => {
210                        return ContractResponse::Now(Err(OpError {
211                            detail: format!("FedAvg::aggregate: backend.add failed: {e}"),
212                            ..Default::default()
213                        }));
214                    }
215                },
216            });
217        }
218
219        let params = acc.expect("entries non-empty implies acc populated");
220        ContractResponse::Now(Ok((
221            Box::new(params),
222            FedAvgMeta {
223                num_samples: total_samples,
224            },
225        )))
226    }
227}
228
229// ─── Manual ConcreteComponent + AnyComponent + role plumbing ──────
230//
231// `bb_derive::Aggregator` does not handle generic structs; the
232// inventory submissions below cover every monomorphization the
233// framework needs (currently `FedAvg<CpuBackend>` when the
234// `cpu-backend` feature is on). Generic-impl support in the derive
235// is out of scope for the dep-injection milestone.
236
237impl<B: Backend> ConcreteComponent for FedAvg<B>
238where
239    B: 'static + Default,
240{
241    const TYPE_NAME: &'static str = "FedAvg";
242    const PACKAGE: ComponentPackage = ComponentPackage::Framework;
243    const DEPENDENCIES: &'static [DependencyDecl] = &[DependencyDecl {
244        role: "Backend",
245        slot: "backend",
246    }];
247
248    type Config = ();
249    type Error = std::convert::Infallible;
250
251    fn new(_: &Self::Config) -> Result<Self, Self::Error> {
252        Ok(Self::default())
253    }
254
255    fn serialize(&self) -> Vec<u8> {
256        bincode::serialize(self).expect("FedAvg serialize — bincode infallible on Default state")
257    }
258
259    fn restore(bytes: &[u8]) -> Result<Self, RestoreError> {
260        bincode::deserialize(bytes).map_err(RestoreError::Malformed)
261    }
262}
263
264impl<B: Backend + 'static> AnyComponent for FedAvg<B> {
265    fn as_any(&self) -> &dyn Any {
266        self
267    }
268    fn as_any_mut(&mut self) -> &mut dyn Any {
269        self
270    }
271}
272
273/// Atomic opset for the `FedAvg<B>` aggregator. Names align with
274/// the canonical Aggregator role surface
275/// (`emit_aggregator_arms` in `bb-derive`).
276static FEDAVG_ATOMIC_OPS: &[AtomicOpDecl] = &[
277    AtomicOpDecl {
278        name: "Contribute",
279        inputs: &[],
280        outputs: &[],
281        kind: AtomicOpKind::Immediate,
282        type_relations: NO_RELATIONS,
283    },
284    AtomicOpDecl {
285        name: "Aggregate",
286        inputs: &[],
287        outputs: &[],
288        kind: AtomicOpKind::Immediate,
289        type_relations: NO_RELATIONS,
290    },
291];
292
293impl<B> AggregatorRuntime for FedAvg<B>
294where
295    B: Backend + 'static + Default,
296    B::Tensor: Tensor,
297{
298    type Error = OpError;
299
300    fn atomic_opset(&self) -> AtomicOpsetDecl {
301        AtomicOpsetDecl {
302            domain: "ai.bytesandbrains.role.aggregator",
303            version: 1,
304            ops: FEDAVG_ATOMIC_OPS,
305        }
306    }
307
308    fn dispatch_atomic(
309        &mut self,
310        op_type: &str,
311        inputs: &[(&str, &dyn SlotValue)],
312        ctx: &mut RuntimeResourceRef<'_>,
313    ) -> Result<DispatchResult, Self::Error> {
314        match op_type {
315            "Contribute" => {
316                // Borrow the boxed tensor through the SlotValue ref;
317                // `contribute` takes `&Self::Element` so no owned copy
318                // is needed at the dispatch boundary. The downstream
319                // buffer insertion in `contribute` is the single
320                // remaining tensor copy per contribution.
321                let tensor_ref: &B::Tensor = match inputs
322                    .first()
323                    .and_then(|(_, v)| v.as_any().downcast_ref::<Box<B::Tensor>>())
324                {
325                    Some(b) => b,
326                    None => {
327                        return Err(OpError {
328                            kind: OpErrorKind::TypeMismatch,
329                            reason: "input_type_mismatch",
330                            detail: format!(
331                                "FedAvg::Contribute input 0 expected `Box<{}>`",
332                                std::any::type_name::<B::Tensor>(),
333                            ),
334                        });
335                    }
336                };
337                let metadata = match inputs
338                    .get(1)
339                    .and_then(|(_, v)| v.as_any().downcast_ref::<FedAvgMeta>())
340                {
341                    Some(m) => m.clone(),
342                    None => {
343                        return Err(OpError {
344                            kind: OpErrorKind::TypeMismatch,
345                            reason: "input_type_mismatch",
346                            detail: "FedAvg::Contribute input 1 expected `FedAvgMeta`".into(),
347                        });
348                    }
349                };
350                let src = match ctx.current.inbound.src_peer {
351                    Some(p) => p,
352                    None => {
353                        return Err(OpError {
354                            detail: "FedAvg::Contribute: envelope_src_peer is None — wire envelope did not carry src_peer multihash bytes".into(),
355                            ..Default::default()
356                        });
357                    }
358                };
359                let completion = ctx.open_completion::<(), OpError>();
360                let cmd_id = completion.cmd_id();
361                match <Self as AggregatorContract>::contribute(
362                    self, ctx, src, tensor_ref, metadata, completion,
363                ) {
364                    ContractResponse::Now(Ok(())) => Ok(DispatchResult::Immediate(Vec::new())),
365                    ContractResponse::Now(Err(e)) => Err(OpError {
366                        detail: format!("{e}"),
367                        ..Default::default()
368                    }),
369                    ContractResponse::Later => Ok(DispatchResult::Async(cmd_id)),
370                }
371            }
372            "Aggregate" => {
373                let completion = ctx.open_completion::<(Box<B::Tensor>, FedAvgMeta), OpError>();
374                let cmd_id = completion.cmd_id();
375                match <Self as AggregatorContract>::aggregate(self, ctx, completion) {
376                    ContractResponse::Now(Ok((params, metadata))) => {
377                        Ok(DispatchResult::Immediate(vec![
378                            ("params".to_string(), Box::new(params) as Box<dyn SlotValue>),
379                            (
380                                "metadata".to_string(),
381                                Box::new(metadata) as Box<dyn SlotValue>,
382                            ),
383                        ]))
384                    }
385                    ContractResponse::Now(Err(e)) => Err(OpError {
386                        detail: format!("{e}"),
387                        ..Default::default()
388                    }),
389                    ContractResponse::Later => Ok(DispatchResult::Async(cmd_id)),
390                }
391            }
392            other => Err(OpError {
393                detail: format!("FedAvg::dispatch_atomic: unknown op_type `{other}`"),
394                ..Default::default()
395            }),
396        }
397    }
398}
399
400// ─── Inventory submissions — `FedAvg<CpuBackend>` monomorphization ─
401//
402// Inventory carriers can only register concrete monomorphizations;
403// each supported backend submits its own block (the `cpu-backend`
404// feature is the only one shipping today).
405
406#[cfg(feature = "cpu-backend")]
407type FedAvgCpu = FedAvg<crate::backends::cpu::CpuBackend>;
408
409#[cfg(feature = "cpu-backend")]
410#[doc(hidden)]
411fn __fedavg_cpu_serialize(erased: &dyn ErasedComponent) -> Vec<u8> {
412    let any: &dyn Any = erased;
413    let concrete: &FedAvgCpu = any
414        .downcast_ref::<FedAvgCpu>()
415        .expect("inventory downcast: FedAvg<CpuBackend>");
416    <FedAvgCpu as ConcreteComponent>::serialize(concrete)
417}
418
419#[cfg(feature = "cpu-backend")]
420#[doc(hidden)]
421fn __fedavg_cpu_restore(bytes: &[u8]) -> Result<Box<dyn ErasedComponent>, RestoreError> {
422    <FedAvgCpu as ConcreteComponent>::restore(bytes)
423        .map(|v| Box::new(v) as Box<dyn ErasedComponent>)
424}
425
426#[cfg(feature = "cpu-backend")]
427#[doc(hidden)]
428fn __fedavg_cpu_construct(
429    cfg: &dyn Any,
430) -> Result<Box<dyn ErasedComponent>, bb_runtime::concrete::ConstructError> {
431    let typed = cfg
432        .downcast_ref::<()>()
433        .ok_or_else(|| bb_runtime::concrete::ConstructError {
434            type_name: "FedAvg",
435            detail: "config type mismatch: expected `()`".into(),
436        })?;
437    <FedAvgCpu as ConcreteComponent>::new(typed)
438        .map(|v| Box::new(v) as Box<dyn ErasedComponent>)
439        .map_err(|e| bb_runtime::concrete::ConstructError {
440            type_name: "FedAvg",
441            detail: format!("{e}"),
442        })
443}
444
445#[cfg(feature = "cpu-backend")]
446#[doc(hidden)]
447fn __fedavg_cpu_element_type_node() -> &'static bb_ir::types::TypeNode {
448    fedavg_element_type::<crate::backends::cpu::CpuBackend>()
449}
450
451#[cfg(feature = "cpu-backend")]
452inventory::submit! {
453    bb_runtime::registry::ConcreteComponentRegistration {
454        type_name: "FedAvg",
455        package: ComponentPackage::Framework,
456        serialize_fn: __fedavg_cpu_serialize,
457        restore_fn: __fedavg_cpu_restore,
458        construct_fn: __fedavg_cpu_construct,
459        dependencies: <FedAvgCpu as ConcreteComponent>::DEPENDENCIES,
460    }
461}
462
463#[cfg(feature = "cpu-backend")]
464inventory::submit! {
465    bb_runtime::registry::ComponentRoleBinding {
466        type_name: "FedAvg",
467        role: bb_runtime::registry::ComponentRole::Aggregator,
468    }
469}
470
471#[cfg(feature = "cpu-backend")]
472inventory::submit! {
473    bb_runtime::registry::DispatcherRegistration {
474        type_name: "FedAvg",
475        role: bb_runtime::registry::ComponentRole::Aggregator,
476        register_fn: |engine: &mut bb_runtime::engine::Engine| {
477            engine.register_aggregator_dispatcher::<FedAvgCpu>();
478        },
479    }
480}
481
482#[cfg(feature = "cpu-backend")]
483inventory::submit! {
484    bb_runtime::registry::StorageTypeEntry {
485        concrete_type_name: <FedAvgCpu as ConcreteComponent>::TYPE_NAME,
486        role_runtime: "AggregatorRuntime",
487        port: "element",
488        type_node_fn: __fedavg_cpu_element_type_node,
489    }
490}
491