dynomite/cluster/capability/
registry.rs1use std::any::Any;
10use std::collections::HashMap;
11
12use parking_lot::RwLock;
13
14use crate::cluster::capability::negotiator::{negotiate_with_floor, NegotiatedCapabilities};
15
16const CAP_AD_MAGIC: [u8; 3] = *b"CAP";
21
22const CAP_AD_VERSION: u8 = 1;
24
25pub trait Capability: Any + Send + Sync + 'static {
54 type Value: Clone + Eq + Send + Sync + 'static;
56
57 fn name(&self) -> &'static str;
59
60 fn supported_values(&self) -> Vec<Self::Value>;
64
65 fn merge(&self, peer_supports: &[Self::Value]) -> Option<Self::Value>;
69
70 fn encode_value(&self, value: &Self::Value) -> Vec<u8>;
73
74 fn decode_value(&self, bytes: &[u8]) -> Option<Self::Value>;
78}
79
80#[derive(Debug, thiserror::Error)]
82pub enum CapabilityCodecError {
83 #[error("capability advertisement truncated")]
85 Truncated,
86 #[error("capability advertisement: invalid magic or version")]
88 BadMagic,
89 #[error("capability advertisement: non-ASCII capability name")]
91 NonAsciiName,
92 #[error("capability advertisement: too many entries ({0})")]
94 TooManyEntries(usize),
95}
96
97#[derive(Clone, Debug, Eq, PartialEq)]
101pub struct CapabilityAdEntry {
102 name: String,
103 supported: Vec<Vec<u8>>,
104}
105
106impl CapabilityAdEntry {
107 #[must_use]
109 pub fn new(name: String, supported: Vec<Vec<u8>>) -> Self {
110 Self { name, supported }
111 }
112
113 #[must_use]
115 pub fn name(&self) -> &str {
116 &self.name
117 }
118
119 #[must_use]
122 pub fn supported(&self) -> &[Vec<u8>] {
123 &self.supported
124 }
125}
126
127#[derive(Clone, Debug, Default, Eq, PartialEq)]
129pub struct CapabilityAd {
130 entries: Vec<CapabilityAdEntry>,
131}
132
133const CAP_AD_MAX_ENTRIES: usize = 1024;
137
138const CAP_AD_MAX_VALUE_LEN: usize = 16 * 1024;
140
141const CAP_AD_MAX_NAME_LEN: usize = 256;
143
144impl CapabilityAd {
145 #[must_use]
147 pub fn new() -> Self {
148 Self::default()
149 }
150
151 #[must_use]
153 pub fn from_entries(entries: Vec<CapabilityAdEntry>) -> Self {
154 Self { entries }
155 }
156
157 #[must_use]
159 pub fn entries(&self) -> &[CapabilityAdEntry] {
160 &self.entries
161 }
162
163 #[must_use]
179 pub fn encode(&self) -> Vec<u8> {
180 let mut out = Vec::with_capacity(8 + self.entries.len() * 32);
181 out.extend_from_slice(&CAP_AD_MAGIC);
182 out.push(CAP_AD_VERSION);
183 let count = u32::try_from(self.entries.len()).unwrap_or(u32::MAX);
184 out.extend_from_slice(&count.to_le_bytes());
185 for entry in &self.entries {
186 let name_bytes = entry.name.as_bytes();
187 let name_len = u16::try_from(name_bytes.len()).unwrap_or(u16::MAX);
188 out.extend_from_slice(&name_len.to_le_bytes());
189 out.extend_from_slice(name_bytes);
190 let val_count = u16::try_from(entry.supported.len()).unwrap_or(u16::MAX);
191 out.extend_from_slice(&val_count.to_le_bytes());
192 for value in &entry.supported {
193 let vlen = u32::try_from(value.len()).unwrap_or(u32::MAX);
194 out.extend_from_slice(&vlen.to_le_bytes());
195 out.extend_from_slice(value);
196 }
197 }
198 out
199 }
200
201 pub fn decode(mut bytes: &[u8]) -> Result<Self, CapabilityCodecError> {
204 if bytes.len() < CAP_AD_MAGIC.len() + 1 + 4 {
205 return Err(CapabilityCodecError::Truncated);
206 }
207 if bytes[..CAP_AD_MAGIC.len()] != CAP_AD_MAGIC {
208 return Err(CapabilityCodecError::BadMagic);
209 }
210 bytes = &bytes[CAP_AD_MAGIC.len()..];
211 if bytes[0] != CAP_AD_VERSION {
212 return Err(CapabilityCodecError::BadMagic);
213 }
214 bytes = &bytes[1..];
215 let count = read_u32(&mut bytes)?;
216 let count_us = usize::try_from(count).unwrap_or(usize::MAX);
217 if count_us > CAP_AD_MAX_ENTRIES {
218 return Err(CapabilityCodecError::TooManyEntries(count_us));
219 }
220 let mut entries = Vec::with_capacity(count_us);
221 for _ in 0..count_us {
222 let name_len = read_u16(&mut bytes)? as usize;
223 if name_len > CAP_AD_MAX_NAME_LEN {
224 return Err(CapabilityCodecError::TooManyEntries(name_len));
225 }
226 let name_bytes = read_slice(&mut bytes, name_len)?;
227 if !name_bytes.is_ascii() {
228 return Err(CapabilityCodecError::NonAsciiName);
229 }
230 let name = std::str::from_utf8(name_bytes)
233 .map_err(|_| CapabilityCodecError::NonAsciiName)?
234 .to_string();
235 let val_count = read_u16(&mut bytes)? as usize;
236 let mut supported = Vec::with_capacity(val_count);
237 for _ in 0..val_count {
238 let vlen = read_u32(&mut bytes)? as usize;
239 if vlen > CAP_AD_MAX_VALUE_LEN {
240 return Err(CapabilityCodecError::TooManyEntries(vlen));
241 }
242 let vbytes = read_slice(&mut bytes, vlen)?;
243 supported.push(vbytes.to_vec());
244 }
245 entries.push(CapabilityAdEntry::new(name, supported));
246 }
247 Ok(Self { entries })
248 }
249}
250
251fn read_slice<'a>(cur: &mut &'a [u8], len: usize) -> Result<&'a [u8], CapabilityCodecError> {
252 if cur.len() < len {
253 return Err(CapabilityCodecError::Truncated);
254 }
255 let (head, tail) = cur.split_at(len);
256 *cur = tail;
257 Ok(head)
258}
259
260fn read_u16(cur: &mut &[u8]) -> Result<u16, CapabilityCodecError> {
261 let bytes = read_slice(cur, 2)?;
262 let arr: [u8; 2] = bytes.try_into().expect("invariant: read_slice(2)");
263 Ok(u16::from_le_bytes(arr))
264}
265
266fn read_u32(cur: &mut &[u8]) -> Result<u32, CapabilityCodecError> {
267 let bytes = read_slice(cur, 4)?;
268 let arr: [u8; 4] = bytes.try_into().expect("invariant: read_slice(4)");
269 Ok(u32::from_le_bytes(arr))
270}
271
272type MergeFn = Box<dyn Fn(&[Vec<u8>]) -> Option<Vec<u8>> + Send + Sync>;
276
277pub(crate) struct Slot {
279 cap: Box<dyn Any + Send + Sync>,
283 supported_bytes: Vec<Vec<u8>>,
286 floor_bytes: Vec<u8>,
289 merge: MergeFn,
292}
293
294pub struct CapabilityRegistry {
298 slots: HashMap<&'static str, Slot>,
299 negotiated: RwLock<HashMap<String, Vec<u8>>>,
300}
301
302impl Default for CapabilityRegistry {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308impl std::fmt::Debug for CapabilityRegistry {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 f.debug_struct("CapabilityRegistry")
311 .field("registered", &self.slots.keys().collect::<Vec<_>>())
312 .finish_non_exhaustive()
313 }
314}
315
316impl CapabilityRegistry {
317 #[must_use]
319 pub fn new() -> Self {
320 Self {
321 slots: HashMap::new(),
322 negotiated: RwLock::new(HashMap::new()),
323 }
324 }
325
326 pub fn register<C: Capability>(&mut self, cap: C) {
331 let name = cap.name();
332 assert!(name.is_ascii(), "capability name must be ASCII: {name:?}");
333 let supported_bytes: Vec<Vec<u8>> = cap
334 .supported_values()
335 .iter()
336 .map(|v| cap.encode_value(v))
337 .collect();
338 let floor_bytes = supported_bytes
342 .first()
343 .cloned()
344 .expect("capability must declare at least one supported value");
345 let cap_arc = std::sync::Arc::new(cap);
348 let cap_for_merge = cap_arc.clone();
349 let merge: MergeFn = Box::new(move |peer_blobs: &[Vec<u8>]| {
350 let peer: Vec<C::Value> = peer_blobs
351 .iter()
352 .filter_map(|b| cap_for_merge.decode_value(b))
353 .collect();
354 cap_for_merge
355 .merge(&peer)
356 .map(|v| cap_for_merge.encode_value(&v))
357 });
358 let cap_any: Box<dyn Any + Send + Sync> = Box::new(cap_arc);
359 self.slots.insert(
360 name,
361 Slot {
362 cap: cap_any,
363 supported_bytes,
364 floor_bytes,
365 merge,
366 },
367 );
368 self.negotiated.write().remove(name);
371 }
372
373 #[must_use]
375 pub fn local_advertise(&self) -> CapabilityAd {
376 let mut entries: Vec<CapabilityAdEntry> = self
377 .slots
378 .iter()
379 .map(|(name, slot)| {
380 CapabilityAdEntry::new((*name).to_string(), slot.supported_bytes.clone())
381 })
382 .collect();
383 entries.sort_by(|a, b| a.name().cmp(b.name()));
387 CapabilityAd::from_entries(entries)
388 }
389
390 pub fn negotiate(&self, peer_ad: &CapabilityAd) -> NegotiatedCapabilities {
401 let result = negotiate_with_floor(self, peer_ad);
402 let mut neg = self.negotiated.write();
404 for (name, value) in result.iter() {
405 neg.insert(name.clone(), value.clone());
406 }
407 result
408 }
409
410 pub fn current<C: Capability>(&self, name: &str) -> Option<C::Value> {
422 let slot = self.slots.get(name)?;
423 let cap_arc = slot.cap.downcast_ref::<std::sync::Arc<C>>()?;
424 let neg = self.negotiated.read();
425 let bytes: &[u8] = neg
426 .get(name)
427 .map_or(slot.floor_bytes.as_slice(), Vec::as_slice);
428 cap_arc.decode_value(bytes)
429 }
430
431 #[must_use]
434 pub fn len(&self) -> usize {
435 self.slots.len()
436 }
437
438 #[must_use]
440 pub fn is_empty(&self) -> bool {
441 self.slots.is_empty()
442 }
443
444 pub(crate) fn slots_for_negotiation(&self) -> &HashMap<&'static str, Slot> {
446 &self.slots
447 }
448}
449
450impl Slot {
451 pub(crate) fn floor_bytes(&self) -> &[u8] {
452 &self.floor_bytes
453 }
454
455 pub(crate) fn merge_bytes(&self, peer: &[Vec<u8>]) -> Option<Vec<u8>> {
456 (self.merge)(peer)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 struct U32Cap {
465 name: &'static str,
466 supported: Vec<u32>,
467 }
468 impl Capability for U32Cap {
469 type Value = u32;
470 fn name(&self) -> &'static str {
471 self.name
472 }
473 fn supported_values(&self) -> Vec<u32> {
474 self.supported.clone()
475 }
476 fn merge(&self, peer: &[u32]) -> Option<u32> {
477 self.supported
478 .iter()
479 .filter(|v| peer.contains(v))
480 .max()
481 .copied()
482 }
483 fn encode_value(&self, v: &u32) -> Vec<u8> {
484 v.to_le_bytes().to_vec()
485 }
486 fn decode_value(&self, b: &[u8]) -> Option<u32> {
487 <[u8; 4]>::try_from(b).ok().map(u32::from_le_bytes)
488 }
489 }
490
491 #[test]
492 fn ad_round_trips() {
493 let mut reg = CapabilityRegistry::new();
494 reg.register(U32Cap {
495 name: "framing",
496 supported: vec![1, 2],
497 });
498 reg.register(U32Cap {
499 name: "aae",
500 supported: vec![1],
501 });
502 let ad = reg.local_advertise();
503 let bytes = ad.encode();
504 let back = CapabilityAd::decode(&bytes).expect("decode");
505 assert_eq!(back, ad);
506 }
507
508 #[test]
509 fn ad_decode_rejects_bad_magic() {
510 let err = CapabilityAd::decode(&[0; 16]).unwrap_err();
511 assert!(matches!(err, CapabilityCodecError::BadMagic));
512 }
513
514 #[test]
515 fn ad_decode_rejects_truncated() {
516 let err = CapabilityAd::decode(&[]).unwrap_err();
517 assert!(matches!(err, CapabilityCodecError::Truncated));
518 }
519}