1use std::collections::HashMap;
4use std::fmt;
5use std::io::ErrorKind;
6use std::net::{SocketAddr, ToSocketAddrs};
7use std::num::NonZeroUsize;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::time::Duration;
11
12use agent_primitives::AgentManifest;
13use async_trait::async_trait;
14use mxp::protocol::Flags;
15use mxp::transport::{SocketError, Transport, TransportConfig, TransportHandle};
16use mxp::{Message, MessageType};
17use thiserror::Error;
18use tokio::task::JoinHandle;
19use tokio::time::{MissedTickBehavior, sleep};
20use tracing::{debug, info, warn};
21
22use crate::registry_wire::{
23 ErrorResponse, HeartbeatRequest, HeartbeatResponse, RegisterRequest, RegisterResponse,
24};
25use crate::{AgentState, SchedulerError, TaskScheduler};
26
27#[derive(Debug, Clone, Copy)]
29pub struct RegistrationConfig {
30 heartbeat_interval: Duration,
31 initial_retry_delay: Duration,
32 max_retry_delay: Duration,
33 max_consecutive_failures: NonZeroUsize,
34}
35
36impl RegistrationConfig {
37 #[must_use]
39 pub fn new(
40 heartbeat_interval: Duration,
41 initial_retry_delay: Duration,
42 max_retry_delay: Duration,
43 max_consecutive_failures: NonZeroUsize,
44 ) -> Self {
45 Self {
46 heartbeat_interval,
47 initial_retry_delay,
48 max_retry_delay,
49 max_consecutive_failures,
50 }
51 }
52
53 #[must_use]
55 pub const fn heartbeat_interval(self) -> Duration {
56 self.heartbeat_interval
57 }
58
59 #[must_use]
61 pub const fn initial_retry_delay(self) -> Duration {
62 self.initial_retry_delay
63 }
64
65 #[must_use]
67 pub const fn max_retry_delay(self) -> Duration {
68 self.max_retry_delay
69 }
70
71 #[must_use]
73 pub const fn max_consecutive_failures(self) -> NonZeroUsize {
74 self.max_consecutive_failures
75 }
76
77 pub fn validate(self) -> RegistryResult<()> {
84 if self.heartbeat_interval.is_zero() {
85 return Err(RegistryError::InvalidConfig(
86 "heartbeat interval must be greater than zero",
87 ));
88 }
89 if self.initial_retry_delay.is_zero() {
90 return Err(RegistryError::InvalidConfig(
91 "initial retry delay must be greater than zero",
92 ));
93 }
94 if self.max_retry_delay.is_zero() {
95 return Err(RegistryError::InvalidConfig(
96 "max retry delay must be greater than zero",
97 ));
98 }
99 if self.initial_retry_delay > self.max_retry_delay {
100 return Err(RegistryError::InvalidConfig(
101 "initial retry delay cannot exceed max retry delay",
102 ));
103 }
104 Ok(())
105 }
106}
107
108impl Default for RegistrationConfig {
109 fn default() -> Self {
110 Self {
111 heartbeat_interval: Duration::from_secs(10),
112 initial_retry_delay: Duration::from_secs(1),
113 max_retry_delay: Duration::from_secs(30),
114 max_consecutive_failures: NonZeroUsize::new(3).expect("non-zero"),
115 }
116 }
117}
118
119pub type RegistryResult<T> = Result<T, RegistryError>;
121
122#[derive(Debug, Error)]
124pub enum RegistryError {
125 #[error("invalid registration configuration: {0}")]
127 InvalidConfig(&'static str),
128 #[error(transparent)]
130 Scheduler(#[from] SchedulerError),
131 #[error("registry backend error: {reason}")]
133 Backend {
134 reason: String,
136 },
137}
138
139impl RegistryError {
140 #[must_use]
142 pub fn backend(reason: impl Into<String>) -> Self {
143 Self::Backend {
144 reason: reason.into(),
145 }
146 }
147}
148
149#[derive(Debug)]
151pub struct MxpRegistryClient {
152 handle: TransportHandle,
153 registry_addr: SocketAddr,
154 agent_endpoint: SocketAddr,
155}
156
157impl MxpRegistryClient {
158 pub fn connect(
164 registry_addr: impl ToSocketAddrs,
165 agent_endpoint: SocketAddr,
166 transport_config: Option<TransportConfig>,
167 ) -> RegistryResult<Self> {
168 let registry_addr = registry_addr
169 .to_socket_addrs()
170 .map_err(|err| {
171 RegistryError::backend(format!("failed to resolve registry endpoint: {err:?}"))
172 })?
173 .next()
174 .ok_or_else(|| RegistryError::backend("registry endpoint resolved to no address"))?;
175
176 let config = transport_config.unwrap_or_else(default_transport_config);
177 let transport = Transport::new(config);
178 let local_bind: SocketAddr = "0.0.0.0:0".parse().map_err(|err| {
179 RegistryError::backend(format!("invalid bind address configuration: {err:?}"))
180 })?;
181 let handle = transport
182 .bind(local_bind)
183 .map_err(|err| RegistryError::backend(format!("transport bind failed: {err:?}")))?;
184
185 Ok(Self {
186 handle,
187 registry_addr,
188 agent_endpoint,
189 })
190 }
191
192 fn agent_id(manifest: &AgentManifest) -> String {
193 manifest.id().to_string()
194 }
195
196 fn manifest_to_register_request(&self, manifest: &AgentManifest) -> RegisterRequest {
197 let capabilities = manifest
198 .capabilities()
199 .iter()
200 .map(|cap| cap.id().as_str().to_string())
201 .collect::<Vec<_>>();
202
203 let mut metadata = HashMap::new();
204 metadata.insert("version".to_string(), manifest.version().to_string());
205 if let Some(description) = manifest.description() {
206 metadata.insert("description".to_string(), description.to_string());
207 }
208 if !manifest.tags().is_empty() {
209 metadata.insert(
210 "tags".to_string(),
211 serde_json::to_string(manifest.tags()).unwrap_or_default(),
212 );
213 }
214
215 RegisterRequest {
216 id: Self::agent_id(manifest),
217 name: manifest.name().to_string(),
218 capabilities,
219 address: self.agent_endpoint,
220 metadata,
221 }
222 }
223
224 fn send_request_blocking(
225 handle: &TransportHandle,
226 registry_addr: SocketAddr,
227 message: &Message,
228 ) -> RegistryResult<Message> {
229 let encoded = message.encode();
230 let message_id = message.message_id();
231
232 handle
233 .send(&encoded, registry_addr)
234 .map_err(|err| RegistryError::backend(format!("send failed: {err:?}")))?;
235
236 let mut buffer = handle.acquire_buffer();
237 let response = loop {
238 match handle.receive(&mut buffer) {
239 Ok((_len, _addr)) => {
240 let payload = buffer.as_slice().to_vec();
241 match Message::decode(payload) {
242 Ok(response) => {
243 if response.message_id() == message_id {
244 break response;
245 }
246 }
247 Err(err) => {
248 return Err(RegistryError::backend(format!(
249 "failed to decode registry response: {err:?}"
250 )));
251 }
252 }
253 }
254 Err(SocketError::Io(err))
255 if matches!(
256 err.kind(),
257 ErrorKind::WouldBlock | ErrorKind::TimedOut | ErrorKind::Interrupted
258 ) =>
259 {
260 if err.kind() == ErrorKind::Interrupted {
261 debug!("registry receive interrupted; retrying");
262 continue;
263 }
264 return Err(RegistryError::backend(
265 "timed out waiting for registry response",
266 ));
267 }
268 Err(SocketError::Io(err)) => {
269 return Err(RegistryError::backend(format!(
270 "registry receive failed: {err:?}"
271 )));
272 }
273 }
274 };
275
276 Ok(response)
277 }
278
279 async fn send_request(&self, message: Message) -> RegistryResult<Message> {
280 let handle = self.handle.clone();
281 let registry_addr = self.registry_addr;
282 tokio::task::spawn_blocking(move || {
283 Self::send_request_blocking(&handle, registry_addr, &message)
284 })
285 .await
286 .map_err(|err| RegistryError::backend(format!("registry task join error: {err:?}")))?
287 }
288
289 fn handle_error_message(message: &Message) -> RegistryResult<()> {
290 let payload =
291 serde_json::from_slice::<ErrorResponse>(message.payload()).map_err(|err| {
292 RegistryError::backend(format!("failed to parse registry error payload: {err:?}"))
293 })?;
294 Err(RegistryError::backend(payload.error))
295 }
296}
297
298fn default_transport_config() -> TransportConfig {
299 TransportConfig {
300 buffer_size: 16 * 1024,
301 max_buffers: 256,
302 read_timeout: Some(Duration::from_secs(5)),
303 write_timeout: Some(Duration::from_secs(5)),
304 }
305}
306
307#[async_trait]
309pub trait AgentRegistry: Send + Sync {
310 async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()>;
312
313 async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()>;
315
316 async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()>;
318}
319
320pub(crate) struct RegistrationController {
321 registry: Arc<dyn AgentRegistry>,
322 manifest: Arc<AgentManifest>,
323 config: RegistrationConfig,
324 shutdown: Arc<AtomicBool>,
325 worker: Option<JoinHandle<()>>,
326}
327
328impl fmt::Debug for RegistrationController {
329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330 f.debug_struct("RegistrationController")
331 .field("registry", &"dyn AgentRegistry")
332 .field("manifest", &self.manifest.id())
333 .field("config", &self.config)
334 .field("shutdown", &self.shutdown.load(Ordering::Relaxed))
335 .field("worker", &self.worker.is_some())
336 .finish()
337 }
338}
339
340impl RegistrationController {
341 pub(crate) fn new(
342 registry: Arc<dyn AgentRegistry>,
343 manifest: AgentManifest,
344 config: RegistrationConfig,
345 ) -> Self {
346 Self {
347 registry,
348 manifest: Arc::new(manifest),
349 config,
350 shutdown: Arc::new(AtomicBool::new(false)),
351 worker: None,
352 }
353 }
354
355 pub(crate) fn on_state_change(
356 &mut self,
357 state: AgentState,
358 scheduler: &TaskScheduler,
359 ) -> RegistryResult<()> {
360 match state {
361 AgentState::Ready | AgentState::Active => {
362 self.ensure_worker(scheduler)?;
363 }
364 AgentState::Retiring | AgentState::Terminated => {
365 self.shutdown.store(true, Ordering::Release);
366 self.spawn_deregister(scheduler)?;
367 if let Some(handle) = self.worker.take() {
368 handle.abort();
369 }
370 }
371 _ => {}
372 }
373
374 Ok(())
375 }
376
377 fn ensure_worker(&mut self, scheduler: &TaskScheduler) -> RegistryResult<()> {
378 if self.worker.is_some() {
379 return Ok(());
380 }
381
382 self.config.validate()?;
383
384 let registry = Arc::clone(&self.registry);
385 let manifest = Arc::clone(&self.manifest);
386 let shutdown = Arc::clone(&self.shutdown);
387 let config = self.config;
388
389 let handle = scheduler.spawn(async move {
390 run_registration_loop(registry, manifest, shutdown, config).await;
391 })?;
392
393 self.worker = Some(handle);
394 Ok(())
395 }
396
397 fn spawn_deregister(&self, scheduler: &TaskScheduler) -> RegistryResult<()> {
398 let registry = Arc::clone(&self.registry);
399 let manifest = Arc::clone(&self.manifest);
400 scheduler.spawn(async move {
401 if let Err(err) = registry.deregister(&manifest).await {
402 warn!(?err, "agent deregistration failed");
403 } else {
404 info!(agent_id = %manifest.id(), "agent deregistered");
405 }
406 })?;
407 Ok(())
408 }
409}
410
411async fn run_registration_loop(
412 registry: Arc<dyn AgentRegistry>,
413 manifest: Arc<AgentManifest>,
414 shutdown: Arc<AtomicBool>,
415 config: RegistrationConfig,
416) {
417 let mut retry_delay = config.initial_retry_delay();
418
419 loop {
420 if shutdown.load(Ordering::Acquire) {
421 break;
422 }
423
424 match registry.register(&manifest).await {
425 Ok(()) => {
426 info!(agent_id = %manifest.id(), "agent registered with mesh");
427 retry_delay = config.initial_retry_delay();
428 if !run_heartbeat_loop(
429 Arc::clone(®istry),
430 Arc::clone(&manifest),
431 Arc::clone(&shutdown),
432 config,
433 )
434 .await
435 {
436 continue;
437 }
438 break;
439 }
440 Err(err) => {
441 warn!(?err, "agent registration failed; retrying");
442 sleep(retry_delay).await;
443 retry_delay = (retry_delay * 2).min(config.max_retry_delay());
444 }
445 }
446 }
447}
448
449async fn run_heartbeat_loop(
450 registry: Arc<dyn AgentRegistry>,
451 manifest: Arc<AgentManifest>,
452 shutdown: Arc<AtomicBool>,
453 config: RegistrationConfig,
454) -> bool {
455 let mut failures: usize = 0;
456 let mut interval = tokio::time::interval(config.heartbeat_interval());
457 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
458
459 while !shutdown.load(Ordering::Acquire) {
460 interval.tick().await;
461 if shutdown.load(Ordering::Acquire) {
462 break;
463 }
464
465 match registry.heartbeat(&manifest).await {
466 Ok(()) => {
467 failures = 0;
468 }
469 Err(err) => {
470 failures += 1;
471 warn!(?err, failures, "heartbeat failure");
472 if failures >= config.max_consecutive_failures().get() {
473 warn!(
474 failures,
475 "heartbeat failure threshold reached; attempting re-registration"
476 );
477 return false;
478 }
479 }
480 }
481 }
482
483 true
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use std::sync::atomic::AtomicUsize;
490
491 use agent_primitives::{AgentId, Capability, CapabilityId};
492
493 struct MockRegistry {
494 registers: Arc<AtomicUsize>,
495 heartbeats: Arc<AtomicUsize>,
496 deregistrations: Arc<AtomicUsize>,
497 }
498
499 #[async_trait]
500 impl AgentRegistry for MockRegistry {
501 async fn register(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
502 self.registers.fetch_add(1, Ordering::SeqCst);
503 Ok(())
504 }
505
506 async fn heartbeat(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
507 self.heartbeats.fetch_add(1, Ordering::SeqCst);
508 Ok(())
509 }
510
511 async fn deregister(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
512 self.deregistrations.fetch_add(1, Ordering::SeqCst);
513 Ok(())
514 }
515 }
516
517 fn manifest() -> AgentManifest {
518 let capability = Capability::builder(CapabilityId::new("mock.cap").unwrap())
519 .name("Mock")
520 .unwrap()
521 .version("1.0.0")
522 .unwrap()
523 .add_scope("read:mock")
524 .unwrap()
525 .build()
526 .unwrap();
527
528 AgentManifest::builder(AgentId::random())
529 .name("mock-agent")
530 .unwrap()
531 .version("0.1.0")
532 .unwrap()
533 .capabilities(vec![capability])
534 .build()
535 .unwrap()
536 }
537
538 #[tokio::test]
539 async fn lifecycle_starts_and_stops_heartbeat() {
540 let registry = Arc::new(MockRegistry {
541 registers: Arc::new(AtomicUsize::new(0)),
542 heartbeats: Arc::new(AtomicUsize::new(0)),
543 deregistrations: Arc::new(AtomicUsize::new(0)),
544 });
545
546 let manifest = manifest();
547 let config = RegistrationConfig::new(
548 Duration::from_millis(10),
549 Duration::from_millis(5),
550 Duration::from_millis(20),
551 NonZeroUsize::new(3).unwrap(),
552 );
553
554 let mut controller = RegistrationController::new(registry.clone(), manifest, config);
555 let scheduler = TaskScheduler::default();
556
557 controller
558 .on_state_change(AgentState::Ready, &scheduler)
559 .unwrap();
560
561 tokio::time::sleep(Duration::from_millis(40)).await;
562
563 assert!(registry.registers.load(Ordering::SeqCst) >= 1);
564 assert!(registry.heartbeats.load(Ordering::SeqCst) >= 1);
565
566 controller
567 .on_state_change(AgentState::Retiring, &scheduler)
568 .unwrap();
569 tokio::time::sleep(Duration::from_millis(20)).await;
570 assert!(registry.deregistrations.load(Ordering::SeqCst) >= 1);
571 }
572}
573
574#[async_trait]
575impl AgentRegistry for MxpRegistryClient {
576 async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()> {
577 let request = self.manifest_to_register_request(manifest);
578 let payload = serde_json::to_vec(&request)
579 .map_err(|err| RegistryError::backend(format!("encode register payload: {err:?}")))?;
580 let message = Message::new(MessageType::AgentRegister, payload);
581 let response = self.send_request(message).await?;
582
583 match response.message_type() {
584 Some(MessageType::Response) => {
585 let ack = serde_json::from_slice::<RegisterResponse>(response.payload()).map_err(
586 |err| {
587 RegistryError::backend(format!("parse register response failed: {err:?}"))
588 },
589 )?;
590 if ack.success {
591 debug!(agent_id = ack.agent_id, "registry registration acked");
592 Ok(())
593 } else {
594 Err(RegistryError::backend(format!(
595 "registry rejected registration: {}",
596 ack.message
597 )))
598 }
599 }
600 Some(MessageType::Error) => Self::handle_error_message(&response),
601 other => Err(RegistryError::backend(format!(
602 "unexpected message type {other:?} for register response"
603 ))),
604 }
605 }
606
607 async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()> {
608 let request = HeartbeatRequest {
609 agent_id: Self::agent_id(manifest),
610 };
611 let payload = serde_json::to_vec(&request)
612 .map_err(|err| RegistryError::backend(format!("encode heartbeat payload: {err:?}")))?;
613 let message = Message::new(MessageType::AgentHeartbeat, payload);
614 let response = self.send_request(message).await?;
615
616 match response.message_type() {
617 Some(MessageType::Response) => {
618 let ack = serde_json::from_slice::<HeartbeatResponse>(response.payload()).map_err(
619 |err| {
620 RegistryError::backend(format!("parse heartbeat response failed: {err:?}"))
621 },
622 )?;
623 if ack.success && !ack.needs_register {
624 Ok(())
625 } else if ack.needs_register {
626 Err(RegistryError::backend("registry requested re-registration"))
627 } else {
628 Err(RegistryError::backend(
629 ack.message
630 .unwrap_or_else(|| "heartbeat rejected".to_string()),
631 ))
632 }
633 }
634 Some(MessageType::Error) => Self::handle_error_message(&response),
635 other => Err(RegistryError::backend(format!(
636 "unexpected message type {other:?} for heartbeat response"
637 ))),
638 }
639 }
640
641 async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()> {
642 let request = HeartbeatRequest {
643 agent_id: Self::agent_id(manifest),
644 };
645 let payload = serde_json::to_vec(&request)
646 .map_err(|err| RegistryError::backend(format!("encode deregister payload: {err:?}")))?;
647 let mut message = Message::new(MessageType::AgentHeartbeat, payload);
648 message.set_flags(message.flags().with(Flags::FINAL));
649 let response = self.send_request(message).await?;
650
651 match response.message_type() {
652 Some(MessageType::Response) => {
653 let ack = serde_json::from_slice::<HeartbeatResponse>(response.payload()).map_err(
654 |err| {
655 RegistryError::backend(format!("parse deregister response failed: {err:?}"))
656 },
657 )?;
658 if ack.success {
659 Ok(())
660 } else {
661 Err(RegistryError::backend(
662 ack.message
663 .unwrap_or_else(|| "deregister failed".to_string()),
664 ))
665 }
666 }
667 Some(MessageType::Error) => Self::handle_error_message(&response),
668 other => Err(RegistryError::backend(format!(
669 "unexpected message type {other:?} for deregister response"
670 ))),
671 }
672 }
673}