use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use super::registry::AggregatorRegistry;
use crate::adapter::net::cortex::rpc::{
RpcContext, RpcHandler, RpcHandlerError, RpcResponsePayload, RpcStatus,
};
use crate::adapter::net::subnet::SubnetId;
pub const REGISTRY_SERVICE: &str = "aggregator.registry";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum RegistryRequest {
List,
Spawn {
template_name: String,
group_name: String,
replica_count: u8,
},
Unregister {
group_name: String,
},
Scale {
group_name: String,
template_name: String,
target_replica_count: u8,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum RegistryResponse {
Groups(Vec<RegistryGroupSummary>),
Spawned(RegistryGroupSummary),
Unregistered {
existed: bool,
},
Scaled(RegistryGroupSummary),
Error(RegistryRpcError),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RegistryGroupSummary {
pub name: String,
pub group_seed: [u8; 32],
pub source_subnet: SubnetId,
pub fold_kinds: Vec<u16>,
pub replicas: Vec<RegistryReplicaSummary>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RegistryReplicaSummary {
pub generation: u64,
pub healthy: bool,
pub diagnostic: Option<String>,
pub placement_node_id: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum RegistryRpcError {
DecodeFailed(String),
UnknownTemplate(String),
DuplicateGroupName(String),
SpawnRejected(String),
SpawnNotSupported,
UnknownGroup(String),
ScaleRejected(String),
ScaleNotSupported,
}
pub type SpawnFn = Box<
dyn Fn(
SpawnRequest,
)
-> futures::future::BoxFuture<'static, Result<RegistryGroupSummary, RegistryRpcError>>
+ Send
+ Sync
+ 'static,
>;
#[derive(Debug, Clone)]
pub struct SpawnRequest {
pub template_name: String,
pub group_name: String,
pub replica_count: u8,
}
pub type ScaleFn = Box<
dyn Fn(
ScaleRequest,
)
-> futures::future::BoxFuture<'static, Result<RegistryGroupSummary, RegistryRpcError>>
+ Send
+ Sync
+ 'static,
>;
#[derive(Debug, Clone)]
pub struct ScaleRequest {
pub group_name: String,
pub template_name: String,
pub target_replica_count: u8,
}
pub struct RegistryReadHandler {
registry: Arc<AggregatorRegistry>,
}
impl RegistryReadHandler {
pub fn new(registry: Arc<AggregatorRegistry>) -> Self {
Self { registry }
}
}
#[async_trait]
impl RpcHandler for RegistryReadHandler {
async fn call(&self, ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
let request: RegistryRequest = match postcard::from_bytes(&ctx.payload.body) {
Ok(req) => req,
Err(e) => {
let response =
RegistryResponse::Error(RegistryRpcError::DecodeFailed(e.to_string()));
return Ok(encode_response(&response));
}
};
let response = answer(&self.registry, None, None, &request).await;
Ok(encode_response(&response))
}
}
pub struct RegistryHandler {
registry: Arc<AggregatorRegistry>,
spawner: Arc<SpawnFn>,
scaler: Option<Arc<ScaleFn>>,
}
impl RegistryHandler {
pub fn new(registry: Arc<AggregatorRegistry>, spawner: SpawnFn) -> Self {
Self {
registry,
spawner: Arc::new(spawner),
scaler: None,
}
}
pub fn with_scaler(mut self, scaler: ScaleFn) -> Self {
self.scaler = Some(Arc::new(scaler));
self
}
}
#[async_trait]
impl RpcHandler for RegistryHandler {
async fn call(&self, ctx: RpcContext) -> Result<RpcResponsePayload, RpcHandlerError> {
let request: RegistryRequest = match postcard::from_bytes(&ctx.payload.body) {
Ok(req) => req,
Err(e) => {
let response =
RegistryResponse::Error(RegistryRpcError::DecodeFailed(e.to_string()));
return Ok(encode_response(&response));
}
};
let response = answer(
&self.registry,
Some(&self.spawner),
self.scaler.as_deref(),
&request,
)
.await;
Ok(encode_response(&response))
}
}
pub(crate) async fn answer(
registry: &Arc<AggregatorRegistry>,
spawner: Option<&SpawnFn>,
scaler: Option<&ScaleFn>,
request: &RegistryRequest,
) -> RegistryResponse {
match request {
RegistryRequest::List => {
let entries = registry.entries();
let mut groups = Vec::with_capacity(entries.len());
for entry in entries {
groups.push(snapshot_group(&entry).await);
}
RegistryResponse::Groups(groups)
}
RegistryRequest::Spawn {
template_name,
group_name,
replica_count,
} => {
let Some(spawner) = spawner else {
return RegistryResponse::Error(RegistryRpcError::SpawnNotSupported);
};
if registry.get(group_name).is_some() {
return RegistryResponse::Error(RegistryRpcError::DuplicateGroupName(
group_name.clone(),
));
}
let req = SpawnRequest {
template_name: template_name.clone(),
group_name: group_name.clone(),
replica_count: *replica_count,
};
match (spawner)(req).await {
Ok(summary) => RegistryResponse::Spawned(summary),
Err(e) => RegistryResponse::Error(e),
}
}
RegistryRequest::Unregister { group_name } => match registry.unregister(group_name).await {
Ok(group) => {
group.stop().await;
RegistryResponse::Unregistered { existed: true }
}
Err(_) => RegistryResponse::Unregistered { existed: false },
},
RegistryRequest::Scale {
group_name,
template_name,
target_replica_count,
} => {
let Some(scaler) = scaler else {
return RegistryResponse::Error(RegistryRpcError::ScaleNotSupported);
};
if registry.get(group_name).is_none() {
return RegistryResponse::Error(RegistryRpcError::UnknownGroup(group_name.clone()));
}
if *target_replica_count == 0 {
return RegistryResponse::Error(RegistryRpcError::ScaleRejected(
"target_replica_count must be > 0".into(),
));
}
let req = ScaleRequest {
group_name: group_name.clone(),
template_name: template_name.clone(),
target_replica_count: *target_replica_count,
};
match (scaler)(req).await {
Ok(summary) => RegistryResponse::Scaled(summary),
Err(e) => RegistryResponse::Error(e),
}
}
}
}
pub async fn snapshot_group(entry: &Arc<super::AggregatorGroupEntry>) -> RegistryGroupSummary {
let snap = entry.snapshot().await;
let rows = build_rows(&snap);
let (source_subnet, fold_kinds) = match snap.replicas.first() {
Some(replica) => {
let cfg = replica.config();
(cfg.source_subnet, cfg.fold_kinds.clone())
}
None => (SubnetId::GLOBAL, Vec::new()),
};
RegistryGroupSummary {
name: entry.name.clone(),
group_seed: entry.group_seed,
source_subnet,
fold_kinds,
replicas: rows,
}
}
fn build_rows(snap: &super::EntrySnapshot) -> Vec<RegistryReplicaSummary> {
snap.replicas
.iter()
.enumerate()
.map(|(idx, replica)| {
let health = snap.healths.get(idx).cloned().unwrap_or(
crate::adapter::net::behavior::lifecycle::ReplicaHealth {
healthy: true,
diagnostic: None,
},
);
let placement_node_id = snap.placements.get(idx).map(|p| p.node_id);
RegistryReplicaSummary {
generation: replica.generation(),
healthy: health.healthy,
diagnostic: health.diagnostic,
placement_node_id,
}
})
.collect()
}
impl AggregatorRegistry {
pub fn install_registry_service(
self: &Arc<Self>,
mesh: &Arc<crate::adapter::net::MeshNode>,
) -> Result<crate::adapter::net::mesh_rpc::ServeHandle, crate::adapter::net::mesh_rpc::ServeError>
{
mesh.serve_rpc(
REGISTRY_SERVICE,
Arc::new(RegistryReadHandler::new(self.clone())),
)
}
pub fn install_registry_service_with_spawner(
self: &Arc<Self>,
mesh: &Arc<crate::adapter::net::MeshNode>,
spawner: SpawnFn,
) -> Result<crate::adapter::net::mesh_rpc::ServeHandle, crate::adapter::net::mesh_rpc::ServeError>
{
mesh.serve_rpc(
REGISTRY_SERVICE,
Arc::new(RegistryHandler::new(self.clone(), spawner)),
)
}
pub fn install_registry_service_with_handlers(
self: &Arc<Self>,
mesh: &Arc<crate::adapter::net::MeshNode>,
spawner: SpawnFn,
scaler: ScaleFn,
) -> Result<crate::adapter::net::mesh_rpc::ServeHandle, crate::adapter::net::mesh_rpc::ServeError>
{
mesh.serve_rpc(
REGISTRY_SERVICE,
Arc::new(RegistryHandler::new(self.clone(), spawner).with_scaler(scaler)),
)
}
}
fn encode_response(response: &RegistryResponse) -> RpcResponsePayload {
let body = match postcard::to_allocvec(response) {
Ok(b) => Bytes::from(b),
Err(e) => {
tracing::warn!(
error = %e,
"aggregator: registry response encode failed; replying with empty body",
);
Bytes::new()
}
};
RpcResponsePayload {
status: RpcStatus::Ok,
headers: Vec::new(),
body,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::adapter::net::behavior::aggregator::{
AggregatorConfig, AggregatorDaemon, AggregatorRegistry,
};
use crate::adapter::net::behavior::fold::capability::CapabilityFold;
use crate::adapter::net::behavior::fold::FoldKind;
use crate::adapter::net::behavior::lifecycle::LifecycleGroup;
use crate::adapter::net::identity::EntityKeypair;
use crate::adapter::net::{MeshNode, MeshNodeConfig, SubnetId};
use std::net::SocketAddr;
use std::time::Duration;
async fn build_mesh() -> Arc<MeshNode> {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let cfg = MeshNodeConfig::new(addr, [0x17u8; 32]);
Arc::new(
MeshNode::new(EntityKeypair::generate(), cfg)
.await
.expect("MeshNode::new"),
)
}
async fn spawn_group(name: &str, interval_ms: u64) -> LifecycleGroup<AggregatorDaemon> {
let _ = name;
let mesh = build_mesh().await;
let cfg = AggregatorConfig::new(SubnetId::GLOBAL)
.with_fold_kind(CapabilityFold::KIND_ID)
.with_interval(Duration::from_millis(interval_ms));
let cfg_clone = cfg.clone();
let mesh_clone = mesh.clone();
LifecycleGroup::<AggregatorDaemon>::spawn(2, [0xABu8; 32], move |_idx| {
Arc::new(AggregatorDaemon::new(cfg_clone.clone(), mesh_clone.clone()).expect("new"))
})
.await
.expect("spawn")
}
#[tokio::test]
async fn list_returns_every_registered_group() {
let registry = Arc::new(AggregatorRegistry::new());
registry
.register("alpha", spawn_group("alpha", 50).await)
.expect("register alpha");
registry
.register("beta", spawn_group("beta", 50).await)
.expect("register beta");
let response = answer(®istry, None, None, &RegistryRequest::List).await;
match response {
RegistryResponse::Groups(groups) => {
assert_eq!(groups.len(), 2);
let names: Vec<&str> = groups.iter().map(|g| g.name.as_str()).collect();
assert_eq!(names, vec!["alpha", "beta"]);
for g in &groups {
assert_eq!(g.replicas.len(), 2);
for r in &g.replicas {
assert!(r.healthy);
}
}
}
other => panic!("expected Groups, got {other:?}"),
}
for n in ["alpha", "beta"] {
let g = registry.unregister(n).await.expect("unregister");
g.stop().await;
}
}
#[tokio::test]
async fn list_against_empty_registry_returns_empty_groups() {
let registry = Arc::new(AggregatorRegistry::new());
let response = answer(®istry, None, None, &RegistryRequest::List).await;
match response {
RegistryResponse::Groups(groups) => assert!(groups.is_empty()),
other => panic!("expected empty Groups, got {other:?}"),
}
}
#[tokio::test]
async fn unregister_drives_group_shutdown_and_returns_existed_true() {
let registry = Arc::new(AggregatorRegistry::new());
registry
.register("agg", spawn_group("agg", 50).await)
.expect("register");
let response = answer(
®istry,
None,
None,
&RegistryRequest::Unregister {
group_name: "agg".into(),
},
)
.await;
match response {
RegistryResponse::Unregistered { existed } => assert!(existed),
other => panic!("expected Unregistered, got {other:?}"),
}
assert!(registry.is_empty());
}
#[tokio::test]
async fn unregister_unknown_group_returns_existed_false() {
let registry = Arc::new(AggregatorRegistry::new());
let response = answer(
®istry,
None,
None,
&RegistryRequest::Unregister {
group_name: "missing".into(),
},
)
.await;
match response {
RegistryResponse::Unregistered { existed } => assert!(!existed),
other => panic!("expected Unregistered, got {other:?}"),
}
}
#[tokio::test]
async fn spawn_without_spawner_returns_spawn_not_supported() {
let registry = Arc::new(AggregatorRegistry::new());
let response = answer(
®istry,
None,
None,
&RegistryRequest::Spawn {
template_name: "primary".into(),
group_name: "newgrp".into(),
replica_count: 2,
},
)
.await;
match response {
RegistryResponse::Error(RegistryRpcError::SpawnNotSupported) => {}
other => panic!("expected SpawnNotSupported, got {other:?}"),
}
}
#[tokio::test]
async fn spawn_with_spawner_round_trips_a_new_group() {
let registry: Arc<AggregatorRegistry> = Arc::new(AggregatorRegistry::new());
let registry_for_spawner = registry.clone();
let spawner: SpawnFn = Box::new(move |req: SpawnRequest| {
let registry = registry_for_spawner.clone();
Box::pin(async move {
if req.template_name != "primary" {
return Err(RegistryRpcError::UnknownTemplate(req.template_name));
}
let mesh = {
let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
let cfg = crate::adapter::net::MeshNodeConfig::new(addr, [0x17u8; 32]);
Arc::new(
crate::adapter::net::MeshNode::new(
crate::adapter::net::identity::EntityKeypair::generate(),
cfg,
)
.await
.map_err(|e| RegistryRpcError::SpawnRejected(format!("{e:?}")))?,
)
};
let cfg = crate::adapter::net::behavior::aggregator::AggregatorConfig::new(
crate::adapter::net::SubnetId::GLOBAL,
)
.with_fold_kind(
crate::adapter::net::behavior::fold::capability::CapabilityFold::KIND_ID,
)
.with_interval(std::time::Duration::from_millis(50));
let cfg_clone = cfg.clone();
let mesh_clone = mesh.clone();
let group = crate::adapter::net::behavior::lifecycle::LifecycleGroup::<
crate::adapter::net::behavior::aggregator::AggregatorDaemon,
>::spawn(req.replica_count, [0xCDu8; 32], move |_idx| {
Arc::new(
crate::adapter::net::behavior::aggregator::AggregatorDaemon::new(
cfg_clone.clone(),
mesh_clone.clone(),
)
.expect("new"),
)
})
.await
.map_err(|e| RegistryRpcError::SpawnRejected(format!("{e}")))?;
let entry = registry
.register(req.group_name.clone(), group)
.map_err(|e| RegistryRpcError::SpawnRejected(format!("{e}")))?;
Ok(snapshot_group(&entry).await)
})
});
let response = answer(
®istry,
Some(&spawner),
None,
&RegistryRequest::Spawn {
template_name: "primary".into(),
group_name: "dynamic".into(),
replica_count: 2,
},
)
.await;
match response {
RegistryResponse::Spawned(summary) => {
assert_eq!(summary.name, "dynamic");
assert_eq!(summary.replicas.len(), 2);
}
other => panic!("expected Spawned, got {other:?}"),
}
assert_eq!(registry.len(), 1);
let _ = answer(
®istry,
None,
None,
&RegistryRequest::Unregister {
group_name: "dynamic".into(),
},
)
.await;
}
#[tokio::test]
async fn spawn_with_unknown_template_surfaces_typed_error() {
let registry: Arc<AggregatorRegistry> = Arc::new(AggregatorRegistry::new());
let spawner: SpawnFn = Box::new(|req: SpawnRequest| {
Box::pin(async move { Err(RegistryRpcError::UnknownTemplate(req.template_name)) })
});
let response = answer(
®istry,
Some(&spawner),
None,
&RegistryRequest::Spawn {
template_name: "nope".into(),
group_name: "x".into(),
replica_count: 1,
},
)
.await;
match response {
RegistryResponse::Error(RegistryRpcError::UnknownTemplate(t)) => {
assert_eq!(t, "nope");
}
other => panic!("expected UnknownTemplate, got {other:?}"),
}
}
#[tokio::test]
async fn spawn_rejects_duplicate_group_name_before_invoking_spawner() {
let registry = Arc::new(AggregatorRegistry::new());
registry
.register("existing", spawn_group("existing", 50).await)
.expect("register existing");
let invoked = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let invoked_clone = invoked.clone();
let spawner: SpawnFn = Box::new(move |_req: SpawnRequest| {
invoked_clone.store(true, std::sync::atomic::Ordering::Release);
Box::pin(async { Err(RegistryRpcError::SpawnRejected("should not run".into())) })
});
let response = answer(
®istry,
Some(&spawner),
None,
&RegistryRequest::Spawn {
template_name: "anything".into(),
group_name: "existing".into(),
replica_count: 1,
},
)
.await;
match response {
RegistryResponse::Error(RegistryRpcError::DuplicateGroupName(n)) => {
assert_eq!(n, "existing");
}
other => panic!("expected DuplicateGroupName, got {other:?}"),
}
assert!(
!invoked.load(std::sync::atomic::Ordering::Acquire),
"spawner must not be invoked on duplicate-name short-circuit"
);
let g = registry.unregister("existing").await.expect("unregister");
g.stop().await;
}
#[test]
fn registry_request_response_round_trip_through_postcard() {
for req in [
RegistryRequest::List,
RegistryRequest::Spawn {
template_name: "primary".into(),
group_name: "newgrp".into(),
replica_count: 3,
},
RegistryRequest::Unregister {
group_name: "old".into(),
},
RegistryRequest::Scale {
group_name: "grow".into(),
template_name: "primary".into(),
target_replica_count: 5,
},
] {
let bytes = postcard::to_allocvec(&req).expect("encode req");
let decoded: RegistryRequest = postcard::from_bytes(&bytes).expect("decode req");
assert_eq!(req, decoded);
}
let group_summary = RegistryGroupSummary {
name: "test".into(),
group_seed: [0xCDu8; 32],
source_subnet: SubnetId::GLOBAL,
fold_kinds: vec![0x0001],
replicas: vec![RegistryReplicaSummary {
generation: 42,
healthy: false,
diagnostic: Some("stuck".into()),
placement_node_id: Some(0xBEEF),
}],
};
for resp in [
RegistryResponse::Groups(vec![group_summary.clone()]),
RegistryResponse::Spawned(group_summary.clone()),
RegistryResponse::Unregistered { existed: true },
RegistryResponse::Unregistered { existed: false },
RegistryResponse::Scaled(group_summary),
RegistryResponse::Error(RegistryRpcError::DecodeFailed("bad bytes".into())),
RegistryResponse::Error(RegistryRpcError::UnknownTemplate("missing".into())),
RegistryResponse::Error(RegistryRpcError::DuplicateGroupName("dup".into())),
RegistryResponse::Error(RegistryRpcError::SpawnRejected("oops".into())),
RegistryResponse::Error(RegistryRpcError::SpawnNotSupported),
RegistryResponse::Error(RegistryRpcError::UnknownGroup("ghost".into())),
RegistryResponse::Error(RegistryRpcError::ScaleRejected("template mismatch".into())),
RegistryResponse::Error(RegistryRpcError::ScaleNotSupported),
] {
let bytes = postcard::to_allocvec(&resp).expect("encode resp");
let decoded: RegistryResponse = postcard::from_bytes(&bytes).expect("decode resp");
assert_eq!(resp, decoded);
}
}
}