use std::any::Any;
use std::collections::HashMap;
use parking_lot::RwLock;
use crate::cluster::capability::negotiator::{negotiate_with_floor, NegotiatedCapabilities};
const CAP_AD_MAGIC: [u8; 3] = *b"CAP";
const CAP_AD_VERSION: u8 = 1;
pub trait Capability: Any + Send + Sync + 'static {
type Value: Clone + Eq + Send + Sync + 'static;
fn name(&self) -> &'static str;
fn supported_values(&self) -> Vec<Self::Value>;
fn merge(&self, peer_supports: &[Self::Value]) -> Option<Self::Value>;
fn encode_value(&self, value: &Self::Value) -> Vec<u8>;
fn decode_value(&self, bytes: &[u8]) -> Option<Self::Value>;
}
#[derive(Debug, thiserror::Error)]
pub enum CapabilityCodecError {
#[error("capability advertisement truncated")]
Truncated,
#[error("capability advertisement: invalid magic or version")]
BadMagic,
#[error("capability advertisement: non-ASCII capability name")]
NonAsciiName,
#[error("capability advertisement: too many entries ({0})")]
TooManyEntries(usize),
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct CapabilityAdEntry {
name: String,
supported: Vec<Vec<u8>>,
}
impl CapabilityAdEntry {
#[must_use]
pub fn new(name: String, supported: Vec<Vec<u8>>) -> Self {
Self { name, supported }
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn supported(&self) -> &[Vec<u8>] {
&self.supported
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct CapabilityAd {
entries: Vec<CapabilityAdEntry>,
}
const CAP_AD_MAX_ENTRIES: usize = 1024;
const CAP_AD_MAX_VALUE_LEN: usize = 16 * 1024;
const CAP_AD_MAX_NAME_LEN: usize = 256;
impl CapabilityAd {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_entries(entries: Vec<CapabilityAdEntry>) -> Self {
Self { entries }
}
#[must_use]
pub fn entries(&self) -> &[CapabilityAdEntry] {
&self.entries
}
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(8 + self.entries.len() * 32);
out.extend_from_slice(&CAP_AD_MAGIC);
out.push(CAP_AD_VERSION);
let count = u32::try_from(self.entries.len()).unwrap_or(u32::MAX);
out.extend_from_slice(&count.to_le_bytes());
for entry in &self.entries {
let name_bytes = entry.name.as_bytes();
let name_len = u16::try_from(name_bytes.len()).unwrap_or(u16::MAX);
out.extend_from_slice(&name_len.to_le_bytes());
out.extend_from_slice(name_bytes);
let val_count = u16::try_from(entry.supported.len()).unwrap_or(u16::MAX);
out.extend_from_slice(&val_count.to_le_bytes());
for value in &entry.supported {
let vlen = u32::try_from(value.len()).unwrap_or(u32::MAX);
out.extend_from_slice(&vlen.to_le_bytes());
out.extend_from_slice(value);
}
}
out
}
pub fn decode(mut bytes: &[u8]) -> Result<Self, CapabilityCodecError> {
if bytes.len() < CAP_AD_MAGIC.len() + 1 + 4 {
return Err(CapabilityCodecError::Truncated);
}
if bytes[..CAP_AD_MAGIC.len()] != CAP_AD_MAGIC {
return Err(CapabilityCodecError::BadMagic);
}
bytes = &bytes[CAP_AD_MAGIC.len()..];
if bytes[0] != CAP_AD_VERSION {
return Err(CapabilityCodecError::BadMagic);
}
bytes = &bytes[1..];
let count = read_u32(&mut bytes)?;
let count_us = usize::try_from(count).unwrap_or(usize::MAX);
if count_us > CAP_AD_MAX_ENTRIES {
return Err(CapabilityCodecError::TooManyEntries(count_us));
}
let mut entries = Vec::with_capacity(count_us);
for _ in 0..count_us {
let name_len = read_u16(&mut bytes)? as usize;
if name_len > CAP_AD_MAX_NAME_LEN {
return Err(CapabilityCodecError::TooManyEntries(name_len));
}
let name_bytes = read_slice(&mut bytes, name_len)?;
if !name_bytes.is_ascii() {
return Err(CapabilityCodecError::NonAsciiName);
}
let name = std::str::from_utf8(name_bytes)
.map_err(|_| CapabilityCodecError::NonAsciiName)?
.to_string();
let val_count = read_u16(&mut bytes)? as usize;
let mut supported = Vec::with_capacity(val_count);
for _ in 0..val_count {
let vlen = read_u32(&mut bytes)? as usize;
if vlen > CAP_AD_MAX_VALUE_LEN {
return Err(CapabilityCodecError::TooManyEntries(vlen));
}
let vbytes = read_slice(&mut bytes, vlen)?;
supported.push(vbytes.to_vec());
}
entries.push(CapabilityAdEntry::new(name, supported));
}
Ok(Self { entries })
}
}
fn read_slice<'a>(cur: &mut &'a [u8], len: usize) -> Result<&'a [u8], CapabilityCodecError> {
if cur.len() < len {
return Err(CapabilityCodecError::Truncated);
}
let (head, tail) = cur.split_at(len);
*cur = tail;
Ok(head)
}
fn read_u16(cur: &mut &[u8]) -> Result<u16, CapabilityCodecError> {
let bytes = read_slice(cur, 2)?;
let arr: [u8; 2] = bytes.try_into().expect("invariant: read_slice(2)");
Ok(u16::from_le_bytes(arr))
}
fn read_u32(cur: &mut &[u8]) -> Result<u32, CapabilityCodecError> {
let bytes = read_slice(cur, 4)?;
let arr: [u8; 4] = bytes.try_into().expect("invariant: read_slice(4)");
Ok(u32::from_le_bytes(arr))
}
type MergeFn = Box<dyn Fn(&[Vec<u8>]) -> Option<Vec<u8>> + Send + Sync>;
pub(crate) struct Slot {
cap: Box<dyn Any + Send + Sync>,
supported_bytes: Vec<Vec<u8>>,
floor_bytes: Vec<u8>,
merge: MergeFn,
}
pub struct CapabilityRegistry {
slots: HashMap<&'static str, Slot>,
negotiated: RwLock<HashMap<String, Vec<u8>>>,
}
impl Default for CapabilityRegistry {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for CapabilityRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CapabilityRegistry")
.field("registered", &self.slots.keys().collect::<Vec<_>>())
.finish_non_exhaustive()
}
}
impl CapabilityRegistry {
#[must_use]
pub fn new() -> Self {
Self {
slots: HashMap::new(),
negotiated: RwLock::new(HashMap::new()),
}
}
pub fn register<C: Capability>(&mut self, cap: C) {
let name = cap.name();
assert!(name.is_ascii(), "capability name must be ASCII: {name:?}");
let supported_bytes: Vec<Vec<u8>> = cap
.supported_values()
.iter()
.map(|v| cap.encode_value(v))
.collect();
let floor_bytes = supported_bytes
.first()
.cloned()
.expect("capability must declare at least one supported value");
let cap_arc = std::sync::Arc::new(cap);
let cap_for_merge = cap_arc.clone();
let merge: MergeFn = Box::new(move |peer_blobs: &[Vec<u8>]| {
let peer: Vec<C::Value> = peer_blobs
.iter()
.filter_map(|b| cap_for_merge.decode_value(b))
.collect();
cap_for_merge
.merge(&peer)
.map(|v| cap_for_merge.encode_value(&v))
});
let cap_any: Box<dyn Any + Send + Sync> = Box::new(cap_arc);
self.slots.insert(
name,
Slot {
cap: cap_any,
supported_bytes,
floor_bytes,
merge,
},
);
self.negotiated.write().remove(name);
}
#[must_use]
pub fn local_advertise(&self) -> CapabilityAd {
let mut entries: Vec<CapabilityAdEntry> = self
.slots
.iter()
.map(|(name, slot)| {
CapabilityAdEntry::new((*name).to_string(), slot.supported_bytes.clone())
})
.collect();
entries.sort_by(|a, b| a.name().cmp(b.name()));
CapabilityAd::from_entries(entries)
}
pub fn negotiate(&self, peer_ad: &CapabilityAd) -> NegotiatedCapabilities {
let result = negotiate_with_floor(self, peer_ad);
let mut neg = self.negotiated.write();
for (name, value) in result.iter() {
neg.insert(name.clone(), value.clone());
}
result
}
pub fn current<C: Capability>(&self, name: &str) -> Option<C::Value> {
let slot = self.slots.get(name)?;
let cap_arc = slot.cap.downcast_ref::<std::sync::Arc<C>>()?;
let neg = self.negotiated.read();
let bytes: &[u8] = neg
.get(name)
.map_or(slot.floor_bytes.as_slice(), Vec::as_slice);
cap_arc.decode_value(bytes)
}
#[must_use]
pub fn len(&self) -> usize {
self.slots.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.slots.is_empty()
}
pub(crate) fn slots_for_negotiation(&self) -> &HashMap<&'static str, Slot> {
&self.slots
}
}
impl Slot {
pub(crate) fn floor_bytes(&self) -> &[u8] {
&self.floor_bytes
}
pub(crate) fn merge_bytes(&self, peer: &[Vec<u8>]) -> Option<Vec<u8>> {
(self.merge)(peer)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct U32Cap {
name: &'static str,
supported: Vec<u32>,
}
impl Capability for U32Cap {
type Value = u32;
fn name(&self) -> &'static str {
self.name
}
fn supported_values(&self) -> Vec<u32> {
self.supported.clone()
}
fn merge(&self, peer: &[u32]) -> Option<u32> {
self.supported
.iter()
.filter(|v| peer.contains(v))
.max()
.copied()
}
fn encode_value(&self, v: &u32) -> Vec<u8> {
v.to_le_bytes().to_vec()
}
fn decode_value(&self, b: &[u8]) -> Option<u32> {
<[u8; 4]>::try_from(b).ok().map(u32::from_le_bytes)
}
}
#[test]
fn ad_round_trips() {
let mut reg = CapabilityRegistry::new();
reg.register(U32Cap {
name: "framing",
supported: vec![1, 2],
});
reg.register(U32Cap {
name: "aae",
supported: vec![1],
});
let ad = reg.local_advertise();
let bytes = ad.encode();
let back = CapabilityAd::decode(&bytes).expect("decode");
assert_eq!(back, ad);
}
#[test]
fn ad_decode_rejects_bad_magic() {
let err = CapabilityAd::decode(&[0; 16]).unwrap_err();
assert!(matches!(err, CapabilityCodecError::BadMagic));
}
#[test]
fn ad_decode_rejects_truncated() {
let err = CapabilityAd::decode(&[]).unwrap_err();
assert!(matches!(err, CapabilityCodecError::Truncated));
}
}