#![allow(clippy::too_long_first_doc_paragraph, clippy::missing_const_for_fn)]
use std::collections::BTreeSet;
use std::fmt;
use std::str::FromStr;
pub const PROTOCOL_VERSION: u32 = 1;
pub const PROTOCOL_HEADER: &str = "mnem-protocol";
pub const CAPABILITIES_HEADER: &str = "mnem-capabilities";
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[non_exhaustive]
pub enum Capability {
HaveSetBloom,
HaveSetRbsr,
PushNegotiate,
FilterSpec,
AtomicPush,
SignedPush,
SelfCertifyingRepoId,
}
impl Capability {
#[must_use]
pub const fn as_wire_str(&self) -> &'static str {
match self {
Self::HaveSetBloom => "have-set-bloom",
Self::HaveSetRbsr => "have-set-rbsr",
Self::PushNegotiate => "push-negotiate",
Self::FilterSpec => "filter-spec",
Self::AtomicPush => "atomic-push",
Self::SignedPush => "signed-push",
Self::SelfCertifyingRepoId => "self-certifying-repo-id",
}
}
#[must_use]
pub fn all() -> &'static [Self] {
const ALL: &[Capability] = &[
Capability::AtomicPush,
Capability::FilterSpec,
Capability::HaveSetBloom,
Capability::HaveSetRbsr,
Capability::PushNegotiate,
Capability::SelfCertifyingRepoId,
Capability::SignedPush,
];
ALL
}
}
impl fmt::Display for Capability {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_wire_str())
}
}
impl FromStr for Capability {
type Err = UnknownCapability;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"have-set-bloom" => Ok(Self::HaveSetBloom),
"have-set-rbsr" => Ok(Self::HaveSetRbsr),
"push-negotiate" => Ok(Self::PushNegotiate),
"filter-spec" => Ok(Self::FilterSpec),
"atomic-push" => Ok(Self::AtomicPush),
"signed-push" => Ok(Self::SignedPush),
"self-certifying-repo-id" => Ok(Self::SelfCertifyingRepoId),
_ => Err(UnknownCapability(s.to_owned())),
}
}
}
impl serde::Serialize for Capability {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(self.as_wire_str())
}
}
impl<'de> serde::Deserialize<'de> for Capability {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = <&str as serde::Deserialize>::deserialize(d)?;
s.parse().map_err(serde::de::Error::custom)
}
}
#[derive(Debug, thiserror::Error)]
#[error("unknown capability: {0}")]
pub struct UnknownCapability(pub String);
#[must_use]
pub fn parse_capabilities(s: &str) -> BTreeSet<Capability> {
s.split(',')
.map(str::trim)
.filter(|tok| !tok.is_empty())
.filter_map(|tok| tok.parse::<Capability>().ok())
.collect()
}
#[must_use]
pub fn serialize_capabilities<I>(caps: I) -> String
where
I: IntoIterator<Item = Capability>,
{
let set: BTreeSet<Capability> = caps.into_iter().collect();
let mut out = String::new();
let mut first = true;
let mut sorted: Vec<Capability> = set.into_iter().collect();
sorted.sort_by_key(Capability::as_wire_str);
for c in sorted {
if !first {
out.push(',');
}
first = false;
out.push_str(c.as_wire_str());
}
out
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct CapabilitySet(BTreeSet<Capability>);
impl CapabilitySet {
#[must_use]
pub fn new() -> Self {
Self(BTreeSet::new())
}
#[must_use]
pub fn with_caps<I: IntoIterator<Item = Capability>>(caps: I) -> Self {
Self(caps.into_iter().collect())
}
#[must_use]
pub fn all_known() -> Self {
Self(Capability::all().iter().copied().collect())
}
#[must_use]
pub fn parse(s: &str) -> Self {
Self(parse_capabilities(s))
}
#[must_use]
pub fn serialize(&self) -> String {
serialize_capabilities(self.0.iter().copied())
}
#[must_use]
pub fn intersect(&self, other: &Self) -> Self {
Self(self.0.intersection(&other.0).copied().collect())
}
#[must_use]
pub fn contains(&self, cap: Capability) -> bool {
self.0.contains(&cap)
}
#[must_use]
pub fn len(&self) -> usize {
self.0.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[must_use]
pub const fn as_set(&self) -> &BTreeSet<Capability> {
&self.0
}
}
impl From<BTreeSet<Capability>> for CapabilitySet {
fn from(s: BTreeSet<Capability>) -> Self {
Self(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn protocol_version_is_frozen() {
assert_eq!(PROTOCOL_VERSION, 1);
assert_eq!(PROTOCOL_HEADER, "mnem-protocol");
assert_eq!(CAPABILITIES_HEADER, "mnem-capabilities");
}
#[test]
fn capability_wire_strings_are_stable_kebab_case() {
assert_eq!(Capability::HaveSetBloom.as_wire_str(), "have-set-bloom");
assert_eq!(Capability::HaveSetRbsr.as_wire_str(), "have-set-rbsr");
assert_eq!(Capability::PushNegotiate.as_wire_str(), "push-negotiate");
assert_eq!(Capability::FilterSpec.as_wire_str(), "filter-spec");
assert_eq!(Capability::AtomicPush.as_wire_str(), "atomic-push");
assert_eq!(Capability::SignedPush.as_wire_str(), "signed-push");
assert_eq!(
Capability::SelfCertifyingRepoId.as_wire_str(),
"self-certifying-repo-id",
);
}
#[test]
fn capability_round_trips_through_str() {
for c in Capability::all() {
let s = c.as_wire_str();
let parsed: Capability = s.parse().unwrap();
assert_eq!(parsed, *c, "round-trip failed for {s}");
}
}
#[test]
fn unknown_capability_parses_as_err_not_panic() {
let res: Result<Capability, _> = "no-such-thing".parse();
assert!(res.is_err());
}
#[test]
fn parse_capabilities_tolerates_unknowns_and_whitespace() {
let caps = parse_capabilities(" have-set-bloom , no-such-thing,atomic-push ");
assert_eq!(caps.len(), 2);
assert!(caps.contains(&Capability::HaveSetBloom));
assert!(caps.contains(&Capability::AtomicPush));
}
#[test]
fn parse_capabilities_skips_empty_entries() {
let caps = parse_capabilities(",,have-set-bloom,,");
assert_eq!(caps.len(), 1);
assert!(caps.contains(&Capability::HaveSetBloom));
}
#[test]
fn serialize_capabilities_is_deterministic() {
let caps = [
Capability::SignedPush,
Capability::HaveSetBloom,
Capability::AtomicPush,
];
let a = serialize_capabilities(caps);
let b = serialize_capabilities(caps.iter().copied().rev());
assert_eq!(a, b, "output must be order-independent");
assert_eq!(a, "atomic-push,have-set-bloom,signed-push");
}
#[test]
fn serialize_then_parse_round_trips() {
let original: BTreeSet<Capability> = [
Capability::HaveSetBloom,
Capability::PushNegotiate,
Capability::FilterSpec,
]
.into_iter()
.collect();
let wire = serialize_capabilities(original.iter().copied());
let parsed = parse_capabilities(&wire);
assert_eq!(parsed, original);
}
#[test]
fn capability_serde_round_trips_through_json() {
let c = Capability::HaveSetBloom;
let j = serde_json::to_string(&c).unwrap();
assert_eq!(j, "\"have-set-bloom\"");
let back: Capability = serde_json::from_str(&j).unwrap();
assert_eq!(back, c);
}
#[test]
fn capability_set_intersect_empty() {
let a = CapabilitySet::new();
let b = CapabilitySet::all_known();
assert!(a.intersect(&b).is_empty());
assert!(b.intersect(&a).is_empty());
}
#[test]
fn capability_set_intersect_identical() {
let a = CapabilitySet::all_known();
let r = a.intersect(&a);
assert_eq!(r, a);
assert_eq!(r.len(), Capability::all().len());
}
#[test]
fn capability_set_intersect_disjoint() {
let a = CapabilitySet::with_caps([Capability::HaveSetBloom, Capability::AtomicPush]);
let b = CapabilitySet::with_caps([Capability::SignedPush, Capability::FilterSpec]);
let r = a.intersect(&b);
assert!(r.is_empty());
}
#[test]
fn capability_set_intersect_partial() {
let a = CapabilitySet::with_caps([
Capability::HaveSetBloom,
Capability::AtomicPush,
Capability::SignedPush,
]);
let b = CapabilitySet::with_caps([Capability::AtomicPush, Capability::FilterSpec]);
let r = a.intersect(&b);
assert_eq!(r.len(), 1);
assert!(r.contains(Capability::AtomicPush));
}
#[test]
fn capability_set_wire_round_trip() {
let a = CapabilitySet::with_caps([Capability::HaveSetBloom, Capability::AtomicPush]);
let wire = a.serialize();
assert_eq!(wire, "atomic-push,have-set-bloom");
let b = CapabilitySet::parse(&wire);
assert_eq!(a, b);
}
}