agent_kernel/
registry.rs

1//! Agent registry integration for MXP Nexus mesh discovery and heartbeats.
2
3use std::collections::HashMap;
4use std::fmt;
5use std::io::ErrorKind;
6use std::net::{SocketAddr, ToSocketAddrs};
7use std::num::NonZeroUsize;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::time::Duration;
11
12use agent_primitives::AgentManifest;
13use async_trait::async_trait;
14use mxp::protocol::Flags;
15use mxp::transport::{SocketError, Transport, TransportConfig, TransportHandle};
16use mxp::{Message, MessageType};
17use thiserror::Error;
18use tokio::task::JoinHandle;
19use tokio::time::{MissedTickBehavior, sleep};
20use tracing::{debug, info, warn};
21
22use crate::registry_wire::{
23    ErrorResponse, HeartbeatRequest, HeartbeatResponse, RegisterRequest, RegisterResponse,
24};
25use crate::{AgentState, SchedulerError, TaskScheduler};
26
27/// Configuration for registration and heartbeat maintenance.
28#[derive(Debug, Clone, Copy)]
29pub struct RegistrationConfig {
30    heartbeat_interval: Duration,
31    initial_retry_delay: Duration,
32    max_retry_delay: Duration,
33    max_consecutive_failures: NonZeroUsize,
34}
35
36impl RegistrationConfig {
37    /// Creates a new configuration.
38    #[must_use]
39    pub fn new(
40        heartbeat_interval: Duration,
41        initial_retry_delay: Duration,
42        max_retry_delay: Duration,
43        max_consecutive_failures: NonZeroUsize,
44    ) -> Self {
45        Self {
46            heartbeat_interval,
47            initial_retry_delay,
48            max_retry_delay,
49            max_consecutive_failures,
50        }
51    }
52
53    /// Returns the heartbeat interval.
54    #[must_use]
55    pub const fn heartbeat_interval(self) -> Duration {
56        self.heartbeat_interval
57    }
58
59    /// Returns the initial retry delay.
60    #[must_use]
61    pub const fn initial_retry_delay(self) -> Duration {
62        self.initial_retry_delay
63    }
64
65    /// Returns the maximum retry delay.
66    #[must_use]
67    pub const fn max_retry_delay(self) -> Duration {
68        self.max_retry_delay
69    }
70
71    /// Returns the limit on consecutive heartbeat failures before re-registration.
72    #[must_use]
73    pub const fn max_consecutive_failures(self) -> NonZeroUsize {
74        self.max_consecutive_failures
75    }
76
77    /// Validates the configuration.
78    ///
79    /// # Errors
80    ///
81    /// Returns [`RegistryError::InvalidConfig`] when any duration is zero or the
82    /// retry delay bounds are inconsistent.
83    pub fn validate(self) -> RegistryResult<()> {
84        if self.heartbeat_interval.is_zero() {
85            return Err(RegistryError::InvalidConfig(
86                "heartbeat interval must be greater than zero",
87            ));
88        }
89        if self.initial_retry_delay.is_zero() {
90            return Err(RegistryError::InvalidConfig(
91                "initial retry delay must be greater than zero",
92            ));
93        }
94        if self.max_retry_delay.is_zero() {
95            return Err(RegistryError::InvalidConfig(
96                "max retry delay must be greater than zero",
97            ));
98        }
99        if self.initial_retry_delay > self.max_retry_delay {
100            return Err(RegistryError::InvalidConfig(
101                "initial retry delay cannot exceed max retry delay",
102            ));
103        }
104        Ok(())
105    }
106}
107
108impl Default for RegistrationConfig {
109    fn default() -> Self {
110        Self {
111            heartbeat_interval: Duration::from_secs(10),
112            initial_retry_delay: Duration::from_secs(1),
113            max_retry_delay: Duration::from_secs(30),
114            max_consecutive_failures: NonZeroUsize::new(3).expect("non-zero"),
115        }
116    }
117}
118
119/// Result alias for registry operations.
120pub type RegistryResult<T> = Result<T, RegistryError>;
121
122/// Errors surfaced by registry integration.
123#[derive(Debug, Error)]
124pub enum RegistryError {
125    /// Registration configuration was invalid.
126    #[error("invalid registration configuration: {0}")]
127    InvalidConfig(&'static str),
128    /// Scheduler rejected a task submission.
129    #[error(transparent)]
130    Scheduler(#[from] SchedulerError),
131    /// Registry backend failure.
132    #[error("registry backend error: {reason}")]
133    Backend {
134        /// Human-readable context provided by the backend.
135        reason: String,
136    },
137}
138
139impl RegistryError {
140    /// Convenience helper to construct backend errors.
141    #[must_use]
142    pub fn backend(reason: impl Into<String>) -> Self {
143        Self::Backend {
144            reason: reason.into(),
145        }
146    }
147}
148
149/// MXP-backed registry client that speaks directly to the MXP Nexus registry service.
150#[derive(Debug)]
151pub struct MxpRegistryClient {
152    handle: TransportHandle,
153    registry_addr: SocketAddr,
154    agent_endpoint: SocketAddr,
155}
156
157impl MxpRegistryClient {
158    /// Establishes a registry client using the provided endpoint configuration.
159    ///
160    /// # Errors
161    ///
162    /// Returns [`RegistryError::Backend`] if the transport cannot be bound.
163    pub fn connect(
164        registry_addr: impl ToSocketAddrs,
165        agent_endpoint: SocketAddr,
166        transport_config: Option<TransportConfig>,
167    ) -> RegistryResult<Self> {
168        let registry_addr = registry_addr
169            .to_socket_addrs()
170            .map_err(|err| {
171                RegistryError::backend(format!("failed to resolve registry endpoint: {err:?}"))
172            })?
173            .next()
174            .ok_or_else(|| RegistryError::backend("registry endpoint resolved to no address"))?;
175
176        let config = transport_config.unwrap_or_else(default_transport_config);
177        let transport = Transport::new(config);
178        let local_bind: SocketAddr = "0.0.0.0:0".parse().map_err(|err| {
179            RegistryError::backend(format!("invalid bind address configuration: {err:?}"))
180        })?;
181        let handle = transport
182            .bind(local_bind)
183            .map_err(|err| RegistryError::backend(format!("transport bind failed: {err:?}")))?;
184
185        Ok(Self {
186            handle,
187            registry_addr,
188            agent_endpoint,
189        })
190    }
191
192    fn agent_id(manifest: &AgentManifest) -> String {
193        manifest.id().to_string()
194    }
195
196    fn manifest_to_register_request(&self, manifest: &AgentManifest) -> RegisterRequest {
197        let capabilities = manifest
198            .capabilities()
199            .iter()
200            .map(|cap| cap.id().as_str().to_string())
201            .collect::<Vec<_>>();
202
203        let mut metadata = HashMap::new();
204        metadata.insert("version".to_string(), manifest.version().to_string());
205        if let Some(description) = manifest.description() {
206            metadata.insert("description".to_string(), description.to_string());
207        }
208        if !manifest.tags().is_empty() {
209            metadata.insert(
210                "tags".to_string(),
211                serde_json::to_string(manifest.tags()).unwrap_or_default(),
212            );
213        }
214
215        RegisterRequest {
216            id: Self::agent_id(manifest),
217            name: manifest.name().to_string(),
218            capabilities,
219            address: self.agent_endpoint,
220            metadata,
221        }
222    }
223
224    fn send_request_blocking(
225        handle: &TransportHandle,
226        registry_addr: SocketAddr,
227        message: &Message,
228    ) -> RegistryResult<Message> {
229        let encoded = message.encode();
230        let message_id = message.message_id();
231
232        handle
233            .send(&encoded, registry_addr)
234            .map_err(|err| RegistryError::backend(format!("send failed: {err:?}")))?;
235
236        let mut buffer = handle.acquire_buffer();
237        let response = loop {
238            match handle.receive(&mut buffer) {
239                Ok((_len, _addr)) => {
240                    let payload = buffer.as_slice().to_vec();
241                    match Message::decode(payload) {
242                        Ok(response) => {
243                            if response.message_id() == message_id {
244                                break response;
245                            }
246                        }
247                        Err(err) => {
248                            return Err(RegistryError::backend(format!(
249                                "failed to decode registry response: {err:?}"
250                            )));
251                        }
252                    }
253                }
254                Err(SocketError::Io(err))
255                    if matches!(
256                        err.kind(),
257                        ErrorKind::WouldBlock | ErrorKind::TimedOut | ErrorKind::Interrupted
258                    ) =>
259                {
260                    if err.kind() == ErrorKind::Interrupted {
261                        debug!("registry receive interrupted; retrying");
262                        continue;
263                    }
264                    return Err(RegistryError::backend(
265                        "timed out waiting for registry response",
266                    ));
267                }
268                Err(SocketError::Io(err)) => {
269                    return Err(RegistryError::backend(format!(
270                        "registry receive failed: {err:?}"
271                    )));
272                }
273            }
274        };
275
276        Ok(response)
277    }
278
279    async fn send_request(&self, message: Message) -> RegistryResult<Message> {
280        let handle = self.handle.clone();
281        let registry_addr = self.registry_addr;
282        tokio::task::spawn_blocking(move || {
283            Self::send_request_blocking(&handle, registry_addr, &message)
284        })
285        .await
286        .map_err(|err| RegistryError::backend(format!("registry task join error: {err:?}")))?
287    }
288
289    fn handle_error_message(message: &Message) -> RegistryResult<()> {
290        let payload =
291            serde_json::from_slice::<ErrorResponse>(message.payload()).map_err(|err| {
292                RegistryError::backend(format!("failed to parse registry error payload: {err:?}"))
293            })?;
294        Err(RegistryError::backend(payload.error))
295    }
296}
297
298fn default_transport_config() -> TransportConfig {
299    TransportConfig {
300        buffer_size: 16 * 1024,
301        max_buffers: 256,
302        read_timeout: Some(Duration::from_secs(5)),
303        write_timeout: Some(Duration::from_secs(5)),
304        #[cfg(feature = "debug-tools")]
305        pcap_send_path: None,
306        #[cfg(feature = "debug-tools")]
307        pcap_recv_path: None,
308    }
309}
310
311/// Trait implemented by discovery/registry backends.
312#[async_trait]
313pub trait AgentRegistry: Send + Sync {
314    /// Registers an agent manifest with the mesh.
315    async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()>;
316
317    /// Sends a heartbeat for an already registered agent.
318    async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()>;
319
320    /// Removes the agent from the registry.
321    async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()>;
322}
323
324pub(crate) struct RegistrationController {
325    registry: Arc<dyn AgentRegistry>,
326    manifest: Arc<AgentManifest>,
327    config: RegistrationConfig,
328    shutdown: Arc<AtomicBool>,
329    worker: Option<JoinHandle<()>>,
330}
331
332impl fmt::Debug for RegistrationController {
333    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334        f.debug_struct("RegistrationController")
335            .field("registry", &"dyn AgentRegistry")
336            .field("manifest", &self.manifest.id())
337            .field("config", &self.config)
338            .field("shutdown", &self.shutdown.load(Ordering::Relaxed))
339            .field("worker", &self.worker.is_some())
340            .finish()
341    }
342}
343
344impl RegistrationController {
345    pub(crate) fn new(
346        registry: Arc<dyn AgentRegistry>,
347        manifest: AgentManifest,
348        config: RegistrationConfig,
349    ) -> Self {
350        Self {
351            registry,
352            manifest: Arc::new(manifest),
353            config,
354            shutdown: Arc::new(AtomicBool::new(false)),
355            worker: None,
356        }
357    }
358
359    pub(crate) fn on_state_change(
360        &mut self,
361        state: AgentState,
362        scheduler: &TaskScheduler,
363    ) -> RegistryResult<()> {
364        match state {
365            AgentState::Ready | AgentState::Active => {
366                self.ensure_worker(scheduler)?;
367            }
368            AgentState::Retiring | AgentState::Terminated => {
369                self.shutdown.store(true, Ordering::Release);
370                self.spawn_deregister(scheduler)?;
371                if let Some(handle) = self.worker.take() {
372                    handle.abort();
373                }
374            }
375            _ => {}
376        }
377
378        Ok(())
379    }
380
381    fn ensure_worker(&mut self, scheduler: &TaskScheduler) -> RegistryResult<()> {
382        if self.worker.is_some() {
383            return Ok(());
384        }
385
386        self.config.validate()?;
387
388        let registry = Arc::clone(&self.registry);
389        let manifest = Arc::clone(&self.manifest);
390        let shutdown = Arc::clone(&self.shutdown);
391        let config = self.config;
392
393        let handle = scheduler.spawn(async move {
394            run_registration_loop(registry, manifest, shutdown, config).await;
395        })?;
396
397        self.worker = Some(handle);
398        Ok(())
399    }
400
401    fn spawn_deregister(&self, scheduler: &TaskScheduler) -> RegistryResult<()> {
402        let registry = Arc::clone(&self.registry);
403        let manifest = Arc::clone(&self.manifest);
404        scheduler.spawn(async move {
405            if let Err(err) = registry.deregister(&manifest).await {
406                warn!(?err, "agent deregistration failed");
407            } else {
408                info!(agent_id = %manifest.id(), "agent deregistered");
409            }
410        })?;
411        Ok(())
412    }
413}
414
415async fn run_registration_loop(
416    registry: Arc<dyn AgentRegistry>,
417    manifest: Arc<AgentManifest>,
418    shutdown: Arc<AtomicBool>,
419    config: RegistrationConfig,
420) {
421    let mut retry_delay = config.initial_retry_delay();
422
423    loop {
424        if shutdown.load(Ordering::Acquire) {
425            break;
426        }
427
428        match registry.register(&manifest).await {
429            Ok(()) => {
430                info!(agent_id = %manifest.id(), "agent registered with mesh");
431                retry_delay = config.initial_retry_delay();
432                if !run_heartbeat_loop(
433                    Arc::clone(&registry),
434                    Arc::clone(&manifest),
435                    Arc::clone(&shutdown),
436                    config,
437                )
438                .await
439                {
440                    continue;
441                }
442                break;
443            }
444            Err(err) => {
445                warn!(?err, "agent registration failed; retrying");
446                sleep(retry_delay).await;
447                retry_delay = (retry_delay * 2).min(config.max_retry_delay());
448            }
449        }
450    }
451}
452
453async fn run_heartbeat_loop(
454    registry: Arc<dyn AgentRegistry>,
455    manifest: Arc<AgentManifest>,
456    shutdown: Arc<AtomicBool>,
457    config: RegistrationConfig,
458) -> bool {
459    let mut failures: usize = 0;
460    let mut interval = tokio::time::interval(config.heartbeat_interval());
461    interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
462
463    while !shutdown.load(Ordering::Acquire) {
464        interval.tick().await;
465        if shutdown.load(Ordering::Acquire) {
466            break;
467        }
468
469        match registry.heartbeat(&manifest).await {
470            Ok(()) => {
471                failures = 0;
472            }
473            Err(err) => {
474                failures += 1;
475                warn!(?err, failures, "heartbeat failure");
476                if failures >= config.max_consecutive_failures().get() {
477                    warn!(
478                        failures,
479                        "heartbeat failure threshold reached; attempting re-registration"
480                    );
481                    return false;
482                }
483            }
484        }
485    }
486
487    true
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use std::sync::atomic::AtomicUsize;
494
495    use agent_primitives::{AgentId, Capability, CapabilityId};
496
497    struct MockRegistry {
498        registers: Arc<AtomicUsize>,
499        heartbeats: Arc<AtomicUsize>,
500        deregistrations: Arc<AtomicUsize>,
501    }
502
503    #[async_trait]
504    impl AgentRegistry for MockRegistry {
505        async fn register(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
506            self.registers.fetch_add(1, Ordering::SeqCst);
507            Ok(())
508        }
509
510        async fn heartbeat(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
511            self.heartbeats.fetch_add(1, Ordering::SeqCst);
512            Ok(())
513        }
514
515        async fn deregister(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
516            self.deregistrations.fetch_add(1, Ordering::SeqCst);
517            Ok(())
518        }
519    }
520
521    fn manifest() -> AgentManifest {
522        let capability = Capability::builder(CapabilityId::new("mock.cap").unwrap())
523            .name("Mock")
524            .unwrap()
525            .version("1.0.0")
526            .unwrap()
527            .add_scope("read:mock")
528            .unwrap()
529            .build()
530            .unwrap();
531
532        AgentManifest::builder(AgentId::random())
533            .name("mock-agent")
534            .unwrap()
535            .version("0.1.0")
536            .unwrap()
537            .capabilities(vec![capability])
538            .build()
539            .unwrap()
540    }
541
542    #[tokio::test]
543    async fn lifecycle_starts_and_stops_heartbeat() {
544        let registry = Arc::new(MockRegistry {
545            registers: Arc::new(AtomicUsize::new(0)),
546            heartbeats: Arc::new(AtomicUsize::new(0)),
547            deregistrations: Arc::new(AtomicUsize::new(0)),
548        });
549
550        let manifest = manifest();
551        let config = RegistrationConfig::new(
552            Duration::from_millis(10),
553            Duration::from_millis(5),
554            Duration::from_millis(20),
555            NonZeroUsize::new(3).unwrap(),
556        );
557
558        let mut controller = RegistrationController::new(registry.clone(), manifest, config);
559        let scheduler = TaskScheduler::default();
560
561        controller
562            .on_state_change(AgentState::Ready, &scheduler)
563            .unwrap();
564
565        tokio::time::sleep(Duration::from_millis(40)).await;
566
567        assert!(registry.registers.load(Ordering::SeqCst) >= 1);
568        assert!(registry.heartbeats.load(Ordering::SeqCst) >= 1);
569
570        controller
571            .on_state_change(AgentState::Retiring, &scheduler)
572            .unwrap();
573        tokio::time::sleep(Duration::from_millis(20)).await;
574        assert!(registry.deregistrations.load(Ordering::SeqCst) >= 1);
575    }
576}
577
578#[async_trait]
579impl AgentRegistry for MxpRegistryClient {
580    async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()> {
581        let request = self.manifest_to_register_request(manifest);
582        let payload = serde_json::to_vec(&request)
583            .map_err(|err| RegistryError::backend(format!("encode register payload: {err:?}")))?;
584        let message = Message::new(MessageType::AgentRegister, payload);
585        let response = self.send_request(message).await?;
586
587        match response.message_type() {
588            Some(MessageType::Response) => {
589                let ack = serde_json::from_slice::<RegisterResponse>(response.payload()).map_err(
590                    |err| {
591                        RegistryError::backend(format!("parse register response failed: {err:?}"))
592                    },
593                )?;
594                if ack.success {
595                    debug!(agent_id = ack.agent_id, "registry registration acked");
596                    Ok(())
597                } else {
598                    Err(RegistryError::backend(format!(
599                        "registry rejected registration: {}",
600                        ack.message
601                    )))
602                }
603            }
604            Some(MessageType::Error) => Self::handle_error_message(&response),
605            other => Err(RegistryError::backend(format!(
606                "unexpected message type {other:?} for register response"
607            ))),
608        }
609    }
610
611    async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()> {
612        let request = HeartbeatRequest {
613            agent_id: Self::agent_id(manifest),
614        };
615        let payload = serde_json::to_vec(&request)
616            .map_err(|err| RegistryError::backend(format!("encode heartbeat payload: {err:?}")))?;
617        let message = Message::new(MessageType::AgentHeartbeat, payload);
618        let response = self.send_request(message).await?;
619
620        match response.message_type() {
621            Some(MessageType::Response) => {
622                let ack = serde_json::from_slice::<HeartbeatResponse>(response.payload()).map_err(
623                    |err| {
624                        RegistryError::backend(format!("parse heartbeat response failed: {err:?}"))
625                    },
626                )?;
627                if ack.success && !ack.needs_register {
628                    Ok(())
629                } else if ack.needs_register {
630                    Err(RegistryError::backend("registry requested re-registration"))
631                } else {
632                    Err(RegistryError::backend(
633                        ack.message
634                            .unwrap_or_else(|| "heartbeat rejected".to_string()),
635                    ))
636                }
637            }
638            Some(MessageType::Error) => Self::handle_error_message(&response),
639            other => Err(RegistryError::backend(format!(
640                "unexpected message type {other:?} for heartbeat response"
641            ))),
642        }
643    }
644
645    async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()> {
646        let request = HeartbeatRequest {
647            agent_id: Self::agent_id(manifest),
648        };
649        let payload = serde_json::to_vec(&request)
650            .map_err(|err| RegistryError::backend(format!("encode deregister payload: {err:?}")))?;
651        let mut message = Message::new(MessageType::AgentHeartbeat, payload);
652        message.set_flags(message.flags().with(Flags::FINAL));
653        let response = self.send_request(message).await?;
654
655        match response.message_type() {
656            Some(MessageType::Response) => {
657                let ack = serde_json::from_slice::<HeartbeatResponse>(response.payload()).map_err(
658                    |err| {
659                        RegistryError::backend(format!("parse deregister response failed: {err:?}"))
660                    },
661                )?;
662                if ack.success {
663                    Ok(())
664                } else {
665                    Err(RegistryError::backend(
666                        ack.message
667                            .unwrap_or_else(|| "deregister failed".to_string()),
668                    ))
669                }
670            }
671            Some(MessageType::Error) => Self::handle_error_message(&response),
672            other => Err(RegistryError::backend(format!(
673                "unexpected message type {other:?} for deregister response"
674            ))),
675        }
676    }
677}