mitoxide_agent/
agent.rs

1//! Agent main loop and frame processing
2
3use anyhow::{Context, Result};
4use bytes::Bytes;
5use mitoxide_proto::{Frame, FrameCodec, Message, Request, Response};
6use mitoxide_proto::message::{ErrorCode, ErrorDetails};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::io::{stdin, stdout, AsyncRead, AsyncWrite};
10use tokio::sync::{oneshot, RwLock};
11use tracing::{debug, error, info, warn};
12
13/// Handler trait for processing requests
14#[async_trait::async_trait]
15pub trait Handler: Send + Sync {
16    /// Handle a request and return a response
17    async fn handle(&self, request: Request) -> Result<Response>;
18}
19
20/// Main agent loop for processing frames
21pub struct AgentLoop<R, W> 
22where
23    R: AsyncRead + Unpin + Send,
24    W: AsyncWrite + Unpin + Send,
25{
26    /// Input stream (typically stdin)
27    reader: R,
28    /// Output stream (typically stdout)
29    writer: W,
30    /// Frame codec for encoding/decoding
31    codec: FrameCodec,
32    /// Registered handlers by request type
33    handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler>>>>,
34    /// Shutdown signal receiver
35    shutdown_rx: Option<oneshot::Receiver<()>>,
36    /// Shutdown signal sender (kept for graceful shutdown)
37    shutdown_tx: Option<oneshot::Sender<()>>,
38}
39
40impl AgentLoop<tokio::io::Stdin, tokio::io::Stdout> {
41    /// Create a new agent loop with stdin/stdout
42    pub fn new() -> Self {
43        let (shutdown_tx, shutdown_rx) = oneshot::channel();
44        Self {
45            reader: stdin(),
46            writer: stdout(),
47            codec: FrameCodec::new(),
48            handlers: Arc::new(RwLock::new(HashMap::new())),
49            shutdown_rx: Some(shutdown_rx),
50            shutdown_tx: Some(shutdown_tx),
51        }
52    }
53}
54
55impl<R, W> AgentLoop<R, W>
56where
57    R: AsyncRead + Unpin + Send,
58    W: AsyncWrite + Unpin + Send,
59{
60    /// Create a new agent loop with custom reader/writer
61    pub fn with_io(reader: R, writer: W) -> Self {
62        let (shutdown_tx, shutdown_rx) = oneshot::channel();
63        Self {
64            reader,
65            writer,
66            codec: FrameCodec::new(),
67            handlers: Arc::new(RwLock::new(HashMap::new())),
68            shutdown_rx: Some(shutdown_rx),
69            shutdown_tx: Some(shutdown_tx),
70        }
71    }
72    
73    /// Register a handler for a specific request type
74    pub async fn register_handler(&self, request_type: String, handler: Arc<dyn Handler>) {
75        let mut handlers = self.handlers.write().await;
76        debug!("Registered handler for request type: {}", request_type);
77        handlers.insert(request_type, handler);
78    }
79    
80    /// Get shutdown sender for graceful shutdown
81    pub fn shutdown_sender(&mut self) -> Option<oneshot::Sender<()>> {
82        self.shutdown_tx.take()
83    }
84    
85    /// Run the agent loop
86    pub async fn run(&mut self) -> Result<()> {
87        info!("Starting agent loop");
88        
89        let mut shutdown_rx = self.shutdown_rx.take()
90            .context("Shutdown receiver already taken")?;
91        
92        loop {
93            tokio::select! {
94                // Handle shutdown signal
95                _ = &mut shutdown_rx => {
96                    info!("Received shutdown signal, stopping agent loop");
97                    break;
98                }
99                
100                // Process incoming frames
101                frame_result = self.codec.read_frame(&mut self.reader) => {
102                    match frame_result {
103                        Ok(Some(frame)) => {
104                            if let Err(e) = self.process_frame(frame).await {
105                                error!("Error processing frame: {}", e);
106                                // Continue processing other frames on error
107                            }
108                        }
109                        Ok(None) => {
110                            info!("Input stream closed, stopping agent loop");
111                            break;
112                        }
113                        Err(e) => {
114                            error!("Error reading frame: {}", e);
115                            // Try to continue on protocol errors
116                            continue;
117                        }
118                    }
119                }
120            }
121        }
122        
123        info!("Agent loop stopped");
124        Ok(())
125    }
126    
127    /// Process a single frame
128    async fn process_frame(&mut self, frame: Frame) -> Result<()> {
129        debug!("Processing frame: stream_id={}, sequence={}, flags={:?}, payload_size={}", 
130               frame.stream_id, frame.sequence, frame.flags, frame.payload.len());
131        
132        // Handle control frames
133        if frame.is_error() {
134            warn!("Received error frame: stream_id={}, payload={:?}", 
135                  frame.stream_id, frame.payload);
136            return Ok(());
137        }
138        
139        if frame.is_end_stream() {
140            debug!("Received end-of-stream frame: stream_id={}", frame.stream_id);
141            return Ok(());
142        }
143        
144        // Deserialize message from frame payload
145        let message = match rmp_serde::from_slice::<Message>(&frame.payload) {
146            Ok(msg) => msg,
147            Err(e) => {
148                error!("Failed to deserialize message: {}", e);
149                self.send_error_frame(frame.stream_id, frame.sequence, 
150                                    ErrorCode::InvalidRequest, 
151                                    format!("Invalid message format: {}", e)).await?;
152                return Ok(());
153            }
154        };
155        
156        // Dispatch message
157        match message {
158            Message::Request(request) => {
159                self.handle_request(frame.stream_id, frame.sequence, request).await?;
160            }
161            Message::Response(_) => {
162                warn!("Received unexpected response message on agent");
163                // Agents typically don't handle responses, only requests
164            }
165        }
166        
167        Ok(())
168    }
169    
170    /// Handle a request message
171    async fn handle_request(&mut self, stream_id: u32, sequence: u32, request: Request) -> Result<()> {
172        let request_id = request.id();
173        debug!("Handling request: id={}, type={:?}", request_id, std::mem::discriminant(&request));
174        
175        // Determine request type for handler lookup
176        let request_type = match &request {
177            Request::ProcessExec { .. } => "process_exec",
178            Request::FileGet { .. } => "file_get",
179            Request::FilePut { .. } => "file_put",
180            Request::DirList { .. } => "dir_list",
181            Request::WasmExec { .. } => "wasm_exec",
182            Request::JsonCall { .. } => "json_call",
183            Request::Ping { .. } => "ping",
184            Request::PtyExec { .. } => "pty_exec",
185        };
186        
187        // Look up handler
188        let handler = {
189            let handlers = self.handlers.read().await;
190            handlers.get(request_type).cloned()
191        };
192        
193        let response = match handler {
194            Some(handler) => {
195                // Execute handler
196                match handler.handle(request).await {
197                    Ok(response) => response,
198                    Err(e) => {
199                        error!("Handler error for request {}: {}", request_id, e);
200                        Response::error(
201                            request_id,
202                            ErrorDetails::new(ErrorCode::InternalError, format!("Handler error: {}", e))
203                        )
204                    }
205                }
206            }
207            None => {
208                warn!("No handler registered for request type: {}", request_type);
209                Response::error(
210                    request_id,
211                    ErrorDetails::new(ErrorCode::Unsupported, format!("Unsupported request type: {}", request_type))
212                )
213            }
214        };
215        
216        // Send response
217        self.send_response(stream_id, sequence, response).await?;
218        
219        Ok(())
220    }
221    
222    /// Send a response message
223    async fn send_response(&mut self, stream_id: u32, sequence: u32, response: Response) -> Result<()> {
224        let message = Message::response(response);
225        let payload = rmp_serde::to_vec(&message)
226            .context("Failed to serialize response message")?;
227        
228        let frame = Frame::data(stream_id, sequence, Bytes::from(payload));
229        self.codec.write_frame(&mut self.writer, &frame).await
230            .context("Failed to write response frame")?;
231        
232        debug!("Sent response: stream_id={}, sequence={}", stream_id, sequence);
233        Ok(())
234    }
235    
236    /// Send an error frame
237    async fn send_error_frame(&mut self, stream_id: u32, sequence: u32, 
238                            error_code: ErrorCode, message: String) -> Result<()> {
239        let error_payload = rmp_serde::to_vec(&ErrorDetails::new(error_code, message))
240            .context("Failed to serialize error details")?;
241        
242        let frame = Frame::error(stream_id, sequence, Bytes::from(error_payload));
243        self.codec.write_frame(&mut self.writer, &frame).await
244            .context("Failed to write error frame")?;
245        
246        debug!("Sent error frame: stream_id={}, sequence={}", stream_id, sequence);
247        Ok(())
248    }
249}
250
251impl<R, W> Default for AgentLoop<R, W>
252where
253    R: AsyncRead + Unpin + Send + Default,
254    W: AsyncWrite + Unpin + Send + Default,
255{
256    fn default() -> Self {
257        Self::with_io(R::default(), W::default())
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use mitoxide_proto::{Request, Response};
265    use std::io::Cursor;
266    use tokio::time::{timeout, Duration};
267    use uuid::Uuid;
268    
269    /// Mock handler for testing
270    struct MockHandler {
271        response: Response,
272    }
273    
274    #[async_trait::async_trait]
275    impl Handler for MockHandler {
276        async fn handle(&self, request: Request) -> Result<Response> {
277            // Echo back the request ID in a pong response
278            match request {
279                Request::Ping { id, timestamp } => {
280                    Ok(Response::pong(id, timestamp))
281                }
282                _ => Ok(self.response.clone()),
283            }
284        }
285    }
286    
287    #[tokio::test]
288    async fn test_agent_loop_creation() {
289        let agent = AgentLoop::new();
290        assert!(agent.shutdown_tx.is_some());
291        assert!(agent.shutdown_rx.is_some());
292    }
293    
294    #[tokio::test]
295    async fn test_handler_registration() {
296        let agent = AgentLoop::new();
297        let handler = Arc::new(MockHandler {
298            response: Response::pong(Uuid::new_v4(), 12345),
299        });
300        
301        agent.register_handler("test".to_string(), handler).await;
302        
303        let handlers = agent.handlers.read().await;
304        assert!(handlers.contains_key("test"));
305    }
306    
307    #[tokio::test]
308    async fn test_graceful_shutdown() {
309        let input = Cursor::new(Vec::<u8>::new());
310        let output = Cursor::new(Vec::<u8>::new());
311        let mut agent = AgentLoop::with_io(input, output);
312        
313        let shutdown_tx = agent.shutdown_sender().unwrap();
314        
315        // Start the agent loop in a task
316        let agent_task = tokio::spawn(async move {
317            agent.run().await
318        });
319        
320        // Send shutdown signal
321        shutdown_tx.send(()).unwrap();
322        
323        // Agent should stop gracefully
324        let result = timeout(Duration::from_secs(1), agent_task).await;
325        assert!(result.is_ok());
326        assert!(result.unwrap().unwrap().is_ok());
327    }
328    
329    #[tokio::test]
330    async fn test_ping_request_handling() {
331        // Create a ping request
332        let request = Request::ping();
333        let request_id = request.id();
334        let message = Message::request(request);
335        
336        // Serialize message and create frame
337        let payload = rmp_serde::to_vec(&message).unwrap();
338        let frame = Frame::data(1, 1, Bytes::from(payload.clone()));
339        
340        // Encode frame
341        let codec = FrameCodec::new();
342        let encoded_frame = codec.encode_frame(&frame).unwrap();
343        
344        // Create agent with mock I/O
345        let input = Cursor::new(encoded_frame.to_vec());
346        let output = Cursor::new(Vec::<u8>::new());
347        let mut agent = AgentLoop::with_io(input, output);
348        
349        // Register ping handler
350        let handler = Arc::new(MockHandler {
351            response: Response::pong(request_id, 12345),
352        });
353        agent.register_handler("ping".to_string(), handler).await;
354        
355        // Process the frame
356        let frame_to_process = Frame::data(1, 1, Bytes::from(payload));
357        let result = agent.process_frame(frame_to_process).await;
358        assert!(result.is_ok());
359    }
360    
361    #[tokio::test]
362    async fn test_invalid_message_handling() {
363        // Create frame with invalid payload
364        let frame = Frame::data(1, 1, Bytes::from(vec![0xFF, 0xFF, 0xFF, 0xFF]));
365        
366        let input = Cursor::new(Vec::<u8>::new());
367        let output = Cursor::new(Vec::<u8>::new());
368        let mut agent = AgentLoop::with_io(input, output);
369        
370        // Should handle invalid message gracefully
371        let result = agent.process_frame(frame).await;
372        assert!(result.is_ok());
373    }
374    
375    #[tokio::test]
376    async fn test_unsupported_request_handling() {
377        // Create a request without a registered handler
378        let request = Request::process_exec(
379            vec!["echo".to_string()],
380            std::collections::HashMap::new(),
381            None,
382            None,
383            None,
384        );
385        let message = Message::request(request);
386        let payload = rmp_serde::to_vec(&message).unwrap();
387        let frame = Frame::data(1, 1, Bytes::from(payload));
388        
389        let input = Cursor::new(Vec::<u8>::new());
390        let output = Cursor::new(Vec::<u8>::new());
391        let mut agent = AgentLoop::with_io(input, output);
392        
393        // Should handle unsupported request gracefully
394        let result = agent.process_frame(frame).await;
395        assert!(result.is_ok());
396    }
397    
398    #[tokio::test]
399    async fn test_error_frame_handling() {
400        let error_frame = Frame::error(1, 1, Bytes::from("test error"));
401        
402        let input = Cursor::new(Vec::<u8>::new());
403        let output = Cursor::new(Vec::<u8>::new());
404        let mut agent = AgentLoop::with_io(input, output);
405        
406        // Should handle error frames gracefully
407        let result = agent.process_frame(error_frame).await;
408        assert!(result.is_ok());
409    }
410    
411    #[tokio::test]
412    async fn test_end_stream_frame_handling() {
413        let end_frame = Frame::end_stream(1, 1);
414        
415        let input = Cursor::new(Vec::<u8>::new());
416        let output = Cursor::new(Vec::<u8>::new());
417        let mut agent = AgentLoop::with_io(input, output);
418        
419        // Should handle end-of-stream frames gracefully
420        let result = agent.process_frame(end_frame).await;
421        assert!(result.is_ok());
422    }
423}