agent_kernel/
mxp_handlers.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use agent_primitives::AgentId;
7use async_trait::async_trait;
8use mxp::{Message, MessageType};
9use thiserror::Error;
10
11#[derive(Debug, Clone)]
13pub struct HandlerContext {
14 agent_id: AgentId,
15 received_at: Instant,
16 message: Arc<Message>,
17}
18
19impl HandlerContext {
20 #[must_use]
22 pub fn from_message(agent_id: AgentId, message: Message) -> Self {
23 Self::from_shared(agent_id, Arc::new(message))
24 }
25
26 #[must_use]
28 pub fn from_shared(agent_id: AgentId, message: Arc<Message>) -> Self {
29 Self {
30 agent_id,
31 received_at: Instant::now(),
32 message,
33 }
34 }
35
36 #[must_use]
38 pub const fn agent_id(&self) -> AgentId {
39 self.agent_id
40 }
41
42 #[must_use]
44 pub fn received_at(&self) -> Instant {
45 self.received_at
46 }
47
48 #[must_use]
50 pub fn message(&self) -> &Message {
51 &self.message
52 }
53
54 pub fn message_type(&self) -> HandlerResult<MessageType> {
61 self.message
62 .message_type()
63 .ok_or(HandlerError::MissingMessageType)
64 }
65}
66
67#[derive(Debug, Error, PartialEq, Eq)]
69pub enum HandlerError {
70 #[error("message missing type information")]
72 MissingMessageType,
73 #[error("message type {0:?} is not supported")]
75 Unsupported(MessageType),
76 #[error("handler error: {0}")]
78 Custom(String),
79}
80
81impl HandlerError {
82 #[must_use]
84 pub fn custom(reason: impl Into<String>) -> Self {
85 Self::Custom(reason.into())
86 }
87}
88
89pub type HandlerResult<T = ()> = Result<T, HandlerError>;
91
92#[async_trait]
94pub trait AgentMessageHandler: Send + Sync {
95 async fn handle_agent_register(&self, ctx: HandlerContext) -> HandlerResult {
97 self.handle_unhandled(ctx, MessageType::AgentRegister).await
98 }
99
100 async fn handle_agent_discover(&self, ctx: HandlerContext) -> HandlerResult {
102 self.handle_unhandled(ctx, MessageType::AgentDiscover).await
103 }
104
105 async fn handle_agent_heartbeat(&self, ctx: HandlerContext) -> HandlerResult {
107 self.handle_unhandled(ctx, MessageType::AgentHeartbeat)
108 .await
109 }
110
111 async fn handle_call(&self, ctx: HandlerContext) -> HandlerResult {
113 self.handle_unhandled(ctx, MessageType::Call).await
114 }
115
116 async fn handle_response(&self, ctx: HandlerContext) -> HandlerResult {
118 self.handle_unhandled(ctx, MessageType::Response).await
119 }
120
121 async fn handle_event(&self, ctx: HandlerContext) -> HandlerResult {
123 self.handle_unhandled(ctx, MessageType::Event).await
124 }
125
126 async fn handle_stream_open(&self, ctx: HandlerContext) -> HandlerResult {
128 self.handle_unhandled(ctx, MessageType::StreamOpen).await
129 }
130
131 async fn handle_stream_chunk(&self, ctx: HandlerContext) -> HandlerResult {
133 self.handle_unhandled(ctx, MessageType::StreamChunk).await
134 }
135
136 async fn handle_stream_close(&self, ctx: HandlerContext) -> HandlerResult {
138 self.handle_unhandled(ctx, MessageType::StreamClose).await
139 }
140
141 async fn handle_ack(&self, ctx: HandlerContext) -> HandlerResult {
143 self.handle_unhandled(ctx, MessageType::Ack).await
144 }
145
146 async fn handle_error(&self, ctx: HandlerContext) -> HandlerResult {
148 self.handle_unhandled(ctx, MessageType::Error).await
149 }
150
151 async fn handle_unhandled(
153 &self,
154 ctx: HandlerContext,
155 message_type: MessageType,
156 ) -> HandlerResult {
157 let _ = ctx;
158 Err(HandlerError::Unsupported(message_type))
159 }
160}
161
162pub async fn dispatch_message<H>(handler: &H, ctx: HandlerContext) -> HandlerResult
168where
169 H: AgentMessageHandler + ?Sized,
170{
171 let message_type = ctx.message_type()?;
172
173 match message_type {
174 MessageType::AgentRegister => handler.handle_agent_register(ctx).await,
175 MessageType::AgentDiscover => handler.handle_agent_discover(ctx).await,
176 MessageType::AgentHeartbeat => handler.handle_agent_heartbeat(ctx).await,
177 MessageType::Call => handler.handle_call(ctx).await,
178 MessageType::Response => handler.handle_response(ctx).await,
179 MessageType::Event => handler.handle_event(ctx).await,
180 MessageType::StreamOpen => handler.handle_stream_open(ctx).await,
181 MessageType::StreamChunk => handler.handle_stream_chunk(ctx).await,
182 MessageType::StreamClose => handler.handle_stream_close(ctx).await,
183 MessageType::Ack => handler.handle_ack(ctx).await,
184 MessageType::Error => handler.handle_error(ctx).await,
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use std::sync::atomic::{AtomicUsize, Ordering};
192
193 struct CountingHandler {
194 calls: Arc<AtomicUsize>,
195 }
196
197 #[async_trait]
198 impl AgentMessageHandler for CountingHandler {
199 async fn handle_call(&self, _ctx: HandlerContext) -> HandlerResult {
200 self.calls.fetch_add(1, Ordering::SeqCst);
201 Ok(())
202 }
203 }
204
205 #[tokio::test]
206 async fn dispatches_to_specific_handler() {
207 let handler = CountingHandler {
208 calls: Arc::new(AtomicUsize::new(0)),
209 };
210
211 let message = Message::new(MessageType::Call, b"ping");
212 let ctx = HandlerContext::from_message(AgentId::random(), message);
213 dispatch_message(&handler, ctx).await.unwrap();
214
215 assert_eq!(handler.calls.load(Ordering::SeqCst), 1);
216 }
217
218 struct UnsupportedHandler;
219
220 #[async_trait]
221 impl AgentMessageHandler for UnsupportedHandler {}
222
223 #[tokio::test]
224 async fn unsupported_message_errors() {
225 let handler = UnsupportedHandler;
226 let message = Message::new(MessageType::Event, b"noop");
227 let ctx = HandlerContext::from_message(AgentId::random(), message);
228 let err = dispatch_message(&handler, ctx)
229 .await
230 .expect_err("should error");
231
232 assert_eq!(err, HandlerError::Unsupported(MessageType::Event));
233 }
234}