agent_kernel/
mxp_handlers.rs

1//! Routing utilities for MXP protocol messages.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use agent_primitives::AgentId;
7use async_trait::async_trait;
8use mxp::{Message, MessageType};
9use thiserror::Error;
10
11/// Context provided to message handlers.
12#[derive(Debug, Clone)]
13pub struct HandlerContext {
14    agent_id: AgentId,
15    received_at: Instant,
16    message: Arc<Message>,
17}
18
19impl HandlerContext {
20    /// Constructs a context from an owned message.
21    #[must_use]
22    pub fn from_message(agent_id: AgentId, message: Message) -> Self {
23        Self::from_shared(agent_id, Arc::new(message))
24    }
25
26    /// Constructs a context from a shared message instance.
27    #[must_use]
28    pub fn from_shared(agent_id: AgentId, message: Arc<Message>) -> Self {
29        Self {
30            agent_id,
31            received_at: Instant::now(),
32            message,
33        }
34    }
35
36    /// Returns the agent identifier.
37    #[must_use]
38    pub const fn agent_id(&self) -> AgentId {
39        self.agent_id
40    }
41
42    /// Returns the time the message was received.
43    #[must_use]
44    pub fn received_at(&self) -> Instant {
45        self.received_at
46    }
47
48    /// Returns the underlying MXP message.
49    #[must_use]
50    pub fn message(&self) -> &Message {
51        &self.message
52    }
53
54    /// Returns the MXP message type.
55    ///
56    /// # Errors
57    ///
58    /// Returns [`HandlerError::MissingMessageType`] when the header could not be
59    /// decoded into a [`MessageType`].
60    pub fn message_type(&self) -> HandlerResult<MessageType> {
61        self.message
62            .message_type()
63            .ok_or(HandlerError::MissingMessageType)
64    }
65}
66
67/// Errors that can occur during message handling.
68#[derive(Debug, Error, PartialEq, Eq)]
69pub enum HandlerError {
70    /// The message header did not contain a valid message type.
71    #[error("message missing type information")]
72    MissingMessageType,
73    /// The agent does not handle the message type.
74    #[error("message type {0:?} is not supported")]
75    Unsupported(MessageType),
76    /// Custom handler error with human-readable context.
77    #[error("handler error: {0}")]
78    Custom(String),
79}
80
81impl HandlerError {
82    /// Creates a custom error variant from a string-like value.
83    #[must_use]
84    pub fn custom(reason: impl Into<String>) -> Self {
85        Self::Custom(reason.into())
86    }
87}
88
89/// Result alias for handler operations.
90pub type HandlerResult<T = ()> = Result<T, HandlerError>;
91
92/// Trait implemented by agent-specific MXP message handlers.
93#[async_trait]
94pub trait AgentMessageHandler: Send + Sync {
95    /// Called for `AgentRegister` messages.
96    async fn handle_agent_register(&self, ctx: HandlerContext) -> HandlerResult {
97        self.handle_unhandled(ctx, MessageType::AgentRegister).await
98    }
99
100    /// Called for `AgentDiscover` messages.
101    async fn handle_agent_discover(&self, ctx: HandlerContext) -> HandlerResult {
102        self.handle_unhandled(ctx, MessageType::AgentDiscover).await
103    }
104
105    /// Called for `AgentHeartbeat` messages.
106    async fn handle_agent_heartbeat(&self, ctx: HandlerContext) -> HandlerResult {
107        self.handle_unhandled(ctx, MessageType::AgentHeartbeat)
108            .await
109    }
110
111    /// Called for `Call` messages.
112    async fn handle_call(&self, ctx: HandlerContext) -> HandlerResult {
113        self.handle_unhandled(ctx, MessageType::Call).await
114    }
115
116    /// Called for `Response` messages.
117    async fn handle_response(&self, ctx: HandlerContext) -> HandlerResult {
118        self.handle_unhandled(ctx, MessageType::Response).await
119    }
120
121    /// Called for `Event` messages.
122    async fn handle_event(&self, ctx: HandlerContext) -> HandlerResult {
123        self.handle_unhandled(ctx, MessageType::Event).await
124    }
125
126    /// Called for `StreamOpen` messages.
127    async fn handle_stream_open(&self, ctx: HandlerContext) -> HandlerResult {
128        self.handle_unhandled(ctx, MessageType::StreamOpen).await
129    }
130
131    /// Called for `StreamChunk` messages.
132    async fn handle_stream_chunk(&self, ctx: HandlerContext) -> HandlerResult {
133        self.handle_unhandled(ctx, MessageType::StreamChunk).await
134    }
135
136    /// Called for `StreamClose` messages.
137    async fn handle_stream_close(&self, ctx: HandlerContext) -> HandlerResult {
138        self.handle_unhandled(ctx, MessageType::StreamClose).await
139    }
140
141    /// Called for `Ack` messages.
142    async fn handle_ack(&self, ctx: HandlerContext) -> HandlerResult {
143        self.handle_unhandled(ctx, MessageType::Ack).await
144    }
145
146    /// Called for protocol-level error responses.
147    async fn handle_error(&self, ctx: HandlerContext) -> HandlerResult {
148        self.handle_unhandled(ctx, MessageType::Error).await
149    }
150
151    /// Fallback invoked when a specialized handler is not implemented.
152    async fn handle_unhandled(
153        &self,
154        ctx: HandlerContext,
155        message_type: MessageType,
156    ) -> HandlerResult {
157        let _ = ctx;
158        Err(HandlerError::Unsupported(message_type))
159    }
160}
161
162/// Dispatches a message to the appropriate handler.
163///
164/// # Errors
165///
166/// Propagates errors returned by the underlying handler implementation.
167pub async fn dispatch_message<H>(handler: &H, ctx: HandlerContext) -> HandlerResult
168where
169    H: AgentMessageHandler + ?Sized,
170{
171    let message_type = ctx.message_type()?;
172
173    match message_type {
174        MessageType::AgentRegister => handler.handle_agent_register(ctx).await,
175        MessageType::AgentDiscover => handler.handle_agent_discover(ctx).await,
176        MessageType::AgentHeartbeat => handler.handle_agent_heartbeat(ctx).await,
177        MessageType::Call => handler.handle_call(ctx).await,
178        MessageType::Response => handler.handle_response(ctx).await,
179        MessageType::Event => handler.handle_event(ctx).await,
180        MessageType::StreamOpen => handler.handle_stream_open(ctx).await,
181        MessageType::StreamChunk => handler.handle_stream_chunk(ctx).await,
182        MessageType::StreamClose => handler.handle_stream_close(ctx).await,
183        MessageType::Ack => handler.handle_ack(ctx).await,
184        MessageType::Error => handler.handle_error(ctx).await,
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use std::sync::atomic::{AtomicUsize, Ordering};
192
193    struct CountingHandler {
194        calls: Arc<AtomicUsize>,
195    }
196
197    #[async_trait]
198    impl AgentMessageHandler for CountingHandler {
199        async fn handle_call(&self, _ctx: HandlerContext) -> HandlerResult {
200            self.calls.fetch_add(1, Ordering::SeqCst);
201            Ok(())
202        }
203    }
204
205    #[tokio::test]
206    async fn dispatches_to_specific_handler() {
207        let handler = CountingHandler {
208            calls: Arc::new(AtomicUsize::new(0)),
209        };
210
211        let message = Message::new(MessageType::Call, b"ping");
212        let ctx = HandlerContext::from_message(AgentId::random(), message);
213        dispatch_message(&handler, ctx).await.unwrap();
214
215        assert_eq!(handler.calls.load(Ordering::SeqCst), 1);
216    }
217
218    struct UnsupportedHandler;
219
220    #[async_trait]
221    impl AgentMessageHandler for UnsupportedHandler {}
222
223    #[tokio::test]
224    async fn unsupported_message_errors() {
225        let handler = UnsupportedHandler;
226        let message = Message::new(MessageType::Event, b"noop");
227        let ctx = HandlerContext::from_message(AgentId::random(), message);
228        let err = dispatch_message(&handler, ctx)
229            .await
230            .expect_err("should error");
231
232        assert_eq!(err, HandlerError::Unsupported(MessageType::Event));
233    }
234}