1use crate::agent::Handler;
4use anyhow::{Context, Result};
5use bytes::Bytes;
6use mitoxide_proto::{Frame, FrameCodec, Message, Request, Response};
7use mitoxide_proto::message::{ErrorCode, ErrorDetails};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::io::AsyncWrite;
11use tokio::sync::{mpsc, oneshot, RwLock};
12use tracing::{debug, error, info, warn};
13use uuid::Uuid;
14
15#[derive(Debug, Clone)]
17struct StreamInfo {
18 stream_id: u32,
20 sequence: u32,
22 request_id: Option<Uuid>,
24}
25
26pub struct AgentRouter<W>
28where
29 W: AsyncWrite + Unpin + Send,
30{
31 writer: Arc<tokio::sync::Mutex<W>>,
33 codec: FrameCodec,
35 streams: Arc<RwLock<HashMap<u32, StreamInfo>>>,
37 handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler>>>>,
39 request_tx: mpsc::UnboundedSender<(u32, u32, Request)>,
41 request_rx: Option<mpsc::UnboundedReceiver<(u32, u32, Request)>>,
43 shutdown_tx: Option<oneshot::Sender<()>>,
45}
46
47impl<W> AgentRouter<W>
48where
49 W: AsyncWrite + Unpin + Send + 'static,
50{
51 pub fn new(writer: W) -> Self {
53 let (request_tx, request_rx) = mpsc::unbounded_channel();
54 let (shutdown_tx, _) = oneshot::channel();
55
56 Self {
57 writer: Arc::new(tokio::sync::Mutex::new(writer)),
58 codec: FrameCodec::new(),
59 streams: Arc::new(RwLock::new(HashMap::new())),
60 handlers: Arc::new(RwLock::new(HashMap::new())),
61 request_tx,
62 request_rx: Some(request_rx),
63 shutdown_tx: Some(shutdown_tx),
64 }
65 }
66
67 pub async fn register_handler(&self, request_type: String, handler: Arc<dyn Handler>) {
69 let mut handlers = self.handlers.write().await;
70 debug!("Registered handler for request type: {}", request_type);
71 handlers.insert(request_type, handler);
72 }
73
74 pub fn shutdown_sender(&mut self) -> Option<oneshot::Sender<()>> {
76 self.shutdown_tx.take()
77 }
78
79 pub async fn route_frame(&self, frame: Frame) -> Result<()> {
81 debug!("Routing frame: stream_id={}, sequence={}, flags={:?}",
82 frame.stream_id, frame.sequence, frame.flags);
83
84 if frame.is_error() {
86 warn!("Received error frame: stream_id={}, payload={:?}",
87 frame.stream_id, frame.payload);
88 return Ok(());
89 }
90
91 if frame.is_end_stream() {
92 debug!("Received end-of-stream frame: stream_id={}", frame.stream_id);
93 self.close_stream(frame.stream_id).await;
94 return Ok(());
95 }
96
97 self.update_stream_info(frame.stream_id, frame.sequence).await;
99
100 let message = match rmp_serde::from_slice::<Message>(&frame.payload) {
102 Ok(msg) => msg,
103 Err(e) => {
104 error!("Failed to deserialize message: {}", e);
105 self.send_error_frame(frame.stream_id, frame.sequence,
106 ErrorCode::InvalidRequest,
107 format!("Invalid message format: {}", e)).await?;
108 return Ok(());
109 }
110 };
111
112 match message {
114 Message::Request(request) => {
115 if let Err(e) = self.request_tx.send((frame.stream_id, frame.sequence, request)) {
117 error!("Failed to send request for processing: {}", e);
118 }
119 }
120 Message::Response(_) => {
121 warn!("Received unexpected response message on agent router");
122 }
123 }
124
125 Ok(())
126 }
127
128 pub async fn start_processing(&mut self) -> Result<()> {
130 let mut request_rx = self.request_rx.take()
131 .context("Request receiver already taken")?;
132
133 let handlers = Arc::clone(&self.handlers);
134 let writer = Arc::clone(&self.writer);
135
136 info!("Starting request processing loop");
137
138 while let Some((stream_id, sequence, request)) = request_rx.recv().await {
139 let handlers = Arc::clone(&handlers);
140 let writer = Arc::clone(&writer);
141
142 tokio::spawn(async move {
144 let response = Self::process_request(request, &handlers).await;
145 let codec = FrameCodec::new(); if let Err(e) = Self::send_response(stream_id, sequence, response, &writer, &codec).await {
148 error!("Failed to send response: {}", e);
149 }
150 });
151 }
152
153 info!("Request processing loop stopped");
154 Ok(())
155 }
156
157 async fn process_request(request: Request, handlers: &Arc<RwLock<HashMap<String, Arc<dyn Handler>>>>) -> Response {
159 let request_id = request.id();
160 debug!("Processing request: id={}, type={:?}", request_id, std::mem::discriminant(&request));
161
162 let request_type = match &request {
164 Request::ProcessExec { .. } => "process_exec",
165 Request::FileGet { .. } => "file_get",
166 Request::FilePut { .. } => "file_put",
167 Request::DirList { .. } => "dir_list",
168 Request::WasmExec { .. } => "wasm_exec",
169 Request::JsonCall { .. } => "json_call",
170 Request::Ping { .. } => "ping",
171 Request::PtyExec { .. } => "pty_exec",
172 };
173
174 let handler = {
176 let handlers_guard = handlers.read().await;
177 handlers_guard.get(request_type).cloned()
178 };
179
180 match handler {
181 Some(handler) => {
182 match handler.handle(request).await {
184 Ok(response) => response,
185 Err(e) => {
186 error!("Handler error for request {}: {}", request_id, e);
187 Response::error(
188 request_id,
189 ErrorDetails::new(ErrorCode::InternalError, format!("Handler error: {}", e))
190 )
191 }
192 }
193 }
194 None => {
195 warn!("No handler registered for request type: {}", request_type);
196 Response::error(
197 request_id,
198 ErrorDetails::new(ErrorCode::Unsupported, format!("Unsupported request type: {}", request_type))
199 )
200 }
201 }
202 }
203
204 async fn send_response(
206 stream_id: u32,
207 sequence: u32,
208 response: Response,
209 writer: &Arc<tokio::sync::Mutex<W>>,
210 codec: &FrameCodec
211 ) -> Result<()> {
212 let message = Message::response(response);
213 let payload = rmp_serde::to_vec(&message)
214 .context("Failed to serialize response message")?;
215
216 let frame = Frame::data(stream_id, sequence, Bytes::from(payload));
217
218 let mut writer_guard = writer.lock().await;
219 codec.write_frame(&mut *writer_guard, &frame).await
220 .context("Failed to write response frame")?;
221
222 debug!("Sent response: stream_id={}, sequence={}", stream_id, sequence);
223 Ok(())
224 }
225
226 async fn send_error_frame(&self, stream_id: u32, sequence: u32,
228 error_code: ErrorCode, message: String) -> Result<()> {
229 let error_payload = rmp_serde::to_vec(&ErrorDetails::new(error_code, message))
230 .context("Failed to serialize error details")?;
231
232 let frame = Frame::error(stream_id, sequence, Bytes::from(error_payload));
233
234 let mut writer = self.writer.lock().await;
235 self.codec.write_frame(&mut *writer, &frame).await
236 .context("Failed to write error frame")?;
237
238 debug!("Sent error frame: stream_id={}, sequence={}", stream_id, sequence);
239 Ok(())
240 }
241
242 async fn update_stream_info(&self, stream_id: u32, sequence: u32) {
244 let mut streams = self.streams.write().await;
245 streams.insert(stream_id, StreamInfo {
246 stream_id,
247 sequence,
248 request_id: None,
249 });
250 }
251
252 async fn close_stream(&self, stream_id: u32) {
254 let mut streams = self.streams.write().await;
255 if streams.remove(&stream_id).is_some() {
256 debug!("Closed stream: {}", stream_id);
257 }
258 }
259
260 pub async fn active_stream_count(&self) -> usize {
262 let streams = self.streams.read().await;
263 streams.len()
264 }
265
266 pub async fn active_streams(&self) -> Vec<u32> {
268 let streams = self.streams.read().await;
269 streams.keys().copied().collect()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::handlers::PingHandler;
277 use mitoxide_proto::{Request, Response};
278 use std::collections::HashMap;
279 use std::io::Cursor;
280
281
282 #[tokio::test]
283 async fn test_router_creation() {
284 let output = Cursor::new(Vec::<u8>::new());
285 let router = AgentRouter::new(output);
286
287 assert_eq!(router.active_stream_count().await, 0);
288 assert!(router.active_streams().await.is_empty());
289 }
290
291 #[tokio::test]
292 async fn test_handler_registration() {
293 let output = Cursor::new(Vec::<u8>::new());
294 let router = AgentRouter::new(output);
295
296 let handler = Arc::new(PingHandler);
297 router.register_handler("ping".to_string(), handler).await;
298
299 let handlers = router.handlers.read().await;
300 assert!(handlers.contains_key("ping"));
301 }
302
303 #[tokio::test]
304 async fn test_stream_management() {
305 let output = Cursor::new(Vec::<u8>::new());
306 let router = AgentRouter::new(output);
307
308 router.update_stream_info(1, 42).await;
310 assert_eq!(router.active_stream_count().await, 1);
311 assert_eq!(router.active_streams().await, vec![1]);
312
313 router.close_stream(1).await;
315 assert_eq!(router.active_stream_count().await, 0);
316 assert!(router.active_streams().await.is_empty());
317 }
318
319 #[tokio::test]
320 async fn test_ping_request_routing() {
321 let output = Cursor::new(Vec::<u8>::new());
322 let router = AgentRouter::new(output);
323
324 let handler = Arc::new(PingHandler);
326 router.register_handler("ping".to_string(), handler).await;
327
328 let request = Request::ping();
330 let message = Message::request(request);
331 let payload = rmp_serde::to_vec(&message).unwrap();
332 let frame = Frame::data(1, 1, Bytes::from(payload));
333
334 let result = router.route_frame(frame).await;
336 assert!(result.is_ok());
337
338 assert_eq!(router.active_stream_count().await, 1);
340 }
341
342 #[tokio::test]
343 async fn test_invalid_message_routing() {
344 let output = Cursor::new(Vec::<u8>::new());
345 let router = AgentRouter::new(output);
346
347 let frame = Frame::data(1, 1, Bytes::from(vec![0xFF, 0xFF, 0xFF, 0xFF]));
349
350 let result = router.route_frame(frame).await;
352 assert!(result.is_ok());
353 }
354
355 #[tokio::test]
356 async fn test_error_frame_routing() {
357 let output = Cursor::new(Vec::<u8>::new());
358 let router = AgentRouter::new(output);
359
360 let error_frame = Frame::error(1, 1, Bytes::from("test error"));
361
362 let result = router.route_frame(error_frame).await;
364 assert!(result.is_ok());
365 }
366
367 #[tokio::test]
368 async fn test_end_stream_frame_routing() {
369 let output = Cursor::new(Vec::<u8>::new());
370 let router = AgentRouter::new(output);
371
372 router.update_stream_info(1, 1).await;
374 assert_eq!(router.active_stream_count().await, 1);
375
376 let end_frame = Frame::end_stream(1, 2);
378 let result = router.route_frame(end_frame).await;
379 assert!(result.is_ok());
380
381 assert_eq!(router.active_stream_count().await, 0);
383 }
384
385 #[tokio::test]
386 async fn test_process_request_with_handler() {
387 let handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler>>>> = Arc::new(RwLock::new(HashMap::new()));
388
389 let handler: Arc<dyn Handler> = Arc::new(PingHandler);
391 handlers.write().await.insert("ping".to_string(), handler);
392
393 let request = Request::ping();
395 let request_id = request.id();
396 let response = AgentRouter::<Cursor<Vec<u8>>>::process_request(request, &handlers).await;
397
398 match response {
399 Response::Pong { request_id: resp_id, .. } => {
400 assert_eq!(resp_id, request_id);
401 }
402 _ => panic!("Expected Pong response"),
403 }
404 }
405
406 #[tokio::test]
407 async fn test_process_request_without_handler() {
408 let handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler>>>> = Arc::new(RwLock::new(HashMap::new()));
409
410 let request = Request::ping();
412 let request_id = request.id();
413 let response = AgentRouter::<Cursor<Vec<u8>>>::process_request(request, &handlers).await;
414
415 match response {
416 Response::Error { request_id: resp_id, error } => {
417 assert_eq!(resp_id, request_id);
418 assert_eq!(error.code, ErrorCode::Unsupported);
419 }
420 _ => panic!("Expected Error response"),
421 }
422 }
423
424 #[tokio::test]
425 async fn test_concurrent_request_processing() {
426 let output = Cursor::new(Vec::<u8>::new());
427 let router = AgentRouter::new(output);
428
429 let ping_handler: Arc<dyn Handler> = Arc::new(PingHandler);
431 router.register_handler("ping".to_string(), ping_handler).await;
432
433 for i in 0..5 {
435 let request = Request::ping();
436 let message = Message::request(request);
437 let payload = rmp_serde::to_vec(&message).unwrap();
438 let frame = Frame::data(i + 1, 1, Bytes::from(payload));
439
440 let result = router.route_frame(frame).await;
441 assert!(result.is_ok());
442 }
443
444 assert_eq!(router.active_stream_count().await, 5);
446 }
447}