use crate::dm::DmCapabilities;
use crate::dm_capability::{
now_unix_ms, CapabilityAdvert, CapabilityStore, ADVERT_PUBLISH_INTERVAL_SECS,
DM_CAPABILITY_TOPIC,
};
use crate::error::{NetworkError, NetworkResult};
use crate::gossip::{PubSubManager, SigningContext};
use crate::identity::{AgentId, MachineId};
use bytes::Bytes;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
pub const ADVERT_PROTOCOL_VERSION: u16 = 1;
const FIRST_PUBLISH_DELAY_MS: u64 = 250;
const STARTUP_BURST_INTERVALS_MS: &[u64] = &[5_000, 10_000, 20_000, 45_000];
pub struct CapabilityAdvertService {
publisher: JoinHandle<()>,
subscriber: JoinHandle<()>,
}
impl CapabilityAdvertService {
#[allow(clippy::too_many_arguments)]
pub async fn spawn(
pubsub: Arc<PubSubManager>,
signing: Arc<SigningContext>,
self_agent_id: AgentId,
self_machine_id: MachineId,
caps_rx: tokio::sync::watch::Receiver<DmCapabilities>,
store: Arc<CapabilityStore>,
publish_interval: Duration,
) -> NetworkResult<Self> {
let mut subscription = pubsub.subscribe(DM_CAPABILITY_TOPIC.to_string()).await;
let store_sub = Arc::clone(&store);
let self_agent_for_sub = self_agent_id;
let rebroadcast_pubsub = Arc::clone(&pubsub);
let subscriber = tokio::spawn(async move {
let mut rebroadcast_state: std::collections::HashMap<
([u8; 32], u64),
std::time::Instant,
> = std::collections::HashMap::new();
const REBROADCAST_MIN_INTERVAL: std::time::Duration =
std::time::Duration::from_secs(20);
while let Some(message) = subscription.recv().await {
let raw_payload = message.payload.clone();
let (pubsub_sender, sender_pubkey) =
match (message.sender, message.sender_public_key.as_deref()) {
(Some(s), Some(pk)) if message.verified => (s, pk.to_vec()),
_ => continue,
};
if pubsub_sender == self_agent_for_sub {
continue;
}
let advert: CapabilityAdvert = match postcard::from_bytes(&message.payload) {
Ok(a) => a,
Err(_) => continue,
};
if advert.protocol_version != ADVERT_PROTOCOL_VERSION {
continue;
}
if advert.agent_id != *pubsub_sender.as_bytes() {
continue;
}
if !verify_advert_signature(&advert, &sender_pubkey) {
continue;
}
store_sub.insert(
AgentId(advert.agent_id),
MachineId(advert.machine_id),
advert.capabilities,
);
tracing::debug!(
"cached capability advert from {}",
hex::encode(advert.agent_id)
);
let key = (advert.agent_id, advert.created_at_unix_ms);
let should_forward = match rebroadcast_state.get(&key) {
None => true,
Some(last) => last.elapsed() >= REBROADCAST_MIN_INTERVAL,
};
if should_forward {
rebroadcast_state.insert(key, std::time::Instant::now());
if rebroadcast_state.len() > 1024 {
let cutoff =
std::time::Instant::now() - std::time::Duration::from_secs(3600);
rebroadcast_state.retain(|_, t| *t >= cutoff);
}
let pubsub = Arc::clone(&rebroadcast_pubsub);
tokio::spawn(async move {
if let Err(e) = pubsub
.publish(DM_CAPABILITY_TOPIC.to_string(), raw_payload)
.await
{
tracing::debug!("capability advert re-broadcast failed: {e}");
}
});
}
}
tracing::debug!("capability advert subscriber exited");
});
let publisher_pubsub = Arc::clone(&pubsub);
let publisher_signing = Arc::clone(&signing);
let mut publisher_caps_rx = caps_rx;
let publisher = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(FIRST_PUBLISH_DELAY_MS)).await;
let mut burst_idx: usize = 0;
loop {
let caps_snapshot = publisher_caps_rx.borrow().clone();
match build_signed_advert(
&publisher_signing,
self_agent_id,
self_machine_id,
caps_snapshot,
) {
Ok(bytes) => {
if let Err(e) = publisher_pubsub
.publish(DM_CAPABILITY_TOPIC.to_string(), Bytes::from(bytes))
.await
{
tracing::warn!("capability advert publish failed: {e}");
} else {
tracing::debug!("capability advert published");
}
}
Err(e) => tracing::warn!("capability advert build failed: {e}"),
}
let next_delay = if burst_idx < STARTUP_BURST_INTERVALS_MS.len() {
let d = Duration::from_millis(STARTUP_BURST_INTERVALS_MS[burst_idx]);
burst_idx += 1;
d
} else {
publish_interval
};
tokio::select! {
_ = tokio::time::sleep(next_delay) => {}
res = publisher_caps_rx.changed() => {
if res.is_ok() {
tracing::debug!("capability advert upgraded; republishing");
burst_idx = 0;
}
}
}
}
});
Ok(Self {
publisher,
subscriber,
})
}
pub async fn spawn_default(
pubsub: Arc<PubSubManager>,
signing: Arc<SigningContext>,
self_agent_id: AgentId,
self_machine_id: MachineId,
caps_rx: tokio::sync::watch::Receiver<DmCapabilities>,
store: Arc<CapabilityStore>,
) -> NetworkResult<Self> {
Self::spawn(
pubsub,
signing,
self_agent_id,
self_machine_id,
caps_rx,
store,
Duration::from_secs(ADVERT_PUBLISH_INTERVAL_SECS),
)
.await
}
pub fn abort(&self) {
self.publisher.abort();
self.subscriber.abort();
}
}
impl Drop for CapabilityAdvertService {
fn drop(&mut self) {
self.abort();
}
}
pub fn build_signed_advert(
signing: &SigningContext,
self_agent_id: AgentId,
self_machine_id: MachineId,
capabilities: DmCapabilities,
) -> NetworkResult<Vec<u8>> {
let mut advert = CapabilityAdvert {
protocol_version: ADVERT_PROTOCOL_VERSION,
agent_id: *self_agent_id.as_bytes(),
machine_id: *self_machine_id.as_bytes(),
created_at_unix_ms: now_unix_ms(),
capabilities,
signature: Vec::new(),
};
let signed_bytes = advert
.signed_bytes()
.map_err(|e| NetworkError::SerializationError(format!("advert sign-bytes: {e}")))?;
advert.signature = signing.sign(&signed_bytes)?;
postcard::to_stdvec(&advert)
.map_err(|e| NetworkError::SerializationError(format!("advert encode: {e}")))
}
pub fn verify_advert_signature(advert: &CapabilityAdvert, public_key_bytes: &[u8]) -> bool {
let signed_bytes = match advert.signed_bytes() {
Ok(b) => b,
Err(_) => return false,
};
let public_key = match ant_quic::MlDsaPublicKey::from_bytes(public_key_bytes) {
Ok(pk) => pk,
Err(_) => return false,
};
let derived = crate::identity::AgentId::from_public_key(&public_key);
if derived.0 != advert.agent_id {
return false;
}
let signature =
match ant_quic::crypto::raw_public_keys::pqc::MlDsaSignature::from_bytes(&advert.signature)
{
Ok(s) => s,
Err(_) => return false,
};
ant_quic::crypto::raw_public_keys::pqc::verify_with_ml_dsa(
&public_key,
&signed_bytes,
&signature,
)
.is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::identity::AgentKeypair;
#[test]
fn build_and_verify_advert_roundtrip() {
let kp = AgentKeypair::generate().expect("keygen");
let signing = SigningContext::from_keypair(&kp);
let agent_id = kp.agent_id();
let machine_id = MachineId([9u8; 32]);
let encoded = build_signed_advert(
&signing,
agent_id,
machine_id,
DmCapabilities::v1_gossip_ready(vec![0u8; 1184]),
)
.expect("build");
let advert: CapabilityAdvert = postcard::from_bytes(&encoded).expect("decode");
assert!(verify_advert_signature(&advert, &signing.public_key_bytes));
}
#[test]
fn verify_advert_rejects_tampered_signature() {
let kp = AgentKeypair::generate().expect("keygen");
let signing = SigningContext::from_keypair(&kp);
let encoded = build_signed_advert(
&signing,
kp.agent_id(),
MachineId([0u8; 32]),
DmCapabilities::v1_gossip_ready(vec![0u8; 1184]),
)
.expect("build");
let mut advert: CapabilityAdvert = postcard::from_bytes(&encoded).expect("decode");
advert.signature[0] ^= 0x01;
assert!(!verify_advert_signature(&advert, &signing.public_key_bytes));
}
}