mitoxide_agent/
router.rs

1//! Agent-side routing for multiplexed streams
2
3use crate::agent::Handler;
4use anyhow::{Context, Result};
5use bytes::Bytes;
6use mitoxide_proto::{Frame, FrameCodec, Message, Request, Response};
7use mitoxide_proto::message::{ErrorCode, ErrorDetails};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::io::AsyncWrite;
11use tokio::sync::{mpsc, oneshot, RwLock};
12use tracing::{debug, error, info, warn};
13use uuid::Uuid;
14
15/// Stream information for tracking active streams
16#[derive(Debug, Clone)]
17struct StreamInfo {
18    /// Stream ID
19    stream_id: u32,
20    /// Current sequence number
21    sequence: u32,
22    /// Request ID being processed
23    request_id: Option<Uuid>,
24}
25
26/// Agent-side router for handling multiplexed streams and request dispatch
27pub struct AgentRouter<W>
28where
29    W: AsyncWrite + Unpin + Send,
30{
31    /// Output writer for sending responses
32    writer: Arc<tokio::sync::Mutex<W>>,
33    /// Frame codec for encoding responses
34    codec: FrameCodec,
35    /// Active streams
36    streams: Arc<RwLock<HashMap<u32, StreamInfo>>>,
37    /// Registered handlers by request type
38    handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler>>>>,
39    /// Channel for sending requests to be processed
40    request_tx: mpsc::UnboundedSender<(u32, u32, Request)>,
41    /// Channel for receiving requests to process
42    request_rx: Option<mpsc::UnboundedReceiver<(u32, u32, Request)>>,
43    /// Shutdown signal
44    shutdown_tx: Option<oneshot::Sender<()>>,
45}
46
47impl<W> AgentRouter<W>
48where
49    W: AsyncWrite + Unpin + Send + 'static,
50{
51    /// Create a new agent router
52    pub fn new(writer: W) -> Self {
53        let (request_tx, request_rx) = mpsc::unbounded_channel();
54        let (shutdown_tx, _) = oneshot::channel();
55        
56        Self {
57            writer: Arc::new(tokio::sync::Mutex::new(writer)),
58            codec: FrameCodec::new(),
59            streams: Arc::new(RwLock::new(HashMap::new())),
60            handlers: Arc::new(RwLock::new(HashMap::new())),
61            request_tx,
62            request_rx: Some(request_rx),
63            shutdown_tx: Some(shutdown_tx),
64        }
65    }
66    
67    /// Register a handler for a specific request type
68    pub async fn register_handler(&self, request_type: String, handler: Arc<dyn Handler>) {
69        let mut handlers = self.handlers.write().await;
70        debug!("Registered handler for request type: {}", request_type);
71        handlers.insert(request_type, handler);
72    }
73    
74    /// Get shutdown sender for graceful shutdown
75    pub fn shutdown_sender(&mut self) -> Option<oneshot::Sender<()>> {
76        self.shutdown_tx.take()
77    }
78    
79    /// Route an incoming frame to the appropriate handler
80    pub async fn route_frame(&self, frame: Frame) -> Result<()> {
81        debug!("Routing frame: stream_id={}, sequence={}, flags={:?}", 
82               frame.stream_id, frame.sequence, frame.flags);
83        
84        // Handle control frames
85        if frame.is_error() {
86            warn!("Received error frame: stream_id={}, payload={:?}", 
87                  frame.stream_id, frame.payload);
88            return Ok(());
89        }
90        
91        if frame.is_end_stream() {
92            debug!("Received end-of-stream frame: stream_id={}", frame.stream_id);
93            self.close_stream(frame.stream_id).await;
94            return Ok(());
95        }
96        
97        // Update stream info
98        self.update_stream_info(frame.stream_id, frame.sequence).await;
99        
100        // Deserialize message from frame payload
101        let message = match rmp_serde::from_slice::<Message>(&frame.payload) {
102            Ok(msg) => msg,
103            Err(e) => {
104                error!("Failed to deserialize message: {}", e);
105                self.send_error_frame(frame.stream_id, frame.sequence, 
106                                    ErrorCode::InvalidRequest, 
107                                    format!("Invalid message format: {}", e)).await?;
108                return Ok(());
109            }
110        };
111        
112        // Route message
113        match message {
114            Message::Request(request) => {
115                // Send request for processing
116                if let Err(e) = self.request_tx.send((frame.stream_id, frame.sequence, request)) {
117                    error!("Failed to send request for processing: {}", e);
118                }
119            }
120            Message::Response(_) => {
121                warn!("Received unexpected response message on agent router");
122            }
123        }
124        
125        Ok(())
126    }
127    
128    /// Start the request processing loop
129    pub async fn start_processing(&mut self) -> Result<()> {
130        let mut request_rx = self.request_rx.take()
131            .context("Request receiver already taken")?;
132        
133        let handlers = Arc::clone(&self.handlers);
134        let writer = Arc::clone(&self.writer);
135        
136        info!("Starting request processing loop");
137        
138        while let Some((stream_id, sequence, request)) = request_rx.recv().await {
139            let handlers = Arc::clone(&handlers);
140            let writer = Arc::clone(&writer);
141            
142            // Process request in a separate task
143            tokio::spawn(async move {
144                let response = Self::process_request(request, &handlers).await;
145                let codec = FrameCodec::new(); // Create new codec instance
146                
147                if let Err(e) = Self::send_response(stream_id, sequence, response, &writer, &codec).await {
148                    error!("Failed to send response: {}", e);
149                }
150            });
151        }
152        
153        info!("Request processing loop stopped");
154        Ok(())
155    }
156    
157    /// Process a single request using registered handlers
158    async fn process_request(request: Request, handlers: &Arc<RwLock<HashMap<String, Arc<dyn Handler>>>>) -> Response {
159        let request_id = request.id();
160        debug!("Processing request: id={}, type={:?}", request_id, std::mem::discriminant(&request));
161        
162        // Determine request type for handler lookup
163        let request_type = match &request {
164            Request::ProcessExec { .. } => "process_exec",
165            Request::FileGet { .. } => "file_get",
166            Request::FilePut { .. } => "file_put",
167            Request::DirList { .. } => "dir_list",
168            Request::WasmExec { .. } => "wasm_exec",
169            Request::JsonCall { .. } => "json_call",
170            Request::Ping { .. } => "ping",
171            Request::PtyExec { .. } => "pty_exec",
172        };
173        
174        // Look up handler
175        let handler = {
176            let handlers_guard = handlers.read().await;
177            handlers_guard.get(request_type).cloned()
178        };
179        
180        match handler {
181            Some(handler) => {
182                // Execute handler
183                match handler.handle(request).await {
184                    Ok(response) => response,
185                    Err(e) => {
186                        error!("Handler error for request {}: {}", request_id, e);
187                        Response::error(
188                            request_id,
189                            ErrorDetails::new(ErrorCode::InternalError, format!("Handler error: {}", e))
190                        )
191                    }
192                }
193            }
194            None => {
195                warn!("No handler registered for request type: {}", request_type);
196                Response::error(
197                    request_id,
198                    ErrorDetails::new(ErrorCode::Unsupported, format!("Unsupported request type: {}", request_type))
199                )
200            }
201        }
202    }
203    
204    /// Send a response back to the client
205    async fn send_response(
206        stream_id: u32, 
207        sequence: u32, 
208        response: Response,
209        writer: &Arc<tokio::sync::Mutex<W>>,
210        codec: &FrameCodec
211    ) -> Result<()> {
212        let message = Message::response(response);
213        let payload = rmp_serde::to_vec(&message)
214            .context("Failed to serialize response message")?;
215        
216        let frame = Frame::data(stream_id, sequence, Bytes::from(payload));
217        
218        let mut writer_guard = writer.lock().await;
219        codec.write_frame(&mut *writer_guard, &frame).await
220            .context("Failed to write response frame")?;
221        
222        debug!("Sent response: stream_id={}, sequence={}", stream_id, sequence);
223        Ok(())
224    }
225    
226    /// Send an error frame
227    async fn send_error_frame(&self, stream_id: u32, sequence: u32, 
228                            error_code: ErrorCode, message: String) -> Result<()> {
229        let error_payload = rmp_serde::to_vec(&ErrorDetails::new(error_code, message))
230            .context("Failed to serialize error details")?;
231        
232        let frame = Frame::error(stream_id, sequence, Bytes::from(error_payload));
233        
234        let mut writer = self.writer.lock().await;
235        self.codec.write_frame(&mut *writer, &frame).await
236            .context("Failed to write error frame")?;
237        
238        debug!("Sent error frame: stream_id={}, sequence={}", stream_id, sequence);
239        Ok(())
240    }
241    
242    /// Update stream information
243    async fn update_stream_info(&self, stream_id: u32, sequence: u32) {
244        let mut streams = self.streams.write().await;
245        streams.insert(stream_id, StreamInfo {
246            stream_id,
247            sequence,
248            request_id: None,
249        });
250    }
251    
252    /// Close a stream
253    async fn close_stream(&self, stream_id: u32) {
254        let mut streams = self.streams.write().await;
255        if streams.remove(&stream_id).is_some() {
256            debug!("Closed stream: {}", stream_id);
257        }
258    }
259    
260    /// Get active stream count
261    pub async fn active_stream_count(&self) -> usize {
262        let streams = self.streams.read().await;
263        streams.len()
264    }
265    
266    /// Get list of active stream IDs
267    pub async fn active_streams(&self) -> Vec<u32> {
268        let streams = self.streams.read().await;
269        streams.keys().copied().collect()
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::handlers::PingHandler;
277    use mitoxide_proto::{Request, Response};
278    use std::collections::HashMap;
279    use std::io::Cursor;
280
281    
282    #[tokio::test]
283    async fn test_router_creation() {
284        let output = Cursor::new(Vec::<u8>::new());
285        let router = AgentRouter::new(output);
286        
287        assert_eq!(router.active_stream_count().await, 0);
288        assert!(router.active_streams().await.is_empty());
289    }
290    
291    #[tokio::test]
292    async fn test_handler_registration() {
293        let output = Cursor::new(Vec::<u8>::new());
294        let router = AgentRouter::new(output);
295        
296        let handler = Arc::new(PingHandler);
297        router.register_handler("ping".to_string(), handler).await;
298        
299        let handlers = router.handlers.read().await;
300        assert!(handlers.contains_key("ping"));
301    }
302    
303    #[tokio::test]
304    async fn test_stream_management() {
305        let output = Cursor::new(Vec::<u8>::new());
306        let router = AgentRouter::new(output);
307        
308        // Update stream info
309        router.update_stream_info(1, 42).await;
310        assert_eq!(router.active_stream_count().await, 1);
311        assert_eq!(router.active_streams().await, vec![1]);
312        
313        // Close stream
314        router.close_stream(1).await;
315        assert_eq!(router.active_stream_count().await, 0);
316        assert!(router.active_streams().await.is_empty());
317    }
318    
319    #[tokio::test]
320    async fn test_ping_request_routing() {
321        let output = Cursor::new(Vec::<u8>::new());
322        let router = AgentRouter::new(output);
323        
324        // Register ping handler
325        let handler = Arc::new(PingHandler);
326        router.register_handler("ping".to_string(), handler).await;
327        
328        // Create ping request
329        let request = Request::ping();
330        let message = Message::request(request);
331        let payload = rmp_serde::to_vec(&message).unwrap();
332        let frame = Frame::data(1, 1, Bytes::from(payload));
333        
334        // Route the frame
335        let result = router.route_frame(frame).await;
336        assert!(result.is_ok());
337        
338        // Check that stream was created
339        assert_eq!(router.active_stream_count().await, 1);
340    }
341    
342    #[tokio::test]
343    async fn test_invalid_message_routing() {
344        let output = Cursor::new(Vec::<u8>::new());
345        let router = AgentRouter::new(output);
346        
347        // Create frame with invalid payload
348        let frame = Frame::data(1, 1, Bytes::from(vec![0xFF, 0xFF, 0xFF, 0xFF]));
349        
350        // Should handle invalid message gracefully
351        let result = router.route_frame(frame).await;
352        assert!(result.is_ok());
353    }
354    
355    #[tokio::test]
356    async fn test_error_frame_routing() {
357        let output = Cursor::new(Vec::<u8>::new());
358        let router = AgentRouter::new(output);
359        
360        let error_frame = Frame::error(1, 1, Bytes::from("test error"));
361        
362        // Should handle error frames gracefully
363        let result = router.route_frame(error_frame).await;
364        assert!(result.is_ok());
365    }
366    
367    #[tokio::test]
368    async fn test_end_stream_frame_routing() {
369        let output = Cursor::new(Vec::<u8>::new());
370        let router = AgentRouter::new(output);
371        
372        // Create a stream first
373        router.update_stream_info(1, 1).await;
374        assert_eq!(router.active_stream_count().await, 1);
375        
376        // Send end-of-stream frame
377        let end_frame = Frame::end_stream(1, 2);
378        let result = router.route_frame(end_frame).await;
379        assert!(result.is_ok());
380        
381        // Stream should be closed
382        assert_eq!(router.active_stream_count().await, 0);
383    }
384    
385    #[tokio::test]
386    async fn test_process_request_with_handler() {
387        let handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler>>>> = Arc::new(RwLock::new(HashMap::new()));
388        
389        // Register ping handler
390        let handler: Arc<dyn Handler> = Arc::new(PingHandler);
391        handlers.write().await.insert("ping".to_string(), handler);
392        
393        // Process ping request
394        let request = Request::ping();
395        let request_id = request.id();
396        let response = AgentRouter::<Cursor<Vec<u8>>>::process_request(request, &handlers).await;
397        
398        match response {
399            Response::Pong { request_id: resp_id, .. } => {
400                assert_eq!(resp_id, request_id);
401            }
402            _ => panic!("Expected Pong response"),
403        }
404    }
405    
406    #[tokio::test]
407    async fn test_process_request_without_handler() {
408        let handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler>>>> = Arc::new(RwLock::new(HashMap::new()));
409        
410        // Process request without registered handler
411        let request = Request::ping();
412        let request_id = request.id();
413        let response = AgentRouter::<Cursor<Vec<u8>>>::process_request(request, &handlers).await;
414        
415        match response {
416            Response::Error { request_id: resp_id, error } => {
417                assert_eq!(resp_id, request_id);
418                assert_eq!(error.code, ErrorCode::Unsupported);
419            }
420            _ => panic!("Expected Error response"),
421        }
422    }
423    
424    #[tokio::test]
425    async fn test_concurrent_request_processing() {
426        let output = Cursor::new(Vec::<u8>::new());
427        let router = AgentRouter::new(output);
428        
429        // Register handlers
430        let ping_handler: Arc<dyn Handler> = Arc::new(PingHandler);
431        router.register_handler("ping".to_string(), ping_handler).await;
432        
433        // Create multiple ping requests and route them sequentially
434        for i in 0..5 {
435            let request = Request::ping();
436            let message = Message::request(request);
437            let payload = rmp_serde::to_vec(&message).unwrap();
438            let frame = Frame::data(i + 1, 1, Bytes::from(payload));
439            
440            let result = router.route_frame(frame).await;
441            assert!(result.is_ok());
442        }
443        
444        // Check that all streams were processed
445        assert_eq!(router.active_stream_count().await, 5);
446    }
447}