Skip to main content

dynomite/cluster/capability/
registry.rs

1//! Type-erased capability registry plus the [`Capability`] trait.
2//!
3//! Capabilities advertise typed values, but the registry must
4//! treat them uniformly so it can encode an advertisement,
5//! receive a peer ad, and look up negotiated values by name. The
6//! erasure is handled internally; the public surface stays
7//! generic over the user's [`Capability::Value`] type.
8
9use std::any::Any;
10use std::collections::HashMap;
11
12use parking_lot::RwLock;
13
14use crate::cluster::capability::negotiator::{negotiate_with_floor, NegotiatedCapabilities};
15
16/// Magic prefix marking an encoded [`CapabilityAd`] blob.
17///
18/// The four bytes are stable on the wire; bumping the format
19/// requires bumping the trailing version byte.
20const CAP_AD_MAGIC: [u8; 3] = *b"CAP";
21
22/// Format version embedded in encoded [`CapabilityAd`] blobs.
23const CAP_AD_VERSION: u8 = 1;
24
25/// Trait every capability implements.
26///
27/// `Value` is the typed representation of a single supported
28/// value. The registry serialises values via
29/// [`Capability::encode_value`] / [`Capability::decode_value`] so
30/// no third-party serialisation dependency is required.
31///
32/// # Examples
33///
34/// ```
35/// use dynomite::cluster::capability::Capability;
36///
37/// struct Bool;
38/// impl Capability for Bool {
39///     type Value = bool;
40///     fn name(&self) -> &'static str { "feature" }
41///     fn supported_values(&self) -> Vec<bool> { vec![false, true] }
42///     fn merge(&self, peer: &[bool]) -> Option<bool> {
43///         if peer.contains(&true) { Some(true) }
44///         else if peer.contains(&false) { Some(false) }
45///         else { None }
46///     }
47///     fn encode_value(&self, v: &bool) -> Vec<u8> { vec![u8::from(*v)] }
48///     fn decode_value(&self, b: &[u8]) -> Option<bool> {
49///         match b { [0] => Some(false), [1] => Some(true), _ => None }
50///     }
51/// }
52/// ```
53pub trait Capability: Any + Send + Sync + 'static {
54    /// The typed value this capability negotiates.
55    type Value: Clone + Eq + Send + Sync + 'static;
56
57    /// Stable on-the-wire name. Must be ASCII.
58    fn name(&self) -> &'static str;
59
60    /// Locally supported values, ordered from lowest preference
61    /// to highest preference. The first element is also used as
62    /// the "floor" when negotiation finds no overlap.
63    fn supported_values(&self) -> Vec<Self::Value>;
64
65    /// Returns the highest local value also supported by `peer`,
66    /// or `None` when there is no overlap. The notion of
67    /// "highest" is owned by the implementation.
68    fn merge(&self, peer_supports: &[Self::Value]) -> Option<Self::Value>;
69
70    /// Serialise a value to a stable byte sequence. Used to
71    /// build the on-the-wire advertisement.
72    fn encode_value(&self, value: &Self::Value) -> Vec<u8>;
73
74    /// Inverse of [`Capability::encode_value`]. Returning `None`
75    /// causes the registry to drop the malformed value when
76    /// merging a peer ad.
77    fn decode_value(&self, bytes: &[u8]) -> Option<Self::Value>;
78}
79
80/// Errors produced while decoding a [`CapabilityAd`] blob.
81#[derive(Debug, thiserror::Error)]
82pub enum CapabilityCodecError {
83    /// Buffer ended mid-record.
84    #[error("capability advertisement truncated")]
85    Truncated,
86    /// The leading magic / version did not match.
87    #[error("capability advertisement: invalid magic or version")]
88    BadMagic,
89    /// A capability name contained a non-ASCII byte.
90    #[error("capability advertisement: non-ASCII capability name")]
91    NonAsciiName,
92    /// The encoded entry count exceeded the safety bound.
93    #[error("capability advertisement: too many entries ({0})")]
94    TooManyEntries(usize),
95}
96
97/// One entry in a [`CapabilityAd`]: a capability name and the
98/// list of opaque, capability-defined value blobs the advertising
99/// peer supports.
100#[derive(Clone, Debug, Eq, PartialEq)]
101pub struct CapabilityAdEntry {
102    name: String,
103    supported: Vec<Vec<u8>>,
104}
105
106impl CapabilityAdEntry {
107    /// Build an entry from already-encoded value blobs.
108    #[must_use]
109    pub fn new(name: String, supported: Vec<Vec<u8>>) -> Self {
110        Self { name, supported }
111    }
112
113    /// Capability name.
114    #[must_use]
115    pub fn name(&self) -> &str {
116        &self.name
117    }
118
119    /// Supported value blobs in the advertiser's preference
120    /// order.
121    #[must_use]
122    pub fn supported(&self) -> &[Vec<u8>] {
123        &self.supported
124    }
125}
126
127/// On-the-wire advertisement built by [`CapabilityRegistry::local_advertise`].
128#[derive(Clone, Debug, Default, Eq, PartialEq)]
129pub struct CapabilityAd {
130    entries: Vec<CapabilityAdEntry>,
131}
132
133/// Maximum number of entries we will encode into or decode from
134/// a single advertisement. Intentionally far above what we ever
135/// expect to use; the bound exists to reject garbage payloads.
136const CAP_AD_MAX_ENTRIES: usize = 1024;
137
138/// Maximum byte length of a single value blob inside an entry.
139const CAP_AD_MAX_VALUE_LEN: usize = 16 * 1024;
140
141/// Maximum byte length of a single capability name.
142const CAP_AD_MAX_NAME_LEN: usize = 256;
143
144impl CapabilityAd {
145    /// Construct an empty advertisement.
146    #[must_use]
147    pub fn new() -> Self {
148        Self::default()
149    }
150
151    /// Build an advertisement from pre-shaped entries.
152    #[must_use]
153    pub fn from_entries(entries: Vec<CapabilityAdEntry>) -> Self {
154        Self { entries }
155    }
156
157    /// Read-only view of the advertised entries.
158    #[must_use]
159    pub fn entries(&self) -> &[CapabilityAdEntry] {
160        &self.entries
161    }
162
163    /// Serialise the advertisement to a length-prefixed byte
164    /// stream. The encoding is stable, ASCII-clean for capability
165    /// names, and uses only the standard library.
166    ///
167    /// # Examples
168    ///
169    /// ```
170    /// use dynomite::cluster::capability::{CapabilityAd, CapabilityAdEntry};
171    /// let ad = CapabilityAd::from_entries(vec![
172    ///     CapabilityAdEntry::new("framing".into(), vec![vec![1, 0, 0, 0]]),
173    /// ]);
174    /// let bytes = ad.encode();
175    /// let back = CapabilityAd::decode(&bytes).unwrap();
176    /// assert_eq!(back, ad);
177    /// ```
178    #[must_use]
179    pub fn encode(&self) -> Vec<u8> {
180        let mut out = Vec::with_capacity(8 + self.entries.len() * 32);
181        out.extend_from_slice(&CAP_AD_MAGIC);
182        out.push(CAP_AD_VERSION);
183        let count = u32::try_from(self.entries.len()).unwrap_or(u32::MAX);
184        out.extend_from_slice(&count.to_le_bytes());
185        for entry in &self.entries {
186            let name_bytes = entry.name.as_bytes();
187            let name_len = u16::try_from(name_bytes.len()).unwrap_or(u16::MAX);
188            out.extend_from_slice(&name_len.to_le_bytes());
189            out.extend_from_slice(name_bytes);
190            let val_count = u16::try_from(entry.supported.len()).unwrap_or(u16::MAX);
191            out.extend_from_slice(&val_count.to_le_bytes());
192            for value in &entry.supported {
193                let vlen = u32::try_from(value.len()).unwrap_or(u32::MAX);
194                out.extend_from_slice(&vlen.to_le_bytes());
195                out.extend_from_slice(value);
196            }
197        }
198        out
199    }
200
201    /// Inverse of [`CapabilityAd::encode`]. Rejects malformed or
202    /// truncated input with a typed error.
203    pub fn decode(mut bytes: &[u8]) -> Result<Self, CapabilityCodecError> {
204        if bytes.len() < CAP_AD_MAGIC.len() + 1 + 4 {
205            return Err(CapabilityCodecError::Truncated);
206        }
207        if bytes[..CAP_AD_MAGIC.len()] != CAP_AD_MAGIC {
208            return Err(CapabilityCodecError::BadMagic);
209        }
210        bytes = &bytes[CAP_AD_MAGIC.len()..];
211        if bytes[0] != CAP_AD_VERSION {
212            return Err(CapabilityCodecError::BadMagic);
213        }
214        bytes = &bytes[1..];
215        let count = read_u32(&mut bytes)?;
216        let count_us = usize::try_from(count).unwrap_or(usize::MAX);
217        if count_us > CAP_AD_MAX_ENTRIES {
218            return Err(CapabilityCodecError::TooManyEntries(count_us));
219        }
220        let mut entries = Vec::with_capacity(count_us);
221        for _ in 0..count_us {
222            let name_len = read_u16(&mut bytes)? as usize;
223            if name_len > CAP_AD_MAX_NAME_LEN {
224                return Err(CapabilityCodecError::TooManyEntries(name_len));
225            }
226            let name_bytes = read_slice(&mut bytes, name_len)?;
227            if !name_bytes.is_ascii() {
228                return Err(CapabilityCodecError::NonAsciiName);
229            }
230            // Safe because we just checked is_ascii: ASCII bytes
231            // are valid UTF-8 by construction.
232            let name = std::str::from_utf8(name_bytes)
233                .map_err(|_| CapabilityCodecError::NonAsciiName)?
234                .to_string();
235            let val_count = read_u16(&mut bytes)? as usize;
236            let mut supported = Vec::with_capacity(val_count);
237            for _ in 0..val_count {
238                let vlen = read_u32(&mut bytes)? as usize;
239                if vlen > CAP_AD_MAX_VALUE_LEN {
240                    return Err(CapabilityCodecError::TooManyEntries(vlen));
241                }
242                let vbytes = read_slice(&mut bytes, vlen)?;
243                supported.push(vbytes.to_vec());
244            }
245            entries.push(CapabilityAdEntry::new(name, supported));
246        }
247        Ok(Self { entries })
248    }
249}
250
251fn read_slice<'a>(cur: &mut &'a [u8], len: usize) -> Result<&'a [u8], CapabilityCodecError> {
252    if cur.len() < len {
253        return Err(CapabilityCodecError::Truncated);
254    }
255    let (head, tail) = cur.split_at(len);
256    *cur = tail;
257    Ok(head)
258}
259
260fn read_u16(cur: &mut &[u8]) -> Result<u16, CapabilityCodecError> {
261    let bytes = read_slice(cur, 2)?;
262    let arr: [u8; 2] = bytes.try_into().expect("invariant: read_slice(2)");
263    Ok(u16::from_le_bytes(arr))
264}
265
266fn read_u32(cur: &mut &[u8]) -> Result<u32, CapabilityCodecError> {
267    let bytes = read_slice(cur, 4)?;
268    let arr: [u8; 4] = bytes.try_into().expect("invariant: read_slice(4)");
269    Ok(u32::from_le_bytes(arr))
270}
271
272/// Type-erased merge closure stored alongside each registered
273/// capability. Defined as a type alias so the `Slot` definition
274/// stays readable.
275type MergeFn = Box<dyn Fn(&[Vec<u8>]) -> Option<Vec<u8>> + Send + Sync>;
276
277/// Slot stored in the registry for one registered capability.
278pub(crate) struct Slot {
279    /// The original boxed cap, kept type-erased so callers can
280    /// downcast back via [`Capability::decode_value`] inside
281    /// [`CapabilityRegistry::current`].
282    cap: Box<dyn Any + Send + Sync>,
283    /// Locally supported values pre-encoded for fast ad
284    /// generation.
285    supported_bytes: Vec<Vec<u8>>,
286    /// Floor value, pre-encoded. Used when negotiation finds no
287    /// overlap.
288    floor_bytes: Vec<u8>,
289    /// Type-erased merge: takes peer-supplied byte blobs, picks
290    /// the highest common value, returns it pre-encoded.
291    merge: MergeFn,
292}
293
294/// Per-node registry that owns capability instances, generates
295/// the local advertisement, and stores the most recently
296/// negotiated value for each capability.
297pub struct CapabilityRegistry {
298    slots: HashMap<&'static str, Slot>,
299    negotiated: RwLock<HashMap<String, Vec<u8>>>,
300}
301
302impl Default for CapabilityRegistry {
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308impl std::fmt::Debug for CapabilityRegistry {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        f.debug_struct("CapabilityRegistry")
311            .field("registered", &self.slots.keys().collect::<Vec<_>>())
312            .finish_non_exhaustive()
313    }
314}
315
316impl CapabilityRegistry {
317    /// Construct an empty registry.
318    #[must_use]
319    pub fn new() -> Self {
320        Self {
321            slots: HashMap::new(),
322            negotiated: RwLock::new(HashMap::new()),
323        }
324    }
325
326    /// Register a capability. Re-registering a capability with
327    /// the same name replaces the previous entry; the cluster
328    /// layer never registers two caps with the same name, so
329    /// this only matters for tests.
330    pub fn register<C: Capability>(&mut self, cap: C) {
331        let name = cap.name();
332        assert!(name.is_ascii(), "capability name must be ASCII: {name:?}");
333        let supported_bytes: Vec<Vec<u8>> = cap
334            .supported_values()
335            .iter()
336            .map(|v| cap.encode_value(v))
337            .collect();
338        // Floor: the first element of supported_values, the
339        // lowest-preferred local value. Documented in the trait
340        // contract.
341        let floor_bytes = supported_bytes
342            .first()
343            .cloned()
344            .expect("capability must declare at least one supported value");
345        // Build a type-erased merge closure that decodes peer
346        // blobs, calls the typed merge, and re-encodes.
347        let cap_arc = std::sync::Arc::new(cap);
348        let cap_for_merge = cap_arc.clone();
349        let merge: MergeFn = Box::new(move |peer_blobs: &[Vec<u8>]| {
350            let peer: Vec<C::Value> = peer_blobs
351                .iter()
352                .filter_map(|b| cap_for_merge.decode_value(b))
353                .collect();
354            cap_for_merge
355                .merge(&peer)
356                .map(|v| cap_for_merge.encode_value(&v))
357        });
358        let cap_any: Box<dyn Any + Send + Sync> = Box::new(cap_arc);
359        self.slots.insert(
360            name,
361            Slot {
362                cap: cap_any,
363                supported_bytes,
364                floor_bytes,
365                merge,
366            },
367        );
368        // Drop any stale negotiated entry for this name so
369        // re-registration starts from the floor.
370        self.negotiated.write().remove(name);
371    }
372
373    /// Build the advertisement to ship to peers.
374    #[must_use]
375    pub fn local_advertise(&self) -> CapabilityAd {
376        let mut entries: Vec<CapabilityAdEntry> = self
377            .slots
378            .iter()
379            .map(|(name, slot)| {
380                CapabilityAdEntry::new((*name).to_string(), slot.supported_bytes.clone())
381            })
382            .collect();
383        // Stable order: alphabetical by name. The HashMap
384        // iteration order is otherwise nondeterministic, which
385        // would make test assertions and gossip diffs flaky.
386        entries.sort_by(|a, b| a.name().cmp(b.name()));
387        CapabilityAd::from_entries(entries)
388    }
389
390    /// Resolve `peer_ad` against the locally registered caps.
391    ///
392    /// Returns a [`NegotiatedCapabilities`] keyed by capability
393    /// name. The registry also caches each negotiated value so
394    /// later calls to [`CapabilityRegistry::current`] reflect the
395    /// most recent negotiation.
396    ///
397    /// Capabilities present in `peer_ad` but not registered
398    /// locally fall through silently: the negotiator can only
399    /// pick a value for capabilities both sides know about.
400    pub fn negotiate(&self, peer_ad: &CapabilityAd) -> NegotiatedCapabilities {
401        let result = negotiate_with_floor(self, peer_ad);
402        // Cache the negotiated bytes so `current()` sees them.
403        let mut neg = self.negotiated.write();
404        for (name, value) in result.iter() {
405            neg.insert(name.clone(), value.clone());
406        }
407        result
408    }
409
410    /// Return the currently active value for the named
411    /// capability, decoded with the registered cap's
412    /// [`Capability::decode_value`].
413    ///
414    /// Returns `None` when no capability with that name is
415    /// registered, when the type parameter `C` does not match
416    /// the registered cap, or when the stored bytes fail to
417    /// decode (which would be a registry bug).
418    ///
419    /// Before any negotiation has happened the floor value
420    /// (lowest-preference local value) is returned.
421    pub fn current<C: Capability>(&self, name: &str) -> Option<C::Value> {
422        let slot = self.slots.get(name)?;
423        let cap_arc = slot.cap.downcast_ref::<std::sync::Arc<C>>()?;
424        let neg = self.negotiated.read();
425        let bytes: &[u8] = neg
426            .get(name)
427            .map_or(slot.floor_bytes.as_slice(), Vec::as_slice);
428        cap_arc.decode_value(bytes)
429    }
430
431    /// Number of registered capabilities. Useful in tests and
432    /// diagnostics.
433    #[must_use]
434    pub fn len(&self) -> usize {
435        self.slots.len()
436    }
437
438    /// True when no capabilities have been registered.
439    #[must_use]
440    pub fn is_empty(&self) -> bool {
441        self.slots.is_empty()
442    }
443
444    /// Internal slot accessor used by the negotiator.
445    pub(crate) fn slots_for_negotiation(&self) -> &HashMap<&'static str, Slot> {
446        &self.slots
447    }
448}
449
450impl Slot {
451    pub(crate) fn floor_bytes(&self) -> &[u8] {
452        &self.floor_bytes
453    }
454
455    pub(crate) fn merge_bytes(&self, peer: &[Vec<u8>]) -> Option<Vec<u8>> {
456        (self.merge)(peer)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    struct U32Cap {
465        name: &'static str,
466        supported: Vec<u32>,
467    }
468    impl Capability for U32Cap {
469        type Value = u32;
470        fn name(&self) -> &'static str {
471            self.name
472        }
473        fn supported_values(&self) -> Vec<u32> {
474            self.supported.clone()
475        }
476        fn merge(&self, peer: &[u32]) -> Option<u32> {
477            self.supported
478                .iter()
479                .filter(|v| peer.contains(v))
480                .max()
481                .copied()
482        }
483        fn encode_value(&self, v: &u32) -> Vec<u8> {
484            v.to_le_bytes().to_vec()
485        }
486        fn decode_value(&self, b: &[u8]) -> Option<u32> {
487            <[u8; 4]>::try_from(b).ok().map(u32::from_le_bytes)
488        }
489    }
490
491    #[test]
492    fn ad_round_trips() {
493        let mut reg = CapabilityRegistry::new();
494        reg.register(U32Cap {
495            name: "framing",
496            supported: vec![1, 2],
497        });
498        reg.register(U32Cap {
499            name: "aae",
500            supported: vec![1],
501        });
502        let ad = reg.local_advertise();
503        let bytes = ad.encode();
504        let back = CapabilityAd::decode(&bytes).expect("decode");
505        assert_eq!(back, ad);
506    }
507
508    #[test]
509    fn ad_decode_rejects_bad_magic() {
510        let err = CapabilityAd::decode(&[0; 16]).unwrap_err();
511        assert!(matches!(err, CapabilityCodecError::BadMagic));
512    }
513
514    #[test]
515    fn ad_decode_rejects_truncated() {
516        let err = CapabilityAd::decode(&[]).unwrap_err();
517        assert!(matches!(err, CapabilityCodecError::Truncated));
518    }
519}