bb_runtime/contracts/aggregator.rs
1//! `bb::Aggregator` — Contract trait for federated aggregators.
2//!
3//! Each method takes a [`CompletionHandle`] AND returns
4//! [`ContractResponse`]. See [`crate::contracts::index`] for the
5//! sync (Now) vs async (Later) semantics.
6//!
7//! ## Shape
8//!
9//! Aggregation is a two-op cycle: `contribute(...)` writes one
10//! peer's update into an in-progress buffer; `aggregate(...)`
11//! reduces the buffer into the current aggregate AND returns it.
12//! There is no separate `current_tensor()` op — `aggregate` is the
13//! one-stop "compute + emit" call.
14//!
15//! ## Metadata channel
16//!
17//! Both `contribute` and `aggregate` carry a **typed** metadata
18//! payload alongside the tensor, defined by the impl as the
19//! associated type [`Aggregator::Metadata`].
20//!
21//! The metadata is transported through the slot table as a typed
22//! Rust value — the framework's slot-value layer (`bb_ir::slot_value`)
23//! holds every value as `Box<dyn SlotValue>` and downcasts to the
24//! concrete type via `Any::downcast_ref`. Bincode/serde fires only
25//! at the wire boundary (`SlotValue::to_wire_bytes`) and at
26//! snapshot time. In-process contribute/aggregate calls see the
27//! typed value directly — no serde overhead.
28//!
29//! This is the channel hierarchical aggregation needs: a child
30//! `FedAvg` aggregator's `aggregate(...)` emits
31//! `(params, FedAvgMeta { num_samples })`; the parent layer's
32//! `contribute(...)` receives that and the `num_samples` weights
33//! the child's contribution in the parent reduction. Both halves
34//! work with the typed `FedAvgMeta` — only the wire crossing does
35//! serde.
36//!
37//! Impls that have no metadata channel set `type Metadata = ();`.
38
39use crate::completion::{CompletionHandle, ContractResponse};
40use crate::runtime::RuntimeResourceRef;
41use bb_ir::ids::PeerId;
42
43/// User-facing Contract trait for a federated/decentralized
44/// aggregator. The derive bridges these methods to the engine's
45/// [`crate::roles::AggregatorRuntime`] trait.
46pub trait Aggregator: Send + Sync {
47 /// Storage element type for the tensors this aggregator
48 /// operates on. Most f32-native aggregators declare
49 /// `type Element = [f32]`.
50 ///
51 /// The bound `?Sized + bb_ir::types::Storage` allows unsized
52 /// slice types like `[f32]` (a `Box<[f32]>` is the owned form
53 /// returned from `aggregate`).
54 type Element: ?Sized + bb_ir::types::Storage;
55
56 /// Library-maker-defined error type.
57 type Error: std::error::Error + std::fmt::Display + Send + Sync + 'static;
58
59 /// Impl-defined metadata that travels alongside the tensor.
60 /// Carried as a typed slot value; serde fires only when the
61 /// value crosses a wire boundary.
62 ///
63 /// For FedAvg: `type Metadata = FedAvgMeta { num_samples: u64 };`.
64 /// For impls with no metadata channel: `type Metadata = ();`.
65 type Metadata: Clone
66 + Default
67 + serde::Serialize
68 + for<'de> serde::Deserialize<'de>
69 + Send
70 + Sync
71 + 'static;
72
73 /// Contribute one peer's update to the in-progress aggregation.
74 /// `ctx` is the per-dispatch runtime surface; impls reach their
75 /// declared `#[depends(...)]` siblings through
76 /// [`RuntimeResourceRef::dependency`]. `tensor` is a reference
77 /// to the element (e.g. `&[f32]` for `Element = [f32]`).
78 /// `metadata` is the typed accompanying data (sample counts for
79 /// FedAvg, weights for weighted sum, round ids, …).
80 /// Default-constructed `Metadata` is valid for impls that don't
81 /// have a real metadata channel.
82 fn contribute(
83 &mut self,
84 ctx: &mut RuntimeResourceRef<'_>,
85 src: PeerId,
86 tensor: &Self::Element,
87 metadata: Self::Metadata,
88 completion: CompletionHandle<(), Self::Error>,
89 ) -> ContractResponse<(), Self::Error>;
90
91 /// Reduce the accumulated contributions and return the result.
92 /// `ctx` carries the runtime surface so the aggregator's
93 /// reduction can resolve `#[depends(...)]` siblings (e.g. the
94 /// `Backend` that supplies the composed weighted-sum).
95 /// Output is `(params, metadata)`:
96 /// - `params`: the aggregated tensor, owned as
97 /// `Box<Self::Element>` (e.g. `Box<[f32]>`). Same allocator
98 /// footprint as a `Vec<f32>` — use `vec.into_boxed_slice()`.
99 /// - `metadata`: typed accompanying data describing the
100 /// aggregation (e.g. summed `num_samples` for hierarchical
101 /// FedAvg).
102 ///
103 /// The output edge fires only when the reduction completes;
104 /// downstream consumers wire directly to the `(params,
105 /// metadata)` outputs — no separate read op needed.
106 #[allow(clippy::type_complexity)]
107 fn aggregate(
108 &mut self,
109 ctx: &mut RuntimeResourceRef<'_>,
110 completion: CompletionHandle<(Box<Self::Element>, Self::Metadata), Self::Error>,
111 ) -> ContractResponse<(Box<Self::Element>, Self::Metadata), Self::Error>;
112}