1#![warn(missing_docs, clippy::pedantic)]
7
8mod call;
9mod lifecycle;
10mod mxp_handlers;
11mod registry;
12mod registry_wire;
13mod scheduler;
14
15use std::sync::Arc;
16
17use agent_primitives::{AgentId, AgentManifest};
18use mxp::Message;
19use mxp_handlers::dispatch_message;
20use thiserror::Error;
21use tokio::task::JoinHandle;
22use tracing::warn;
23
24pub use call::{
25 AuditEmitter, CallExecutor, CallOutcome, CallOutcomeSink, CollectingSink,
26 CompositeAuditEmitter, CompositePolicyObserver, GovernanceAuditEmitter, KernelMessageHandler,
27 KernelMessageHandlerBuilder, MxpAuditObserver, PolicyObserver, ToolInvocationResult,
28 TracingAuditEmitter, TracingCallSink, TracingPolicyObserver,
29};
30pub use lifecycle::{AgentState, Lifecycle, LifecycleError, LifecycleEvent, LifecycleResult};
31pub use mxp_handlers::{AgentMessageHandler, HandlerContext, HandlerError, HandlerResult};
32pub use registry::{
33 AgentRegistry, MxpRegistryClient, RegistrationConfig, RegistryError, RegistryResult,
34};
35pub use registry_wire::{
36 AgentRecord, AgentStatus as WireAgentStatus, DiscoverRequest, DiscoverResponse, ErrorResponse,
37 HeartbeatRequest, HeartbeatResponse, RegisterRequest, RegisterResponse,
38};
39pub use scheduler::{SchedulerConfig, SchedulerError, SchedulerResult, TaskScheduler};
40
41use registry::RegistrationController;
42
43#[derive(Debug)]
45pub struct AgentKernel<H>
46where
47 H: AgentMessageHandler + 'static,
48{
49 agent_id: AgentId,
50 lifecycle: Lifecycle,
51 handler: Arc<H>,
52 scheduler: TaskScheduler,
53 registry: Option<RegistrationController>,
54}
55
56impl<H> AgentKernel<H>
57where
58 H: AgentMessageHandler + 'static,
59{
60 #[must_use]
62 pub fn new(agent_id: AgentId, handler: Arc<H>, scheduler: TaskScheduler) -> Self {
63 Self {
64 agent_id,
65 lifecycle: Lifecycle::new(agent_id),
66 handler,
67 scheduler,
68 registry: None,
69 }
70 }
71
72 pub fn set_registry<R>(
74 &mut self,
75 registry: Arc<R>,
76 manifest: AgentManifest,
77 config: RegistrationConfig,
78 ) where
79 R: AgentRegistry + 'static,
80 {
81 let registry: Arc<dyn AgentRegistry> = registry;
82 self.registry = Some(RegistrationController::new(registry, manifest, config));
83 }
84
85 #[must_use]
87 pub const fn agent_id(&self) -> AgentId {
88 self.agent_id
89 }
90
91 #[must_use]
93 pub fn state(&self) -> AgentState {
94 self.lifecycle.state()
95 }
96
97 pub fn transition(&mut self, event: LifecycleEvent) -> KernelResult<AgentState> {
104 let state = self.lifecycle.transition(event)?;
105 if let Some(controller) = &mut self.registry
106 && let Err(err) = controller.on_state_change(state, &self.scheduler)
107 {
108 warn!(?err, "registry hook failed during state transition");
109 return Err(err.into());
110 }
111
112 Ok(state)
113 }
114
115 pub async fn handle_message(&self, message: Message) -> HandlerResult {
121 let ctx = HandlerContext::from_message(self.agent_id, message);
122 dispatch_message(self.handler.as_ref(), ctx).await
123 }
124
125 pub fn schedule_message(&self, message: Message) -> SchedulerResult<JoinHandle<HandlerResult>> {
131 let handler = Arc::clone(&self.handler);
132 let agent_id = self.agent_id;
133 self.scheduler.spawn(async move {
134 let ctx = HandlerContext::from_message(agent_id, message);
135 dispatch_message(handler.as_ref(), ctx).await
136 })
137 }
138
139 #[must_use]
141 pub fn scheduler(&self) -> &TaskScheduler {
142 &self.scheduler
143 }
144}
145
146#[derive(Debug, Error)]
148pub enum KernelError {
149 #[error(transparent)]
151 Lifecycle(#[from] LifecycleError),
152 #[error(transparent)]
154 Registry(#[from] RegistryError),
155}
156
157pub type KernelResult<T> = Result<T, KernelError>;
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use std::num::NonZeroUsize;
164 use std::sync::Arc;
165 use std::sync::atomic::{AtomicUsize, Ordering};
166 use std::time::Duration;
167
168 use agent_primitives::{Capability, CapabilityId};
169
170 struct NullHandler;
171
172 impl AgentMessageHandler for NullHandler {}
173
174 #[derive(Default)]
175 struct CountingRegistry {
176 registers: Arc<AtomicUsize>,
177 heartbeats: Arc<AtomicUsize>,
178 deregisters: Arc<AtomicUsize>,
179 }
180
181 #[async_trait::async_trait]
182 impl AgentRegistry for CountingRegistry {
183 async fn register(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
184 self.registers.fetch_add(1, Ordering::SeqCst);
185 Ok(())
186 }
187
188 async fn heartbeat(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
189 self.heartbeats.fetch_add(1, Ordering::SeqCst);
190 Ok(())
191 }
192
193 async fn deregister(&self, _manifest: &AgentManifest) -> RegistryResult<()> {
194 self.deregisters.fetch_add(1, Ordering::SeqCst);
195 Ok(())
196 }
197 }
198
199 fn capability() -> Capability {
200 Capability::builder(CapabilityId::new("kernel.test").unwrap())
201 .name("Test")
202 .unwrap()
203 .version("1.0.0")
204 .unwrap()
205 .add_scope("read:test")
206 .unwrap()
207 .build()
208 .unwrap()
209 }
210
211 fn manifest() -> AgentManifest {
212 AgentManifest::builder(AgentId::random())
213 .name("kernel-agent")
214 .unwrap()
215 .version("0.0.1")
216 .unwrap()
217 .capabilities(vec![capability()])
218 .build()
219 .unwrap()
220 }
221
222 #[tokio::test]
223 async fn registry_hooks_trigger_lifecycle_actions() {
224 let scheduler = TaskScheduler::default();
225 let handler = Arc::new(NullHandler);
226 let mut kernel = AgentKernel::new(AgentId::random(), handler, scheduler.clone());
227
228 let registry = Arc::new(CountingRegistry::default());
229 let config = RegistrationConfig::new(
230 Duration::from_millis(10),
231 Duration::from_millis(5),
232 Duration::from_millis(20),
233 NonZeroUsize::new(3).unwrap(),
234 );
235 kernel.set_registry(registry.clone(), manifest(), config);
236
237 kernel.transition(LifecycleEvent::Boot).unwrap();
238 kernel.transition(LifecycleEvent::Activate).unwrap();
239
240 tokio::time::sleep(Duration::from_millis(35)).await;
241 assert!(registry.registers.load(Ordering::SeqCst) >= 1);
242 assert!(registry.heartbeats.load(Ordering::SeqCst) >= 1);
243
244 kernel.transition(LifecycleEvent::Retire).unwrap();
245 kernel.transition(LifecycleEvent::Terminate).unwrap();
246 tokio::time::sleep(Duration::from_millis(20)).await;
247 assert!(registry.deregisters.load(Ordering::SeqCst) >= 1);
248 }
249}