1#![warn(missing_docs, clippy::pedantic)]
7
8mod call;
9mod lifecycle;
10mod mxp_handlers;
11mod recovery;
12mod registry;
13mod registry_wire;
14mod scheduler;
15mod shutdown;
16mod validation;
17
18use std::sync::Arc;
19
20use agent_primitives::{AgentId, AgentManifest};
21use mxp::Message;
22use mxp_handlers::dispatch_message;
23use thiserror::Error;
24use tokio::task::JoinHandle;
25use tracing::warn;
26
27pub use call::{
28 AuditEmitter, CallExecutor, CallOutcome, CallOutcomeSink, CollectingSink,
29 CompositeAuditEmitter, CompositePolicyObserver, GovernanceAuditEmitter, KernelMessageHandler,
30 KernelMessageHandlerBuilder, MxpAuditObserver, PolicyObserver, ToolInvocationResult,
31 TracingAuditEmitter, TracingCallSink, TracingPolicyObserver,
32};
33pub use lifecycle::{AgentState, Lifecycle, LifecycleError, LifecycleEvent, LifecycleResult};
34pub use mxp_handlers::{AgentMessageHandler, HandlerContext, HandlerError, HandlerResult};
35pub use recovery::{
36 CheckpointMetadata, RecoveryError, RecoveryResult, StateCheckpoint, StateRecovery,
37};
38pub use registry::{
39 AgentRegistry, MxpRegistryClient, RegistrationConfig, RegistryError, RegistryResult,
40};
41pub use registry_wire::{
42 AgentRecord, AgentStatus as WireAgentStatus, DiscoverRequest, DiscoverResponse, ErrorResponse,
43 HeartbeatRequest, HeartbeatResponse, RegisterRequest, RegisterResponse,
44};
45pub use scheduler::{SchedulerConfig, SchedulerError, SchedulerResult, TaskScheduler};
46pub use shutdown::{ShutdownCoordinator, ShutdownError, ShutdownResult, ShutdownState, WorkGuard};
47pub use validation::{ValidationConfig, ValidationError, ValidationResult, validate_message};
48
49use registry::RegistrationController;
50
51#[derive(Debug)]
53pub struct AgentKernel<H>
54where
55 H: AgentMessageHandler + 'static,
56{
57 agent_id: AgentId,
58 lifecycle: Lifecycle,
59 handler: Arc<H>,
60 scheduler: TaskScheduler,
61 registry: Option<RegistrationController>,
62 shutdown: Arc<ShutdownCoordinator>,
63 recovery: Option<Arc<recovery::StateRecovery>>,
64}
65
66impl<H> AgentKernel<H>
67where
68 H: AgentMessageHandler + 'static,
69{
70 #[must_use]
72 pub fn new(agent_id: AgentId, handler: Arc<H>, scheduler: TaskScheduler) -> Self {
73 Self {
74 agent_id,
75 lifecycle: Lifecycle::new(agent_id),
76 handler,
77 scheduler,
78 registry: None,
79 shutdown: Arc::new(ShutdownCoordinator::new(std::time::Duration::from_secs(30))),
80 recovery: None,
81 }
82 }
83
84 pub fn set_registry<R>(
86 &mut self,
87 registry: Arc<R>,
88 manifest: AgentManifest,
89 config: RegistrationConfig,
90 ) where
91 R: AgentRegistry + 'static,
92 {
93 let registry: Arc<dyn AgentRegistry> = registry;
94 self.registry = Some(RegistrationController::new(registry, manifest, config));
95 }
96
97 pub fn set_recovery(&mut self, recovery: Arc<recovery::StateRecovery>) {
99 self.recovery = Some(recovery);
100 }
101
102 #[must_use]
104 pub const fn agent_id(&self) -> AgentId {
105 self.agent_id
106 }
107
108 #[must_use]
110 pub fn state(&self) -> AgentState {
111 self.lifecycle.state()
112 }
113
114 pub fn transition(&mut self, event: LifecycleEvent) -> KernelResult<AgentState> {
121 let state = self.lifecycle.transition(event)?;
122 if let Some(controller) = &mut self.registry
123 && let Err(err) = controller.on_state_change(state, &self.scheduler)
124 {
125 warn!(?err, "registry hook failed during state transition");
126 return Err(err.into());
127 }
128
129 Ok(state)
130 }
131
132 pub async fn handle_message(&self, message: Message) -> HandlerResult {
138 let _guard = self.shutdown.register_work().ok();
139 let ctx = HandlerContext::from_message(self.agent_id, message);
140 dispatch_message(self.handler.as_ref(), ctx).await
141 }
142
143 pub fn schedule_message(&self, message: Message) -> SchedulerResult<JoinHandle<HandlerResult>> {
149 let handler = Arc::clone(&self.handler);
150 let agent_id = self.agent_id;
151 let shutdown = Arc::clone(&self.shutdown);
152 self.scheduler.spawn(async move {
153 let _guard = shutdown.register_work().ok();
154 let ctx = HandlerContext::from_message(agent_id, message);
155 dispatch_message(handler.as_ref(), ctx).await
156 })
157 }
158
159 #[must_use]
161 pub fn scheduler(&self) -> &TaskScheduler {
162 &self.scheduler
163 }
164
165 #[must_use]
167 pub fn shutdown(&self) -> &ShutdownCoordinator {
168 &self.shutdown
169 }
170
171 pub async fn graceful_shutdown(&mut self) -> ShutdownResult<()> {
180 self.transition(LifecycleEvent::Retire).ok();
181
182 if let Some(recovery) = &self.recovery {
184 let checkpoint = recovery::StateRecovery::create_checkpoint(
185 self.agent_id,
186 self.lifecycle.state(),
187 serde_json::json!({}),
188 self.shutdown.in_flight_count(),
189 );
190
191 if let Err(e) = recovery.persist_checkpoint(&checkpoint).await {
192 warn!("failed to persist checkpoint: {}", e);
193 }
194 }
195
196 self.shutdown.shutdown().await
197 }
198}
199
200#[derive(Debug, Error)]
202pub enum KernelError {
203 #[error(transparent)]
205 Lifecycle(#[from] LifecycleError),
206 #[error(transparent)]
208 Registry(#[from] RegistryError),
209}
210
211pub type KernelResult<T> = Result<T, KernelError>;
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use std::num::NonZeroUsize;
218 use std::sync::Arc;
219 use std::sync::atomic::{AtomicUsize, Ordering};
220 use std::time::Duration;
221
222 use agent_primitives::{Capability, CapabilityId};
223
224 struct NullHandler;
225
226 impl AgentMessageHandler for NullHandler {}
227
228 #[derive(Default)]
229 struct CountingRegistry {
230 registers: Arc<AtomicUsize>,
231 heartbeats: Arc<AtomicUsize>,
232 deregisters: Arc<AtomicUsize>,
233 }
234
235 #[async_trait::async_trait]
236 impl AgentRegistry for CountingRegistry {
237 async fn register(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
238 self.registers.fetch_add(1, Ordering::SeqCst);
239 Ok(())
240 }
241
242 async fn heartbeat(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
243 self.heartbeats.fetch_add(1, Ordering::SeqCst);
244 Ok(())
245 }
246
247 async fn deregister(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
248 self.deregisters.fetch_add(1, Ordering::SeqCst);
249 Ok(())
250 }
251 }
252
253 fn capability() -> Capability {
254 Capability::builder(CapabilityId::new("kernel.test").unwrap())
255 .name("Test")
256 .unwrap()
257 .version("1.0.0")
258 .unwrap()
259 .add_scope("read:test")
260 .unwrap()
261 .build()
262 .unwrap()
263 }
264
265 fn manifest() -> AgentManifest {
266 AgentManifest::builder(AgentId::random())
267 .name("kernel-agent")
268 .unwrap()
269 .version("0.0.1")
270 .unwrap()
271 .capabilities(vec![capability()])
272 .build()
273 .unwrap()
274 }
275
276 #[tokio::test]
277 async fn registry_hooks_trigger_lifecycle_actions() {
278 let scheduler = TaskScheduler::default();
279 let handler = Arc::new(NullHandler);
280 let mut kernel = AgentKernel::new(AgentId::random(), handler, scheduler.clone());
281
282 let registry = Arc::new(CountingRegistry::default());
283 let config = RegistrationConfig::new(
284 Duration::from_millis(10),
285 Duration::from_millis(5),
286 Duration::from_millis(20),
287 NonZeroUsize::new(3).unwrap(),
288 );
289 kernel.set_registry(registry.clone(), manifest(), config);
290
291 kernel.transition(LifecycleEvent::Boot).unwrap();
292 kernel.transition(LifecycleEvent::Activate).unwrap();
293
294 tokio::time::sleep(Duration::from_millis(35)).await;
295 assert!(registry.registers.load(Ordering::SeqCst) >= 1);
296 assert!(registry.heartbeats.load(Ordering::SeqCst) >= 1);
297
298 kernel.transition(LifecycleEvent::Retire).unwrap();
299 kernel.transition(LifecycleEvent::Terminate).unwrap();
300 tokio::time::sleep(Duration::from_millis(20)).await;
301 assert!(registry.deregisters.load(Ordering::SeqCst) >= 1);
302 }
303}