1use std::collections::HashMap;
4use std::fmt;
5use std::io::ErrorKind;
6use std::net::{SocketAddr, ToSocketAddrs};
7use std::num::NonZeroUsize;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::time::Duration;
11
12use agent_primitives::AgentManifest;
13use async_trait::async_trait;
14use mxp::protocol::Flags;
15use mxp::transport::{SocketError, Transport, TransportConfig, TransportHandle};
16use mxp::{Message, MessageType};
17use thiserror::Error;
18use tokio::task::JoinHandle;
19use tokio::time::{MissedTickBehavior, sleep};
20use tracing::{debug, info, warn};
21
22use crate::registry_wire::{
23 ErrorResponse, HeartbeatRequest, HeartbeatResponse, RegisterRequest, RegisterResponse,
24};
25use crate::{AgentState, SchedulerError, TaskScheduler};
26
27#[derive(Debug, Clone, Copy)]
29pub struct RegistrationConfig {
30 heartbeat_interval: Duration,
31 initial_retry_delay: Duration,
32 max_retry_delay: Duration,
33 max_consecutive_failures: NonZeroUsize,
34}
35
36impl RegistrationConfig {
37 #[must_use]
39 pub fn new(
40 heartbeat_interval: Duration,
41 initial_retry_delay: Duration,
42 max_retry_delay: Duration,
43 max_consecutive_failures: NonZeroUsize,
44 ) -> Self {
45 Self {
46 heartbeat_interval,
47 initial_retry_delay,
48 max_retry_delay,
49 max_consecutive_failures,
50 }
51 }
52
53 #[must_use]
55 pub const fn heartbeat_interval(self) -> Duration {
56 self.heartbeat_interval
57 }
58
59 #[must_use]
61 pub const fn initial_retry_delay(self) -> Duration {
62 self.initial_retry_delay
63 }
64
65 #[must_use]
67 pub const fn max_retry_delay(self) -> Duration {
68 self.max_retry_delay
69 }
70
71 #[must_use]
73 pub const fn max_consecutive_failures(self) -> NonZeroUsize {
74 self.max_consecutive_failures
75 }
76
77 pub fn validate(self) -> RegistryResult<()> {
84 if self.heartbeat_interval.is_zero() {
85 return Err(RegistryError::InvalidConfig(
86 "heartbeat interval must be greater than zero",
87 ));
88 }
89 if self.initial_retry_delay.is_zero() {
90 return Err(RegistryError::InvalidConfig(
91 "initial retry delay must be greater than zero",
92 ));
93 }
94 if self.max_retry_delay.is_zero() {
95 return Err(RegistryError::InvalidConfig(
96 "max retry delay must be greater than zero",
97 ));
98 }
99 if self.initial_retry_delay > self.max_retry_delay {
100 return Err(RegistryError::InvalidConfig(
101 "initial retry delay cannot exceed max retry delay",
102 ));
103 }
104 Ok(())
105 }
106}
107
108impl Default for RegistrationConfig {
109 fn default() -> Self {
110 Self {
111 heartbeat_interval: Duration::from_secs(10),
112 initial_retry_delay: Duration::from_secs(1),
113 max_retry_delay: Duration::from_secs(30),
114 max_consecutive_failures: NonZeroUsize::new(3).expect("non-zero"),
115 }
116 }
117}
118
119pub type RegistryResult<T> = Result<T, RegistryError>;
121
122#[derive(Debug, Error)]
124pub enum RegistryError {
125 #[error("invalid registration configuration: {0}")]
127 InvalidConfig(&'static str),
128 #[error(transparent)]
130 Scheduler(#[from] SchedulerError),
131 #[error("registry backend error: {reason}")]
133 Backend {
134 reason: String,
136 },
137}
138
139impl RegistryError {
140 #[must_use]
142 pub fn backend(reason: impl Into<String>) -> Self {
143 Self::Backend {
144 reason: reason.into(),
145 }
146 }
147}
148
149#[derive(Debug)]
151pub struct MxpRegistryClient {
152 handle: TransportHandle,
153 registry_addr: SocketAddr,
154 agent_endpoint: SocketAddr,
155}
156
157impl MxpRegistryClient {
158 pub fn connect(
164 registry_addr: impl ToSocketAddrs,
165 agent_endpoint: SocketAddr,
166 transport_config: Option<TransportConfig>,
167 ) -> RegistryResult<Self> {
168 let registry_addr = registry_addr
169 .to_socket_addrs()
170 .map_err(|err| {
171 RegistryError::backend(format!("failed to resolve registry endpoint: {err:?}"))
172 })?
173 .next()
174 .ok_or_else(|| RegistryError::backend("registry endpoint resolved to no address"))?;
175
176 let config = transport_config.unwrap_or_else(default_transport_config);
177 let transport = Transport::new(config);
178 let local_bind: SocketAddr = "0.0.0.0:0".parse().map_err(|err| {
179 RegistryError::backend(format!("invalid bind address configuration: {err:?}"))
180 })?;
181 let handle = transport
182 .bind(local_bind)
183 .map_err(|err| RegistryError::backend(format!("transport bind failed: {err:?}")))?;
184
185 Ok(Self {
186 handle,
187 registry_addr,
188 agent_endpoint,
189 })
190 }
191
192 fn agent_id(manifest: &AgentManifest) -> String {
193 manifest.id().to_string()
194 }
195
196 fn manifest_to_register_request(&self, manifest: &AgentManifest) -> RegisterRequest {
197 let capabilities = manifest
198 .capabilities()
199 .iter()
200 .map(|cap| cap.id().as_str().to_string())
201 .collect::<Vec<_>>();
202
203 let mut metadata = HashMap::new();
204 metadata.insert("version".to_string(), manifest.version().to_string());
205 if let Some(description) = manifest.description() {
206 metadata.insert("description".to_string(), description.to_string());
207 }
208 if !manifest.tags().is_empty() {
209 metadata.insert(
210 "tags".to_string(),
211 serde_json::to_string(manifest.tags()).unwrap_or_default(),
212 );
213 }
214
215 RegisterRequest {
216 id: Self::agent_id(manifest),
217 name: manifest.name().to_string(),
218 capabilities,
219 address: self.agent_endpoint,
220 metadata,
221 }
222 }
223
224 fn send_request_blocking(
225 handle: &TransportHandle,
226 registry_addr: SocketAddr,
227 message: &Message,
228 ) -> RegistryResult<Message> {
229 let encoded = message.encode();
230 let message_id = message.message_id();
231
232 handle
233 .send(&encoded, registry_addr)
234 .map_err(|err| RegistryError::backend(format!("send failed: {err:?}")))?;
235
236 let mut buffer = handle.acquire_buffer();
237 let response = loop {
238 match handle.receive(&mut buffer) {
239 Ok((_len, _addr)) => {
240 let payload = buffer.as_slice().to_vec();
241 match Message::decode(payload) {
242 Ok(response) => {
243 if response.message_id() == message_id {
244 break response;
245 }
246 }
247 Err(err) => {
248 return Err(RegistryError::backend(format!(
249 "failed to decode registry response: {err:?}"
250 )));
251 }
252 }
253 }
254 Err(SocketError::Io(err))
255 if matches!(
256 err.kind(),
257 ErrorKind::WouldBlock | ErrorKind::TimedOut | ErrorKind::Interrupted
258 ) =>
259 {
260 if err.kind() == ErrorKind::Interrupted {
261 debug!("registry receive interrupted; retrying");
262 continue;
263 }
264 return Err(RegistryError::backend(
265 "timed out waiting for registry response",
266 ));
267 }
268 Err(SocketError::Io(err)) => {
269 return Err(RegistryError::backend(format!(
270 "registry receive failed: {err:?}"
271 )));
272 }
273 }
274 };
275
276 Ok(response)
277 }
278
279 async fn send_request(&self, message: Message) -> RegistryResult<Message> {
280 let handle = self.handle.clone();
281 let registry_addr = self.registry_addr;
282 tokio::task::spawn_blocking(move || {
283 Self::send_request_blocking(&handle, registry_addr, &message)
284 })
285 .await
286 .map_err(|err| RegistryError::backend(format!("registry task join error: {err:?}")))?
287 }
288
289 fn handle_error_message(message: &Message) -> RegistryResult<()> {
290 let payload =
291 serde_json::from_slice::<ErrorResponse>(message.payload()).map_err(|err| {
292 RegistryError::backend(format!("failed to parse registry error payload: {err:?}"))
293 })?;
294 Err(RegistryError::backend(payload.error))
295 }
296}
297
298fn default_transport_config() -> TransportConfig {
299 TransportConfig {
300 buffer_size: 16 * 1024,
301 max_buffers: 256,
302 read_timeout: Some(Duration::from_secs(5)),
303 write_timeout: Some(Duration::from_secs(5)),
304 #[cfg(feature = "debug-tools")]
305 pcap_send_path: None,
306 #[cfg(feature = "debug-tools")]
307 pcap_recv_path: None,
308 }
309}
310
311#[async_trait]
313pub trait AgentRegistry: Send + Sync {
314 async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()>;
316
317 async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()>;
319
320 async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()>;
322}
323
324pub(crate) struct RegistrationController {
325 registry: Arc<dyn AgentRegistry>,
326 manifest: Arc<AgentManifest>,
327 config: RegistrationConfig,
328 shutdown: Arc<AtomicBool>,
329 worker: Option<JoinHandle<()>>,
330}
331
332impl fmt::Debug for RegistrationController {
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 f.debug_struct("RegistrationController")
335 .field("registry", &"dyn AgentRegistry")
336 .field("manifest", &self.manifest.id())
337 .field("config", &self.config)
338 .field("shutdown", &self.shutdown.load(Ordering::Relaxed))
339 .field("worker", &self.worker.is_some())
340 .finish()
341 }
342}
343
344impl RegistrationController {
345 pub(crate) fn new(
346 registry: Arc<dyn AgentRegistry>,
347 manifest: AgentManifest,
348 config: RegistrationConfig,
349 ) -> Self {
350 Self {
351 registry,
352 manifest: Arc::new(manifest),
353 config,
354 shutdown: Arc::new(AtomicBool::new(false)),
355 worker: None,
356 }
357 }
358
359 pub(crate) fn on_state_change(
360 &mut self,
361 state: AgentState,
362 scheduler: &TaskScheduler,
363 ) -> RegistryResult<()> {
364 match state {
365 AgentState::Ready | AgentState::Active => {
366 self.ensure_worker(scheduler)?;
367 }
368 AgentState::Retiring | AgentState::Terminated => {
369 self.shutdown.store(true, Ordering::Release);
370 self.spawn_deregister(scheduler)?;
371 if let Some(handle) = self.worker.take() {
372 handle.abort();
373 }
374 }
375 _ => {}
376 }
377
378 Ok(())
379 }
380
381 fn ensure_worker(&mut self, scheduler: &TaskScheduler) -> RegistryResult<()> {
382 if self.worker.is_some() {
383 return Ok(());
384 }
385
386 self.config.validate()?;
387
388 let registry = Arc::clone(&self.registry);
389 let manifest = Arc::clone(&self.manifest);
390 let shutdown = Arc::clone(&self.shutdown);
391 let config = self.config;
392
393 let handle = scheduler.spawn(async move {
394 run_registration_loop(registry, manifest, shutdown, config).await;
395 })?;
396
397 self.worker = Some(handle);
398 Ok(())
399 }
400
401 fn spawn_deregister(&self, scheduler: &TaskScheduler) -> RegistryResult<()> {
402 let registry = Arc::clone(&self.registry);
403 let manifest = Arc::clone(&self.manifest);
404 scheduler.spawn(async move {
405 if let Err(err) = registry.deregister(&manifest).await {
406 warn!(?err, "agent deregistration failed");
407 } else {
408 info!(agent_id = %manifest.id(), "agent deregistered");
409 }
410 })?;
411 Ok(())
412 }
413}
414
415async fn run_registration_loop(
416 registry: Arc<dyn AgentRegistry>,
417 manifest: Arc<AgentManifest>,
418 shutdown: Arc<AtomicBool>,
419 config: RegistrationConfig,
420) {
421 let mut retry_delay = config.initial_retry_delay();
422
423 loop {
424 if shutdown.load(Ordering::Acquire) {
425 break;
426 }
427
428 match registry.register(&manifest).await {
429 Ok(()) => {
430 info!(agent_id = %manifest.id(), "agent registered with mesh");
431 retry_delay = config.initial_retry_delay();
432 if !run_heartbeat_loop(
433 Arc::clone(®istry),
434 Arc::clone(&manifest),
435 Arc::clone(&shutdown),
436 config,
437 )
438 .await
439 {
440 continue;
441 }
442 break;
443 }
444 Err(err) => {
445 warn!(?err, "agent registration failed; retrying");
446 sleep(retry_delay).await;
447 retry_delay = (retry_delay * 2).min(config.max_retry_delay());
448 }
449 }
450 }
451}
452
453async fn run_heartbeat_loop(
454 registry: Arc<dyn AgentRegistry>,
455 manifest: Arc<AgentManifest>,
456 shutdown: Arc<AtomicBool>,
457 config: RegistrationConfig,
458) -> bool {
459 let mut failures: usize = 0;
460 let mut interval = tokio::time::interval(config.heartbeat_interval());
461 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
462
463 while !shutdown.load(Ordering::Acquire) {
464 interval.tick().await;
465 if shutdown.load(Ordering::Acquire) {
466 break;
467 }
468
469 match registry.heartbeat(&manifest).await {
470 Ok(()) => {
471 failures = 0;
472 }
473 Err(err) => {
474 failures += 1;
475 warn!(?err, failures, "heartbeat failure");
476 if failures >= config.max_consecutive_failures().get() {
477 warn!(
478 failures,
479 "heartbeat failure threshold reached; attempting re-registration"
480 );
481 return false;
482 }
483 }
484 }
485 }
486
487 true
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use std::sync::atomic::AtomicUsize;
494
495 use agent_primitives::{AgentId, Capability, CapabilityId};
496
497 struct MockRegistry {
498 registers: Arc<AtomicUsize>,
499 heartbeats: Arc<AtomicUsize>,
500 deregistrations: Arc<AtomicUsize>,
501 }
502
503 #[async_trait]
504 impl AgentRegistry for MockRegistry {
505 async fn register(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
506 self.registers.fetch_add(1, Ordering::SeqCst);
507 Ok(())
508 }
509
510 async fn heartbeat(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
511 self.heartbeats.fetch_add(1, Ordering::SeqCst);
512 Ok(())
513 }
514
515 async fn deregister(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
516 self.deregistrations.fetch_add(1, Ordering::SeqCst);
517 Ok(())
518 }
519 }
520
521 fn manifest() -> AgentManifest {
522 let capability = Capability::builder(CapabilityId::new("mock.cap").unwrap())
523 .name("Mock")
524 .unwrap()
525 .version("1.0.0")
526 .unwrap()
527 .add_scope("read:mock")
528 .unwrap()
529 .build()
530 .unwrap();
531
532 AgentManifest::builder(AgentId::random())
533 .name("mock-agent")
534 .unwrap()
535 .version("0.1.0")
536 .unwrap()
537 .capabilities(vec![capability])
538 .build()
539 .unwrap()
540 }
541
542 #[tokio::test]
543 async fn lifecycle_starts_and_stops_heartbeat() {
544 let registry = Arc::new(MockRegistry {
545 registers: Arc::new(AtomicUsize::new(0)),
546 heartbeats: Arc::new(AtomicUsize::new(0)),
547 deregistrations: Arc::new(AtomicUsize::new(0)),
548 });
549
550 let manifest = manifest();
551 let config = RegistrationConfig::new(
552 Duration::from_millis(10),
553 Duration::from_millis(5),
554 Duration::from_millis(20),
555 NonZeroUsize::new(3).unwrap(),
556 );
557
558 let mut controller = RegistrationController::new(registry.clone(), manifest, config);
559 let scheduler = TaskScheduler::default();
560
561 controller
562 .on_state_change(AgentState::Ready, &scheduler)
563 .unwrap();
564
565 tokio::time::sleep(Duration::from_millis(40)).await;
566
567 assert!(registry.registers.load(Ordering::SeqCst) >= 1);
568 assert!(registry.heartbeats.load(Ordering::SeqCst) >= 1);
569
570 controller
571 .on_state_change(AgentState::Retiring, &scheduler)
572 .unwrap();
573 tokio::time::sleep(Duration::from_millis(20)).await;
574 assert!(registry.deregistrations.load(Ordering::SeqCst) >= 1);
575 }
576}
577
578#[async_trait]
579impl AgentRegistry for MxpRegistryClient {
580 async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()> {
581 let request = self.manifest_to_register_request(manifest);
582 let payload = serde_json::to_vec(&request)
583 .map_err(|err| RegistryError::backend(format!("encode register payload: {err:?}")))?;
584 let message = Message::new(MessageType::AgentRegister, payload);
585 let response = self.send_request(message).await?;
586
587 match response.message_type() {
588 Some(MessageType::Response) => {
589 let ack = serde_json::from_slice::<RegisterResponse>(response.payload()).map_err(
590 |err| {
591 RegistryError::backend(format!("parse register response failed: {err:?}"))
592 },
593 )?;
594 if ack.success {
595 debug!(agent_id = ack.agent_id, "registry registration acked");
596 Ok(())
597 } else {
598 Err(RegistryError::backend(format!(
599 "registry rejected registration: {}",
600 ack.message
601 )))
602 }
603 }
604 Some(MessageType::Error) => Self::handle_error_message(&response),
605 other => Err(RegistryError::backend(format!(
606 "unexpected message type {other:?} for register response"
607 ))),
608 }
609 }
610
611 async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()> {
612 let request = HeartbeatRequest {
613 agent_id: Self::agent_id(manifest),
614 };
615 let payload = serde_json::to_vec(&request)
616 .map_err(|err| RegistryError::backend(format!("encode heartbeat payload: {err:?}")))?;
617 let message = Message::new(MessageType::AgentHeartbeat, payload);
618 let response = self.send_request(message).await?;
619
620 match response.message_type() {
621 Some(MessageType::Response) => {
622 let ack = serde_json::from_slice::<HeartbeatResponse>(response.payload()).map_err(
623 |err| {
624 RegistryError::backend(format!("parse heartbeat response failed: {err:?}"))
625 },
626 )?;
627 if ack.success && !ack.needs_register {
628 Ok(())
629 } else if ack.needs_register {
630 Err(RegistryError::backend("registry requested re-registration"))
631 } else {
632 Err(RegistryError::backend(
633 ack.message
634 .unwrap_or_else(|| "heartbeat rejected".to_string()),
635 ))
636 }
637 }
638 Some(MessageType::Error) => Self::handle_error_message(&response),
639 other => Err(RegistryError::backend(format!(
640 "unexpected message type {other:?} for heartbeat response"
641 ))),
642 }
643 }
644
645 async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()> {
646 let request = HeartbeatRequest {
647 agent_id: Self::agent_id(manifest),
648 };
649 let payload = serde_json::to_vec(&request)
650 .map_err(|err| RegistryError::backend(format!("encode deregister payload: {err:?}")))?;
651 let mut message = Message::new(MessageType::AgentHeartbeat, payload);
652 message.set_flags(message.flags().with(Flags::FINAL));
653 let response = self.send_request(message).await?;
654
655 match response.message_type() {
656 Some(MessageType::Response) => {
657 let ack = serde_json::from_slice::<HeartbeatResponse>(response.payload()).map_err(
658 |err| {
659 RegistryError::backend(format!("parse deregister response failed: {err:?}"))
660 },
661 )?;
662 if ack.success {
663 Ok(())
664 } else {
665 Err(RegistryError::backend(
666 ack.message
667 .unwrap_or_else(|| "deregister failed".to_string()),
668 ))
669 }
670 }
671 Some(MessageType::Error) => Self::handle_error_message(&response),
672 other => Err(RegistryError::backend(format!(
673 "unexpected message type {other:?} for deregister response"
674 ))),
675 }
676 }
677}