agent_kernel/
registry.rs

1//! Agent registry integration for MXP Nexus mesh discovery and heartbeats.
2
3use std::collections::HashMap;
4use std::fmt;
5use std::io::ErrorKind;
6use std::net::{SocketAddr, ToSocketAddrs};
7use std::num::NonZeroUsize;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::time::Duration;
11
12use agent_primitives::AgentManifest;
13use async_trait::async_trait;
14use mxp::protocol::Flags;
15use mxp::transport::{SocketError, Transport, TransportConfig, TransportHandle};
16use mxp::{Message, MessageType};
17use thiserror::Error;
18use tokio::task::JoinHandle;
19use tokio::time::{MissedTickBehavior, sleep};
20use tracing::{debug, info, warn};
21
22use crate::registry_wire::{
23    ErrorResponse, HeartbeatRequest, HeartbeatResponse, RegisterRequest, RegisterResponse,
24};
25use crate::{AgentState, SchedulerError, TaskScheduler};
26
27/// Configuration for registration and heartbeat maintenance.
28#[derive(Debug, Clone, Copy)]
29pub struct RegistrationConfig {
30    heartbeat_interval: Duration,
31    initial_retry_delay: Duration,
32    max_retry_delay: Duration,
33    max_consecutive_failures: NonZeroUsize,
34}
35
36impl RegistrationConfig {
37    /// Creates a new configuration.
38    #[must_use]
39    pub fn new(
40        heartbeat_interval: Duration,
41        initial_retry_delay: Duration,
42        max_retry_delay: Duration,
43        max_consecutive_failures: NonZeroUsize,
44    ) -> Self {
45        Self {
46            heartbeat_interval,
47            initial_retry_delay,
48            max_retry_delay,
49            max_consecutive_failures,
50        }
51    }
52
53    /// Returns the heartbeat interval.
54    #[must_use]
55    pub const fn heartbeat_interval(self) -> Duration {
56        self.heartbeat_interval
57    }
58
59    /// Returns the initial retry delay.
60    #[must_use]
61    pub const fn initial_retry_delay(self) -> Duration {
62        self.initial_retry_delay
63    }
64
65    /// Returns the maximum retry delay.
66    #[must_use]
67    pub const fn max_retry_delay(self) -> Duration {
68        self.max_retry_delay
69    }
70
71    /// Returns the limit on consecutive heartbeat failures before re-registration.
72    #[must_use]
73    pub const fn max_consecutive_failures(self) -> NonZeroUsize {
74        self.max_consecutive_failures
75    }
76
77    /// Validates the configuration.
78    ///
79    /// # Errors
80    ///
81    /// Returns [`RegistryError::InvalidConfig`] when any duration is zero or the
82    /// retry delay bounds are inconsistent.
83    pub fn validate(self) -> RegistryResult<()> {
84        if self.heartbeat_interval.is_zero() {
85            return Err(RegistryError::InvalidConfig(
86                "heartbeat interval must be greater than zero",
87            ));
88        }
89        if self.initial_retry_delay.is_zero() {
90            return Err(RegistryError::InvalidConfig(
91                "initial retry delay must be greater than zero",
92            ));
93        }
94        if self.max_retry_delay.is_zero() {
95            return Err(RegistryError::InvalidConfig(
96                "max retry delay must be greater than zero",
97            ));
98        }
99        if self.initial_retry_delay > self.max_retry_delay {
100            return Err(RegistryError::InvalidConfig(
101                "initial retry delay cannot exceed max retry delay",
102            ));
103        }
104        Ok(())
105    }
106}
107
108impl Default for RegistrationConfig {
109    fn default() -> Self {
110        Self {
111            heartbeat_interval: Duration::from_secs(10),
112            initial_retry_delay: Duration::from_secs(1),
113            max_retry_delay: Duration::from_secs(30),
114            max_consecutive_failures: NonZeroUsize::new(3).expect("non-zero"),
115        }
116    }
117}
118
119/// Result alias for registry operations.
120pub type RegistryResult<T> = Result<T, RegistryError>;
121
122/// Errors surfaced by registry integration.
123#[derive(Debug, Error)]
124pub enum RegistryError {
125    /// Registration configuration was invalid.
126    #[error("invalid registration configuration: {0}")]
127    InvalidConfig(&'static str),
128    /// Scheduler rejected a task submission.
129    #[error(transparent)]
130    Scheduler(#[from] SchedulerError),
131    /// Registry backend failure.
132    #[error("registry backend error: {reason}")]
133    Backend {
134        /// Human-readable context provided by the backend.
135        reason: String,
136    },
137}
138
139impl RegistryError {
140    /// Convenience helper to construct backend errors.
141    #[must_use]
142    pub fn backend(reason: impl Into<String>) -> Self {
143        Self::Backend {
144            reason: reason.into(),
145        }
146    }
147}
148
149/// MXP-backed registry client that speaks directly to the MXP Nexus registry service.
150#[derive(Debug)]
151pub struct MxpRegistryClient {
152    handle: TransportHandle,
153    registry_addr: SocketAddr,
154    agent_endpoint: SocketAddr,
155}
156
157impl MxpRegistryClient {
158    /// Establishes a registry client using the provided endpoint configuration.
159    ///
160    /// # Errors
161    ///
162    /// Returns [`RegistryError::Backend`] if the transport cannot be bound.
163    pub fn connect(
164        registry_addr: impl ToSocketAddrs,
165        agent_endpoint: SocketAddr,
166        transport_config: Option<TransportConfig>,
167    ) -> RegistryResult<Self> {
168        let registry_addr = registry_addr
169            .to_socket_addrs()
170            .map_err(|err| {
171                RegistryError::backend(format!("failed to resolve registry endpoint: {err:?}"))
172            })?
173            .next()
174            .ok_or_else(|| RegistryError::backend("registry endpoint resolved to no address"))?;
175
176        let config = transport_config.unwrap_or_else(default_transport_config);
177        let transport = Transport::new(config);
178        let local_bind: SocketAddr = "0.0.0.0:0".parse().map_err(|err| {
179            RegistryError::backend(format!("invalid bind address configuration: {err:?}"))
180        })?;
181        let handle = transport
182            .bind(local_bind)
183            .map_err(|err| RegistryError::backend(format!("transport bind failed: {err:?}")))?;
184
185        Ok(Self {
186            handle,
187            registry_addr,
188            agent_endpoint,
189        })
190    }
191
192    fn agent_id(manifest: &AgentManifest) -> String {
193        manifest.id().to_string()
194    }
195
196    fn manifest_to_register_request(&self, manifest: &AgentManifest) -> RegisterRequest {
197        let capabilities = manifest
198            .capabilities()
199            .iter()
200            .map(|cap| cap.id().as_str().to_string())
201            .collect::<Vec<_>>();
202
203        let mut metadata = HashMap::new();
204        metadata.insert("version".to_string(), manifest.version().to_string());
205        if let Some(description) = manifest.description() {
206            metadata.insert("description".to_string(), description.to_string());
207        }
208        if !manifest.tags().is_empty() {
209            metadata.insert(
210                "tags".to_string(),
211                serde_json::to_string(manifest.tags()).unwrap_or_default(),
212            );
213        }
214
215        RegisterRequest {
216            id: Self::agent_id(manifest),
217            name: manifest.name().to_string(),
218            capabilities,
219            address: self.agent_endpoint,
220            metadata,
221        }
222    }
223
224    fn send_request_blocking(
225        handle: &TransportHandle,
226        registry_addr: SocketAddr,
227        message: &Message,
228    ) -> RegistryResult<Message> {
229        let encoded = message.encode();
230        let message_id = message.message_id();
231
232        handle
233            .send(&encoded, registry_addr)
234            .map_err(|err| RegistryError::backend(format!("send failed: {err:?}")))?;
235
236        let mut buffer = handle.acquire_buffer();
237        let response = loop {
238            match handle.receive(&mut buffer) {
239                Ok((_len, _addr)) => {
240                    let payload = buffer.as_slice().to_vec();
241                    match Message::decode(payload) {
242                        Ok(response) => {
243                            if response.message_id() == message_id {
244                                break response;
245                            }
246                        }
247                        Err(err) => {
248                            return Err(RegistryError::backend(format!(
249                                "failed to decode registry response: {err:?}"
250                            )));
251                        }
252                    }
253                }
254                Err(SocketError::Io(err))
255                    if matches!(
256                        err.kind(),
257                        ErrorKind::WouldBlock | ErrorKind::TimedOut | ErrorKind::Interrupted
258                    ) =>
259                {
260                    if err.kind() == ErrorKind::Interrupted {
261                        debug!("registry receive interrupted; retrying");
262                        continue;
263                    }
264                    return Err(RegistryError::backend(
265                        "timed out waiting for registry response",
266                    ));
267                }
268                Err(SocketError::Io(err)) => {
269                    return Err(RegistryError::backend(format!(
270                        "registry receive failed: {err:?}"
271                    )));
272                }
273            }
274        };
275
276        Ok(response)
277    }
278
279    async fn send_request(&self, message: Message) -> RegistryResult<Message> {
280        let handle = self.handle.clone();
281        let registry_addr = self.registry_addr;
282        tokio::task::spawn_blocking(move || {
283            Self::send_request_blocking(&handle, registry_addr, &message)
284        })
285        .await
286        .map_err(|err| RegistryError::backend(format!("registry task join error: {err:?}")))?
287    }
288
289    fn handle_error_message(message: &Message) -> RegistryResult<()> {
290        let payload =
291            serde_json::from_slice::<ErrorResponse>(message.payload()).map_err(|err| {
292                RegistryError::backend(format!("failed to parse registry error payload: {err:?}"))
293            })?;
294        Err(RegistryError::backend(payload.error))
295    }
296}
297
298fn default_transport_config() -> TransportConfig {
299    TransportConfig {
300        buffer_size: 16 * 1024,
301        max_buffers: 256,
302        read_timeout: Some(Duration::from_secs(5)),
303        write_timeout: Some(Duration::from_secs(5)),
304    }
305}
306
307/// Trait implemented by discovery/registry backends.
308#[async_trait]
309pub trait AgentRegistry: Send + Sync {
310    /// Registers an agent manifest with the mesh.
311    async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()>;
312
313    /// Sends a heartbeat for an already registered agent.
314    async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()>;
315
316    /// Removes the agent from the registry.
317    async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()>;
318}
319
320pub(crate) struct RegistrationController {
321    registry: Arc<dyn AgentRegistry>,
322    manifest: Arc<AgentManifest>,
323    config: RegistrationConfig,
324    shutdown: Arc<AtomicBool>,
325    worker: Option<JoinHandle<()>>,
326}
327
328impl fmt::Debug for RegistrationController {
329    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330        f.debug_struct("RegistrationController")
331            .field("registry", &"dyn AgentRegistry")
332            .field("manifest", &self.manifest.id())
333            .field("config", &self.config)
334            .field("shutdown", &self.shutdown.load(Ordering::Relaxed))
335            .field("worker", &self.worker.is_some())
336            .finish()
337    }
338}
339
340impl RegistrationController {
341    pub(crate) fn new(
342        registry: Arc<dyn AgentRegistry>,
343        manifest: AgentManifest,
344        config: RegistrationConfig,
345    ) -> Self {
346        Self {
347            registry,
348            manifest: Arc::new(manifest),
349            config,
350            shutdown: Arc::new(AtomicBool::new(false)),
351            worker: None,
352        }
353    }
354
355    pub(crate) fn on_state_change(
356        &mut self,
357        state: AgentState,
358        scheduler: &TaskScheduler,
359    ) -> RegistryResult<()> {
360        match state {
361            AgentState::Ready | AgentState::Active => {
362                self.ensure_worker(scheduler)?;
363            }
364            AgentState::Retiring | AgentState::Terminated => {
365                self.shutdown.store(true, Ordering::Release);
366                self.spawn_deregister(scheduler)?;
367                if let Some(handle) = self.worker.take() {
368                    handle.abort();
369                }
370            }
371            _ => {}
372        }
373
374        Ok(())
375    }
376
377    fn ensure_worker(&mut self, scheduler: &TaskScheduler) -> RegistryResult<()> {
378        if self.worker.is_some() {
379            return Ok(());
380        }
381
382        self.config.validate()?;
383
384        let registry = Arc::clone(&self.registry);
385        let manifest = Arc::clone(&self.manifest);
386        let shutdown = Arc::clone(&self.shutdown);
387        let config = self.config;
388
389        let handle = scheduler.spawn(async move {
390            run_registration_loop(registry, manifest, shutdown, config).await;
391        })?;
392
393        self.worker = Some(handle);
394        Ok(())
395    }
396
397    fn spawn_deregister(&self, scheduler: &TaskScheduler) -> RegistryResult<()> {
398        let registry = Arc::clone(&self.registry);
399        let manifest = Arc::clone(&self.manifest);
400        scheduler.spawn(async move {
401            if let Err(err) = registry.deregister(&manifest).await {
402                warn!(?err, "agent deregistration failed");
403            } else {
404                info!(agent_id = %manifest.id(), "agent deregistered");
405            }
406        })?;
407        Ok(())
408    }
409}
410
411async fn run_registration_loop(
412    registry: Arc<dyn AgentRegistry>,
413    manifest: Arc<AgentManifest>,
414    shutdown: Arc<AtomicBool>,
415    config: RegistrationConfig,
416) {
417    let mut retry_delay = config.initial_retry_delay();
418
419    loop {
420        if shutdown.load(Ordering::Acquire) {
421            break;
422        }
423
424        match registry.register(&manifest).await {
425            Ok(()) => {
426                info!(agent_id = %manifest.id(), "agent registered with mesh");
427                retry_delay = config.initial_retry_delay();
428                if !run_heartbeat_loop(
429                    Arc::clone(&registry),
430                    Arc::clone(&manifest),
431                    Arc::clone(&shutdown),
432                    config,
433                )
434                .await
435                {
436                    continue;
437                }
438                break;
439            }
440            Err(err) => {
441                warn!(?err, "agent registration failed; retrying");
442                sleep(retry_delay).await;
443                retry_delay = (retry_delay * 2).min(config.max_retry_delay());
444            }
445        }
446    }
447}
448
449async fn run_heartbeat_loop(
450    registry: Arc<dyn AgentRegistry>,
451    manifest: Arc<AgentManifest>,
452    shutdown: Arc<AtomicBool>,
453    config: RegistrationConfig,
454) -> bool {
455    let mut failures: usize = 0;
456    let mut interval = tokio::time::interval(config.heartbeat_interval());
457    interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
458
459    while !shutdown.load(Ordering::Acquire) {
460        interval.tick().await;
461        if shutdown.load(Ordering::Acquire) {
462            break;
463        }
464
465        match registry.heartbeat(&manifest).await {
466            Ok(()) => {
467                failures = 0;
468            }
469            Err(err) => {
470                failures += 1;
471                warn!(?err, failures, "heartbeat failure");
472                if failures >= config.max_consecutive_failures().get() {
473                    warn!(
474                        failures,
475                        "heartbeat failure threshold reached; attempting re-registration"
476                    );
477                    return false;
478                }
479            }
480        }
481    }
482
483    true
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use std::sync::atomic::AtomicUsize;
490
491    use agent_primitives::{AgentId, Capability, CapabilityId};
492
493    struct MockRegistry {
494        registers: Arc<AtomicUsize>,
495        heartbeats: Arc<AtomicUsize>,
496        deregistrations: Arc<AtomicUsize>,
497    }
498
499    #[async_trait]
500    impl AgentRegistry for MockRegistry {
501        async fn register(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
502            self.registers.fetch_add(1, Ordering::SeqCst);
503            Ok(())
504        }
505
506        async fn heartbeat(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
507            self.heartbeats.fetch_add(1, Ordering::SeqCst);
508            Ok(())
509        }
510
511        async fn deregister(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
512            self.deregistrations.fetch_add(1, Ordering::SeqCst);
513            Ok(())
514        }
515    }
516
517    fn manifest() -> AgentManifest {
518        let capability = Capability::builder(CapabilityId::new("mock.cap").unwrap())
519            .name("Mock")
520            .unwrap()
521            .version("1.0.0")
522            .unwrap()
523            .add_scope("read:mock")
524            .unwrap()
525            .build()
526            .unwrap();
527
528        AgentManifest::builder(AgentId::random())
529            .name("mock-agent")
530            .unwrap()
531            .version("0.1.0")
532            .unwrap()
533            .capabilities(vec![capability])
534            .build()
535            .unwrap()
536    }
537
538    #[tokio::test]
539    async fn lifecycle_starts_and_stops_heartbeat() {
540        let registry = Arc::new(MockRegistry {
541            registers: Arc::new(AtomicUsize::new(0)),
542            heartbeats: Arc::new(AtomicUsize::new(0)),
543            deregistrations: Arc::new(AtomicUsize::new(0)),
544        });
545
546        let manifest = manifest();
547        let config = RegistrationConfig::new(
548            Duration::from_millis(10),
549            Duration::from_millis(5),
550            Duration::from_millis(20),
551            NonZeroUsize::new(3).unwrap(),
552        );
553
554        let mut controller = RegistrationController::new(registry.clone(), manifest, config);
555        let scheduler = TaskScheduler::default();
556
557        controller
558            .on_state_change(AgentState::Ready, &scheduler)
559            .unwrap();
560
561        tokio::time::sleep(Duration::from_millis(40)).await;
562
563        assert!(registry.registers.load(Ordering::SeqCst) >= 1);
564        assert!(registry.heartbeats.load(Ordering::SeqCst) >= 1);
565
566        controller
567            .on_state_change(AgentState::Retiring, &scheduler)
568            .unwrap();
569        tokio::time::sleep(Duration::from_millis(20)).await;
570        assert!(registry.deregistrations.load(Ordering::SeqCst) >= 1);
571    }
572}
573
574#[async_trait]
575impl AgentRegistry for MxpRegistryClient {
576    async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()> {
577        let request = self.manifest_to_register_request(manifest);
578        let payload = serde_json::to_vec(&request)
579            .map_err(|err| RegistryError::backend(format!("encode register payload: {err:?}")))?;
580        let message = Message::new(MessageType::AgentRegister, payload);
581        let response = self.send_request(message).await?;
582
583        match response.message_type() {
584            Some(MessageType::Response) => {
585                let ack = serde_json::from_slice::<RegisterResponse>(response.payload()).map_err(
586                    |err| {
587                        RegistryError::backend(format!("parse register response failed: {err:?}"))
588                    },
589                )?;
590                if ack.success {
591                    debug!(agent_id = ack.agent_id, "registry registration acked");
592                    Ok(())
593                } else {
594                    Err(RegistryError::backend(format!(
595                        "registry rejected registration: {}",
596                        ack.message
597                    )))
598                }
599            }
600            Some(MessageType::Error) => Self::handle_error_message(&response),
601            other => Err(RegistryError::backend(format!(
602                "unexpected message type {other:?} for register response"
603            ))),
604        }
605    }
606
607    async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()> {
608        let request = HeartbeatRequest {
609            agent_id: Self::agent_id(manifest),
610        };
611        let payload = serde_json::to_vec(&request)
612            .map_err(|err| RegistryError::backend(format!("encode heartbeat payload: {err:?}")))?;
613        let message = Message::new(MessageType::AgentHeartbeat, payload);
614        let response = self.send_request(message).await?;
615
616        match response.message_type() {
617            Some(MessageType::Response) => {
618                let ack = serde_json::from_slice::<HeartbeatResponse>(response.payload()).map_err(
619                    |err| {
620                        RegistryError::backend(format!("parse heartbeat response failed: {err:?}"))
621                    },
622                )?;
623                if ack.success && !ack.needs_register {
624                    Ok(())
625                } else if ack.needs_register {
626                    Err(RegistryError::backend("registry requested re-registration"))
627                } else {
628                    Err(RegistryError::backend(
629                        ack.message
630                            .unwrap_or_else(|| "heartbeat rejected".to_string()),
631                    ))
632                }
633            }
634            Some(MessageType::Error) => Self::handle_error_message(&response),
635            other => Err(RegistryError::backend(format!(
636                "unexpected message type {other:?} for heartbeat response"
637            ))),
638        }
639    }
640
641    async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()> {
642        let request = HeartbeatRequest {
643            agent_id: Self::agent_id(manifest),
644        };
645        let payload = serde_json::to_vec(&request)
646            .map_err(|err| RegistryError::backend(format!("encode deregister payload: {err:?}")))?;
647        let mut message = Message::new(MessageType::AgentHeartbeat, payload);
648        message.set_flags(message.flags().with(Flags::FINAL));
649        let response = self.send_request(message).await?;
650
651        match response.message_type() {
652            Some(MessageType::Response) => {
653                let ack = serde_json::from_slice::<HeartbeatResponse>(response.payload()).map_err(
654                    |err| {
655                        RegistryError::backend(format!("parse deregister response failed: {err:?}"))
656                    },
657                )?;
658                if ack.success {
659                    Ok(())
660                } else {
661                    Err(RegistryError::backend(
662                        ack.message
663                            .unwrap_or_else(|| "deregister failed".to_string()),
664                    ))
665                }
666            }
667            Some(MessageType::Error) => Self::handle_error_message(&response),
668            other => Err(RegistryError::backend(format!(
669                "unexpected message type {other:?} for deregister response"
670            ))),
671        }
672    }
673}