use dashmap::DashMap;
use super::descriptor::{SubprotocolDescriptor, SubprotocolVersion};
use super::stream_window::SUBPROTOCOL_STREAM_WINDOW;
use crate::adapter::net::behavior::capability::{CapabilityFilter, CapabilitySet};
use crate::adapter::net::compute::SUBPROTOCOL_MIGRATION;
use crate::adapter::net::state::causal::{SUBPROTOCOL_CAUSAL, SUBPROTOCOL_SNAPSHOT};
pub struct SubprotocolRegistry {
entries: DashMap<u16, SubprotocolDescriptor>,
}
impl SubprotocolRegistry {
pub fn new() -> Self {
Self {
entries: DashMap::new(),
}
}
pub fn with_defaults() -> Self {
let reg = Self::new();
reg.register(SubprotocolDescriptor::new(
SUBPROTOCOL_CAUSAL,
"causal",
SubprotocolVersion::new(1, 0),
));
reg.register(SubprotocolDescriptor::new(
SUBPROTOCOL_SNAPSHOT,
"snapshot",
SubprotocolVersion::new(1, 0),
));
reg.register(SubprotocolDescriptor::new(
SUBPROTOCOL_MIGRATION,
"migration",
SubprotocolVersion::new(1, 0),
));
reg.register(SubprotocolDescriptor::new(
super::SUBPROTOCOL_NEGOTIATION,
"negotiation",
SubprotocolVersion::new(1, 0),
));
reg.register(SubprotocolDescriptor::new(
SUBPROTOCOL_STREAM_WINDOW,
"stream-window",
SubprotocolVersion::new(1, 0),
));
reg
}
pub fn register(&self, descriptor: SubprotocolDescriptor) -> Option<SubprotocolDescriptor> {
self.entries.insert(descriptor.id, descriptor)
}
pub fn unregister(&self, id: u16) -> Option<SubprotocolDescriptor> {
self.entries.remove(&id).map(|(_, d)| d)
}
pub fn get(
&self,
id: u16,
) -> Option<dashmap::mapref::one::Ref<'_, u16, SubprotocolDescriptor>> {
self.entries.get(&id)
}
pub fn is_handled(&self, id: u16) -> bool {
self.entries.get(&id).is_some_and(|d| d.handler_present)
}
pub fn is_registered(&self, id: u16) -> bool {
self.entries.contains_key(&id)
}
pub fn list(&self) -> Vec<SubprotocolDescriptor> {
self.entries.iter().map(|e| e.value().clone()).collect()
}
pub fn count(&self) -> usize {
self.entries.len()
}
pub fn capability_filter_for(id: u16) -> CapabilityFilter {
CapabilityFilter::new().require_tag(format!("subprotocol:{:#06x}", id))
}
pub fn capability_tags(&self) -> Vec<String> {
let mut entries: Vec<_> = self
.entries
.iter()
.filter(|e| e.handler_present)
.map(|e| (e.id, e.capability_tag()))
.collect();
entries.sort_by_key(|(id, _)| *id);
entries.into_iter().map(|(_, tag)| tag).collect()
}
pub fn enrich_capabilities(&self, mut caps: CapabilitySet) -> CapabilitySet {
for tag in self.capability_tags() {
caps = caps.add_tag(tag);
}
caps
}
}
impl Default for SubprotocolRegistry {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for SubprotocolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SubprotocolRegistry")
.field("count", &self.entries.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::adapter::net::behavior::capability::CapabilitySet;
#[test]
fn test_empty_registry() {
let reg = SubprotocolRegistry::new();
assert_eq!(reg.count(), 0);
assert!(!reg.is_handled(0x0400));
assert!(!reg.is_registered(0x0400));
}
#[test]
fn test_with_defaults() {
let reg = SubprotocolRegistry::with_defaults();
assert!(reg.is_handled(SUBPROTOCOL_CAUSAL));
assert!(reg.is_handled(SUBPROTOCOL_SNAPSHOT));
assert!(reg.is_handled(SUBPROTOCOL_MIGRATION));
assert!(reg.count() >= 4); }
#[test]
fn test_register_and_lookup() {
let reg = SubprotocolRegistry::new();
let desc = SubprotocolDescriptor::new(0x1000, "vendor-test", SubprotocolVersion::new(1, 0));
assert!(reg.register(desc).is_none()); assert!(reg.is_handled(0x1000));
assert!(reg.is_registered(0x1000));
let retrieved = reg.get(0x1000).unwrap();
assert_eq!(retrieved.name, "vendor-test");
}
#[test]
fn test_register_upgrade() {
let reg = SubprotocolRegistry::new();
let v1 = SubprotocolDescriptor::new(0x1000, "test", SubprotocolVersion::new(1, 0));
reg.register(v1);
let v2 = SubprotocolDescriptor::new(0x1000, "test", SubprotocolVersion::new(2, 0));
let old = reg.register(v2);
assert!(old.is_some());
assert_eq!(old.unwrap().version, SubprotocolVersion::new(1, 0));
let current = reg.get(0x1000).unwrap();
assert_eq!(current.version, SubprotocolVersion::new(2, 0));
}
#[test]
fn test_unregister() {
let reg = SubprotocolRegistry::new();
reg.register(SubprotocolDescriptor::new(
0x1000,
"test",
SubprotocolVersion::new(1, 0),
));
let removed = reg.unregister(0x1000);
assert!(removed.is_some());
assert!(!reg.is_registered(0x1000));
assert_eq!(reg.count(), 0);
}
#[test]
fn test_forwarding_only() {
let reg = SubprotocolRegistry::new();
let desc = SubprotocolDescriptor::new(0x2000, "remote-only", SubprotocolVersion::new(1, 0))
.forwarding_only();
reg.register(desc);
assert!(reg.is_registered(0x2000));
assert!(!reg.is_handled(0x2000)); }
#[test]
fn test_capability_tags() {
let reg = SubprotocolRegistry::new();
reg.register(SubprotocolDescriptor::new(
0x1000,
"handled",
SubprotocolVersion::new(1, 0),
));
reg.register(
SubprotocolDescriptor::new(0x2000, "forwarded", SubprotocolVersion::new(1, 0))
.forwarding_only(),
);
let tags = reg.capability_tags();
assert_eq!(tags.len(), 1); assert!(tags[0].contains("0x1000"));
}
#[test]
fn test_capability_filter_for() {
let filter = SubprotocolRegistry::capability_filter_for(0x0400);
assert!(!filter.require_tags.is_empty());
assert_eq!(filter.require_tags[0], "subprotocol:0x0400");
}
#[test]
fn test_list() {
let reg = SubprotocolRegistry::new();
reg.register(SubprotocolDescriptor::new(
0x1000,
"a",
SubprotocolVersion::new(1, 0),
));
reg.register(SubprotocolDescriptor::new(
0x2000,
"b",
SubprotocolVersion::new(2, 0),
));
let list = reg.list();
assert_eq!(list.len(), 2);
}
#[test]
fn test_enrich_capabilities() {
let reg = SubprotocolRegistry::new();
reg.register(SubprotocolDescriptor::new(
0x0500,
"migration",
SubprotocolVersion::new(1, 0),
));
reg.register(
SubprotocolDescriptor::new(0x2000, "forwarded", SubprotocolVersion::new(1, 0))
.forwarding_only(),
);
let caps = CapabilitySet::new();
let enriched = reg.enrich_capabilities(caps);
assert!(enriched.has_tag("subprotocol:0x0500"));
assert!(!enriched.has_tag("subprotocol:0x2000")); }
#[test]
fn test_enrich_capabilities_with_defaults() {
let reg = SubprotocolRegistry::with_defaults();
let caps = reg.enrich_capabilities(CapabilitySet::new());
assert!(caps.has_tag("subprotocol:0x0400")); assert!(caps.has_tag("subprotocol:0x0401")); assert!(caps.has_tag("subprotocol:0x0500")); assert!(caps.has_tag("subprotocol:0x0600")); }
#[test]
fn test_enrich_preserves_existing_tags() {
let reg = SubprotocolRegistry::new();
reg.register(SubprotocolDescriptor::new(
0x0500,
"migration",
SubprotocolVersion::new(1, 0),
));
let caps = CapabilitySet::new().add_tag("custom:my-tag");
let enriched = reg.enrich_capabilities(caps);
assert!(enriched.has_tag("custom:my-tag"));
assert!(enriched.has_tag("subprotocol:0x0500"));
}
}