1use std::fmt;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8
9use agent_primitives::AgentManifest;
10use async_trait::async_trait;
11use thiserror::Error;
12use tokio::task::JoinHandle;
13use tokio::time::{MissedTickBehavior, sleep};
14use tracing::{info, warn};
15
16use crate::{AgentState, SchedulerError, TaskScheduler};
17
18#[derive(Debug, Clone, Copy)]
20pub struct RegistrationConfig {
21 heartbeat_interval: Duration,
22 initial_retry_delay: Duration,
23 max_retry_delay: Duration,
24 max_consecutive_failures: NonZeroUsize,
25}
26
27impl RegistrationConfig {
28 #[must_use]
30 pub fn new(
31 heartbeat_interval: Duration,
32 initial_retry_delay: Duration,
33 max_retry_delay: Duration,
34 max_consecutive_failures: NonZeroUsize,
35 ) -> Self {
36 Self {
37 heartbeat_interval,
38 initial_retry_delay,
39 max_retry_delay,
40 max_consecutive_failures,
41 }
42 }
43
44 #[must_use]
46 pub const fn heartbeat_interval(self) -> Duration {
47 self.heartbeat_interval
48 }
49
50 #[must_use]
52 pub const fn initial_retry_delay(self) -> Duration {
53 self.initial_retry_delay
54 }
55
56 #[must_use]
58 pub const fn max_retry_delay(self) -> Duration {
59 self.max_retry_delay
60 }
61
62 #[must_use]
64 pub const fn max_consecutive_failures(self) -> NonZeroUsize {
65 self.max_consecutive_failures
66 }
67
68 pub fn validate(self) -> RegistryResult<()> {
75 if self.heartbeat_interval.is_zero() {
76 return Err(RegistryError::InvalidConfig(
77 "heartbeat interval must be greater than zero",
78 ));
79 }
80 if self.initial_retry_delay.is_zero() {
81 return Err(RegistryError::InvalidConfig(
82 "initial retry delay must be greater than zero",
83 ));
84 }
85 if self.max_retry_delay.is_zero() {
86 return Err(RegistryError::InvalidConfig(
87 "max retry delay must be greater than zero",
88 ));
89 }
90 if self.initial_retry_delay > self.max_retry_delay {
91 return Err(RegistryError::InvalidConfig(
92 "initial retry delay cannot exceed max retry delay",
93 ));
94 }
95 Ok(())
96 }
97}
98
99impl Default for RegistrationConfig {
100 fn default() -> Self {
101 Self {
102 heartbeat_interval: Duration::from_secs(10),
103 initial_retry_delay: Duration::from_secs(1),
104 max_retry_delay: Duration::from_secs(30),
105 max_consecutive_failures: NonZeroUsize::new(3).expect("non-zero"),
106 }
107 }
108}
109
110pub type RegistryResult<T> = Result<T, RegistryError>;
112
113#[derive(Debug, Error)]
115pub enum RegistryError {
116 #[error("invalid registration configuration: {0}")]
118 InvalidConfig(&'static str),
119 #[error(transparent)]
121 Scheduler(#[from] SchedulerError),
122 #[error("registry backend error: {reason}")]
124 Backend {
125 reason: String,
127 },
128}
129
130impl RegistryError {
131 #[must_use]
133 pub fn backend(reason: impl Into<String>) -> Self {
134 Self::Backend {
135 reason: reason.into(),
136 }
137 }
138}
139
140#[async_trait]
142pub trait AgentRegistry: Send + Sync {
143 async fn register(&self, manifest: &AgentManifest) -> RegistryResult<()>;
145
146 async fn heartbeat(&self, manifest: &AgentManifest) -> RegistryResult<()>;
148
149 async fn deregister(&self, manifest: &AgentManifest) -> RegistryResult<()>;
151}
152
153pub(crate) struct RegistrationController {
154 registry: Arc<dyn AgentRegistry>,
155 manifest: Arc<AgentManifest>,
156 config: RegistrationConfig,
157 shutdown: Arc<AtomicBool>,
158 worker: Option<JoinHandle<()>>,
159}
160
161impl fmt::Debug for RegistrationController {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("RegistrationController")
164 .field("registry", &"dyn AgentRegistry")
165 .field("manifest", &self.manifest.id())
166 .field("config", &self.config)
167 .field("shutdown", &self.shutdown.load(Ordering::Relaxed))
168 .field("worker", &self.worker.is_some())
169 .finish()
170 }
171}
172
173impl RegistrationController {
174 pub(crate) fn new(
175 registry: Arc<dyn AgentRegistry>,
176 manifest: AgentManifest,
177 config: RegistrationConfig,
178 ) -> Self {
179 Self {
180 registry,
181 manifest: Arc::new(manifest),
182 config,
183 shutdown: Arc::new(AtomicBool::new(false)),
184 worker: None,
185 }
186 }
187
188 pub(crate) fn on_state_change(
189 &mut self,
190 state: AgentState,
191 scheduler: &TaskScheduler,
192 ) -> RegistryResult<()> {
193 match state {
194 AgentState::Ready | AgentState::Active => {
195 self.ensure_worker(scheduler)?;
196 }
197 AgentState::Retiring | AgentState::Terminated => {
198 self.shutdown.store(true, Ordering::Release);
199 self.spawn_deregister(scheduler)?;
200 if let Some(handle) = self.worker.take() {
201 handle.abort();
202 }
203 }
204 _ => {}
205 }
206
207 Ok(())
208 }
209
210 fn ensure_worker(&mut self, scheduler: &TaskScheduler) -> RegistryResult<()> {
211 if self.worker.is_some() {
212 return Ok(());
213 }
214
215 self.config.validate()?;
216
217 let registry = Arc::clone(&self.registry);
218 let manifest = Arc::clone(&self.manifest);
219 let shutdown = Arc::clone(&self.shutdown);
220 let config = self.config;
221
222 let handle = scheduler.spawn(async move {
223 run_registration_loop(registry, manifest, shutdown, config).await;
224 })?;
225
226 self.worker = Some(handle);
227 Ok(())
228 }
229
230 fn spawn_deregister(&self, scheduler: &TaskScheduler) -> RegistryResult<()> {
231 let registry = Arc::clone(&self.registry);
232 let manifest = Arc::clone(&self.manifest);
233 scheduler.spawn(async move {
234 if let Err(err) = registry.deregister(&manifest).await {
235 warn!(?err, "agent deregistration failed");
236 } else {
237 info!(agent_id = %manifest.id(), "agent deregistered");
238 }
239 })?;
240 Ok(())
241 }
242}
243
244async fn run_registration_loop(
245 registry: Arc<dyn AgentRegistry>,
246 manifest: Arc<AgentManifest>,
247 shutdown: Arc<AtomicBool>,
248 config: RegistrationConfig,
249) {
250 let mut retry_delay = config.initial_retry_delay();
251
252 loop {
253 if shutdown.load(Ordering::Acquire) {
254 break;
255 }
256
257 match registry.register(&manifest).await {
258 Ok(()) => {
259 info!(agent_id = %manifest.id(), "agent registered with mesh");
260 retry_delay = config.initial_retry_delay();
261 if !run_heartbeat_loop(
262 Arc::clone(®istry),
263 Arc::clone(&manifest),
264 Arc::clone(&shutdown),
265 config,
266 )
267 .await
268 {
269 continue;
270 }
271 break;
272 }
273 Err(err) => {
274 warn!(?err, "agent registration failed; retrying");
275 sleep(retry_delay).await;
276 retry_delay = (retry_delay * 2).min(config.max_retry_delay());
277 }
278 }
279 }
280}
281
282async fn run_heartbeat_loop(
283 registry: Arc<dyn AgentRegistry>,
284 manifest: Arc<AgentManifest>,
285 shutdown: Arc<AtomicBool>,
286 config: RegistrationConfig,
287) -> bool {
288 let mut failures: usize = 0;
289 let mut interval = tokio::time::interval(config.heartbeat_interval());
290 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
291
292 while !shutdown.load(Ordering::Acquire) {
293 interval.tick().await;
294 if shutdown.load(Ordering::Acquire) {
295 break;
296 }
297
298 match registry.heartbeat(&manifest).await {
299 Ok(()) => {
300 failures = 0;
301 }
302 Err(err) => {
303 failures += 1;
304 warn!(?err, failures, "heartbeat failure");
305 if failures >= config.max_consecutive_failures().get() {
306 warn!(
307 failures,
308 "heartbeat failure threshold reached; attempting re-registration"
309 );
310 return false;
311 }
312 }
313 }
314 }
315
316 true
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use std::sync::atomic::AtomicUsize;
323
324 use agent_primitives::{AgentId, Capability, CapabilityId};
325
326 struct MockRegistry {
327 registers: Arc<AtomicUsize>,
328 heartbeats: Arc<AtomicUsize>,
329 deregistrations: Arc<AtomicUsize>,
330 }
331
332 #[async_trait]
333 impl AgentRegistry for MockRegistry {
334 async fn register(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
335 self.registers.fetch_add(1, Ordering::SeqCst);
336 Ok(())
337 }
338
339 async fn heartbeat(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
340 self.heartbeats.fetch_add(1, Ordering::SeqCst);
341 Ok(())
342 }
343
344 async fn deregister(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
345 self.deregistrations.fetch_add(1, Ordering::SeqCst);
346 Ok(())
347 }
348 }
349
350 fn manifest() -> AgentManifest {
351 let capability = Capability::builder(CapabilityId::new("mock.cap").unwrap())
352 .name("Mock")
353 .unwrap()
354 .version("1.0.0")
355 .unwrap()
356 .add_scope("read:mock")
357 .unwrap()
358 .build()
359 .unwrap();
360
361 AgentManifest::builder(AgentId::random())
362 .name("mock-agent")
363 .unwrap()
364 .version("0.1.0")
365 .unwrap()
366 .capabilities(vec![capability])
367 .build()
368 .unwrap()
369 }
370
371 #[tokio::test]
372 async fn lifecycle_starts_and_stops_heartbeat() {
373 let registry = Arc::new(MockRegistry {
374 registers: Arc::new(AtomicUsize::new(0)),
375 heartbeats: Arc::new(AtomicUsize::new(0)),
376 deregistrations: Arc::new(AtomicUsize::new(0)),
377 });
378
379 let manifest = manifest();
380 let config = RegistrationConfig::new(
381 Duration::from_millis(10),
382 Duration::from_millis(5),
383 Duration::from_millis(20),
384 NonZeroUsize::new(3).unwrap(),
385 );
386
387 let mut controller = RegistrationController::new(registry.clone(), manifest, config);
388 let scheduler = TaskScheduler::default();
389
390 controller
391 .on_state_change(AgentState::Ready, &scheduler)
392 .unwrap();
393
394 tokio::time::sleep(Duration::from_millis(40)).await;
395
396 assert!(registry.registers.load(Ordering::SeqCst) >= 1);
397 assert!(registry.heartbeats.load(Ordering::SeqCst) >= 1);
398
399 controller
400 .on_state_change(AgentState::Retiring, &scheduler)
401 .unwrap();
402 tokio::time::sleep(Duration::from_millis(20)).await;
403 assert!(registry.deregistrations.load(Ordering::SeqCst) >= 1);
404 }
405}