use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock};
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use super::placement::PlacementFilter;
struct RegisteredEntry {
filter: Arc<dyn PlacementFilter>,
binding: String,
}
pub struct PlacementFilterRegistry {
filters: DashMap<String, RegisteredEntry>,
invocations: DashMap<String, AtomicU64>,
}
impl std::fmt::Debug for PlacementFilterRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ids: Vec<String> = self
.filters
.iter()
.map(|entry| entry.key().clone())
.collect();
f.debug_struct("PlacementFilterRegistry")
.field("len", &ids.len())
.field("ids", &ids)
.finish()
}
}
impl PlacementFilterRegistry {
pub fn new() -> Self {
Self {
filters: DashMap::new(),
invocations: DashMap::new(),
}
}
pub fn register(
&self,
id: String,
filter: Arc<dyn PlacementFilter>,
binding: impl Into<String>,
) -> bool {
let binding = binding.into();
match self.filters.entry(id) {
Entry::Occupied(_) => false,
Entry::Vacant(slot) => {
self.invocations
.entry(binding.clone())
.or_insert_with(|| AtomicU64::new(0));
slot.insert(RegisteredEntry { filter, binding });
true
}
}
}
pub fn get(&self, id: &str) -> Option<Arc<dyn PlacementFilter>> {
let entry = self.filters.get(id)?;
if let Some(counter) = self.invocations.get(&entry.binding) {
counter.fetch_add(1, Ordering::Relaxed);
} else {
self.invocations
.entry(entry.binding.clone())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
Some(entry.filter.clone())
}
pub fn unregister(&self, id: &str) -> bool {
self.filters.remove(id).is_some()
}
pub fn contains(&self, id: &str) -> bool {
self.filters.contains_key(id)
}
pub fn len(&self) -> usize {
self.filters.len()
}
pub fn is_empty(&self) -> bool {
self.filters.is_empty()
}
pub fn clear(&self) {
self.filters.clear();
self.invocations.clear();
}
pub fn invocation_count(&self, binding: &str) -> u64 {
self.invocations
.get(binding)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn invocations_by_binding(&self) -> HashMap<String, u64> {
self.invocations
.iter()
.map(|r| (r.key().clone(), r.value().load(Ordering::Relaxed)))
.collect()
}
}
impl Default for PlacementFilterRegistry {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_REGISTRY: OnceLock<PlacementFilterRegistry> = OnceLock::new();
pub fn global_placement_filter_registry() -> &'static PlacementFilterRegistry {
GLOBAL_REGISTRY.get_or_init(PlacementFilterRegistry::new)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::adapter::net::behavior::placement::{Artifact, NodeId};
struct FixedFilter(f32);
impl PlacementFilter for FixedFilter {
fn placement_score(&self, _: &NodeId, _: &Artifact<'_>) -> Option<f32> {
Some(self.0)
}
}
#[test]
fn register_and_get_returns_same_filter() {
let reg = PlacementFilterRegistry::new();
let filter: Arc<dyn PlacementFilter> = Arc::new(FixedFilter(0.7));
assert!(reg.register("pf-1".into(), filter.clone(), "test"));
assert_eq!(reg.len(), 1);
assert!(reg.contains("pf-1"));
let got = reg
.get("pf-1")
.expect("registered filter must be retrievable");
let req = crate::adapter::net::behavior::capability::CapabilitySet::default();
let opt = crate::adapter::net::behavior::capability::CapabilitySet::default();
let artifact = Artifact::Daemon {
daemon_id: [0u8; 32],
required: &req,
optional: &opt,
};
assert_eq!(got.placement_score(&0x1234, &artifact), Some(0.7));
}
#[test]
fn register_refuses_to_overwrite_existing_id() {
let reg = PlacementFilterRegistry::new();
let original: Arc<dyn PlacementFilter> = Arc::new(FixedFilter(0.5));
let challenger: Arc<dyn PlacementFilter> = Arc::new(FixedFilter(0.9));
assert!(reg.register("pf-1".into(), original, "test"));
assert!(
!reg.register("pf-1".into(), challenger, "test"),
"second register call must report failure"
);
let req = crate::adapter::net::behavior::capability::CapabilitySet::default();
let opt = crate::adapter::net::behavior::capability::CapabilitySet::default();
let artifact = Artifact::Daemon {
daemon_id: [0u8; 32],
required: &req,
optional: &opt,
};
let got = reg.get("pf-1").unwrap();
assert_eq!(got.placement_score(&0x1234, &artifact), Some(0.5));
}
#[test]
fn register_collision_does_not_pre_create_counter_for_failed_binding() {
let reg = PlacementFilterRegistry::new();
let original: Arc<dyn PlacementFilter> = Arc::new(FixedFilter(0.5));
let challenger: Arc<dyn PlacementFilter> = Arc::new(FixedFilter(0.9));
assert!(reg.register("pf-1".into(), original, "alpha"));
assert!(!reg.register("pf-1".into(), challenger, "beta"));
let invocations = reg.invocations_by_binding();
assert!(
!invocations.contains_key("beta"),
"failed register must not create a phantom binding counter; got {invocations:?}",
);
assert_eq!(invocations.get("alpha").copied(), Some(0));
}
#[test]
fn unregister_returns_true_only_on_first_call() {
let reg = PlacementFilterRegistry::new();
reg.register("pf-1".into(), Arc::new(FixedFilter(0.3)), "test");
assert!(reg.unregister("pf-1"));
assert!(!reg.contains("pf-1"));
assert!(reg.is_empty());
assert!(!reg.unregister("pf-1"));
}
#[test]
fn get_clone_outlives_unregister() {
let reg = PlacementFilterRegistry::new();
reg.register("pf-1".into(), Arc::new(FixedFilter(0.42)), "test");
let held = reg.get("pf-1").expect("filter is registered");
assert!(reg.unregister("pf-1"));
assert!(!reg.contains("pf-1"));
let req = crate::adapter::net::behavior::capability::CapabilitySet::default();
let opt = crate::adapter::net::behavior::capability::CapabilitySet::default();
let artifact = Artifact::Daemon {
daemon_id: [0u8; 32],
required: &req,
optional: &opt,
};
assert_eq!(held.placement_score(&0x1234, &artifact), Some(0.42));
}
#[test]
fn get_unknown_id_returns_none() {
let reg = PlacementFilterRegistry::new();
assert!(reg.get("pf-missing").is_none());
}
#[test]
fn clear_drops_every_registration() {
let reg = PlacementFilterRegistry::new();
reg.register("pf-1".into(), Arc::new(FixedFilter(0.1)), "test");
reg.register("pf-2".into(), Arc::new(FixedFilter(0.2)), "test");
reg.register("pf-3".into(), Arc::new(FixedFilter(0.3)), "test");
assert_eq!(reg.len(), 3);
reg.clear();
assert_eq!(reg.len(), 0);
assert!(reg.get("pf-1").is_none());
}
#[test]
fn concurrent_registers_under_unique_keys_all_succeed() {
let reg = Arc::new(PlacementFilterRegistry::new());
let n = 16usize;
let handles: Vec<_> = (0..n)
.map(|i| {
let reg = reg.clone();
std::thread::spawn(move || {
let f: Arc<dyn PlacementFilter> = Arc::new(FixedFilter(i as f32 / n as f32));
assert!(reg.register(format!("pf-{i}"), f, "test"));
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(reg.len(), n);
}
#[test]
fn global_singleton_is_shared_across_calls() {
let reg_a = global_placement_filter_registry();
let reg_b = global_placement_filter_registry();
assert!(std::ptr::eq(reg_a, reg_b));
let id = "pf-singleton-test-unique-key";
assert!(reg_a.register(id.into(), Arc::new(FixedFilter(0.6)), "test"));
assert!(reg_b.contains(id));
reg_b.unregister(id);
}
#[test]
fn get_increments_per_binding_invocation_counter() {
let reg = PlacementFilterRegistry::new();
reg.register("pf-1".into(), Arc::new(FixedFilter(1.0)), "node");
assert_eq!(reg.invocation_count("node"), 0);
let _ = reg.get("pf-1");
let _ = reg.get("pf-1");
let _ = reg.get("pf-1");
assert_eq!(reg.invocation_count("node"), 3);
let _ = reg.get("pf-missing");
assert_eq!(reg.invocation_count("node"), 3);
}
#[test]
fn invocation_counter_aggregates_across_ids_within_binding() {
let reg = PlacementFilterRegistry::new();
reg.register("pf-1".into(), Arc::new(FixedFilter(1.0)), "node");
reg.register("pf-2".into(), Arc::new(FixedFilter(0.5)), "node");
reg.register("pf-py".into(), Arc::new(FixedFilter(0.7)), "python");
let _ = reg.get("pf-1");
let _ = reg.get("pf-1");
let _ = reg.get("pf-2");
let _ = reg.get("pf-py");
assert_eq!(reg.invocation_count("node"), 3);
assert_eq!(reg.invocation_count("python"), 1);
assert_eq!(reg.invocation_count("go"), 0);
}
#[test]
fn invocations_by_binding_returns_full_snapshot() {
let reg = PlacementFilterRegistry::new();
reg.register("pf-1".into(), Arc::new(FixedFilter(1.0)), "node");
reg.register("pf-2".into(), Arc::new(FixedFilter(1.0)), "python");
reg.register("pf-3".into(), Arc::new(FixedFilter(1.0)), "go");
let _ = reg.get("pf-1");
let _ = reg.get("pf-1");
let _ = reg.get("pf-2");
let snap = reg.invocations_by_binding();
assert_eq!(snap.get("node").copied(), Some(2));
assert_eq!(snap.get("python").copied(), Some(1));
assert_eq!(snap.get("go").copied(), Some(0));
}
#[test]
fn unregister_preserves_invocation_counters() {
let reg = PlacementFilterRegistry::new();
reg.register("pf-1".into(), Arc::new(FixedFilter(1.0)), "node");
let _ = reg.get("pf-1");
let _ = reg.get("pf-1");
assert_eq!(reg.invocation_count("node"), 2);
reg.unregister("pf-1");
assert_eq!(
reg.invocation_count("node"),
2,
"counter must survive unregister (cumulative semantics)",
);
reg.register("pf-1".into(), Arc::new(FixedFilter(1.0)), "node");
let _ = reg.get("pf-1");
assert_eq!(
reg.invocation_count("node"),
3,
"counter must accumulate across re-registrations",
);
}
#[test]
fn clear_resets_invocation_counters() {
let reg = PlacementFilterRegistry::new();
reg.register("pf-1".into(), Arc::new(FixedFilter(1.0)), "node");
let _ = reg.get("pf-1");
assert_eq!(reg.invocation_count("node"), 1);
reg.clear();
assert_eq!(
reg.invocation_count("node"),
0,
"clear() resets counters for test isolation",
);
}
}