1use anyhow::{Context, Result};
4use bytes::Bytes;
5use mitoxide_proto::{Frame, FrameCodec, Message, Request, Response};
6use mitoxide_proto::message::{ErrorCode, ErrorDetails};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::io::{stdin, stdout, AsyncRead, AsyncWrite};
10use tokio::sync::{oneshot, RwLock};
11use tracing::{debug, error, info, warn};
12
13#[async_trait::async_trait]
15pub trait Handler: Send + Sync {
16 async fn handle(&self, request: Request) -> Result<Response>;
18}
19
20pub struct AgentLoop<R, W>
22where
23 R: AsyncRead + Unpin + Send,
24 W: AsyncWrite + Unpin + Send,
25{
26 reader: R,
28 writer: W,
30 codec: FrameCodec,
32 handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler>>>>,
34 shutdown_rx: Option<oneshot::Receiver<()>>,
36 shutdown_tx: Option<oneshot::Sender<()>>,
38}
39
40impl AgentLoop<tokio::io::Stdin, tokio::io::Stdout> {
41 pub fn new() -> Self {
43 let (shutdown_tx, shutdown_rx) = oneshot::channel();
44 Self {
45 reader: stdin(),
46 writer: stdout(),
47 codec: FrameCodec::new(),
48 handlers: Arc::new(RwLock::new(HashMap::new())),
49 shutdown_rx: Some(shutdown_rx),
50 shutdown_tx: Some(shutdown_tx),
51 }
52 }
53}
54
55impl<R, W> AgentLoop<R, W>
56where
57 R: AsyncRead + Unpin + Send,
58 W: AsyncWrite + Unpin + Send,
59{
60 pub fn with_io(reader: R, writer: W) -> Self {
62 let (shutdown_tx, shutdown_rx) = oneshot::channel();
63 Self {
64 reader,
65 writer,
66 codec: FrameCodec::new(),
67 handlers: Arc::new(RwLock::new(HashMap::new())),
68 shutdown_rx: Some(shutdown_rx),
69 shutdown_tx: Some(shutdown_tx),
70 }
71 }
72
73 pub async fn register_handler(&self, request_type: String, handler: Arc<dyn Handler>) {
75 let mut handlers = self.handlers.write().await;
76 debug!("Registered handler for request type: {}", request_type);
77 handlers.insert(request_type, handler);
78 }
79
80 pub fn shutdown_sender(&mut self) -> Option<oneshot::Sender<()>> {
82 self.shutdown_tx.take()
83 }
84
85 pub async fn run(&mut self) -> Result<()> {
87 info!("Starting agent loop");
88
89 let mut shutdown_rx = self.shutdown_rx.take()
90 .context("Shutdown receiver already taken")?;
91
92 loop {
93 tokio::select! {
94 _ = &mut shutdown_rx => {
96 info!("Received shutdown signal, stopping agent loop");
97 break;
98 }
99
100 frame_result = self.codec.read_frame(&mut self.reader) => {
102 match frame_result {
103 Ok(Some(frame)) => {
104 if let Err(e) = self.process_frame(frame).await {
105 error!("Error processing frame: {}", e);
106 }
108 }
109 Ok(None) => {
110 info!("Input stream closed, stopping agent loop");
111 break;
112 }
113 Err(e) => {
114 error!("Error reading frame: {}", e);
115 continue;
117 }
118 }
119 }
120 }
121 }
122
123 info!("Agent loop stopped");
124 Ok(())
125 }
126
127 async fn process_frame(&mut self, frame: Frame) -> Result<()> {
129 debug!("Processing frame: stream_id={}, sequence={}, flags={:?}, payload_size={}",
130 frame.stream_id, frame.sequence, frame.flags, frame.payload.len());
131
132 if frame.is_error() {
134 warn!("Received error frame: stream_id={}, payload={:?}",
135 frame.stream_id, frame.payload);
136 return Ok(());
137 }
138
139 if frame.is_end_stream() {
140 debug!("Received end-of-stream frame: stream_id={}", frame.stream_id);
141 return Ok(());
142 }
143
144 let message = match rmp_serde::from_slice::<Message>(&frame.payload) {
146 Ok(msg) => msg,
147 Err(e) => {
148 error!("Failed to deserialize message: {}", e);
149 self.send_error_frame(frame.stream_id, frame.sequence,
150 ErrorCode::InvalidRequest,
151 format!("Invalid message format: {}", e)).await?;
152 return Ok(());
153 }
154 };
155
156 match message {
158 Message::Request(request) => {
159 self.handle_request(frame.stream_id, frame.sequence, request).await?;
160 }
161 Message::Response(_) => {
162 warn!("Received unexpected response message on agent");
163 }
165 }
166
167 Ok(())
168 }
169
170 async fn handle_request(&mut self, stream_id: u32, sequence: u32, request: Request) -> Result<()> {
172 let request_id = request.id();
173 debug!("Handling request: id={}, type={:?}", request_id, std::mem::discriminant(&request));
174
175 let request_type = match &request {
177 Request::ProcessExec { .. } => "process_exec",
178 Request::FileGet { .. } => "file_get",
179 Request::FilePut { .. } => "file_put",
180 Request::DirList { .. } => "dir_list",
181 Request::WasmExec { .. } => "wasm_exec",
182 Request::JsonCall { .. } => "json_call",
183 Request::Ping { .. } => "ping",
184 Request::PtyExec { .. } => "pty_exec",
185 };
186
187 let handler = {
189 let handlers = self.handlers.read().await;
190 handlers.get(request_type).cloned()
191 };
192
193 let response = match handler {
194 Some(handler) => {
195 match handler.handle(request).await {
197 Ok(response) => response,
198 Err(e) => {
199 error!("Handler error for request {}: {}", request_id, e);
200 Response::error(
201 request_id,
202 ErrorDetails::new(ErrorCode::InternalError, format!("Handler error: {}", e))
203 )
204 }
205 }
206 }
207 None => {
208 warn!("No handler registered for request type: {}", request_type);
209 Response::error(
210 request_id,
211 ErrorDetails::new(ErrorCode::Unsupported, format!("Unsupported request type: {}", request_type))
212 )
213 }
214 };
215
216 self.send_response(stream_id, sequence, response).await?;
218
219 Ok(())
220 }
221
222 async fn send_response(&mut self, stream_id: u32, sequence: u32, response: Response) -> Result<()> {
224 let message = Message::response(response);
225 let payload = rmp_serde::to_vec(&message)
226 .context("Failed to serialize response message")?;
227
228 let frame = Frame::data(stream_id, sequence, Bytes::from(payload));
229 self.codec.write_frame(&mut self.writer, &frame).await
230 .context("Failed to write response frame")?;
231
232 debug!("Sent response: stream_id={}, sequence={}", stream_id, sequence);
233 Ok(())
234 }
235
236 async fn send_error_frame(&mut self, stream_id: u32, sequence: u32,
238 error_code: ErrorCode, message: String) -> Result<()> {
239 let error_payload = rmp_serde::to_vec(&ErrorDetails::new(error_code, message))
240 .context("Failed to serialize error details")?;
241
242 let frame = Frame::error(stream_id, sequence, Bytes::from(error_payload));
243 self.codec.write_frame(&mut self.writer, &frame).await
244 .context("Failed to write error frame")?;
245
246 debug!("Sent error frame: stream_id={}, sequence={}", stream_id, sequence);
247 Ok(())
248 }
249}
250
251impl<R, W> Default for AgentLoop<R, W>
252where
253 R: AsyncRead + Unpin + Send + Default,
254 W: AsyncWrite + Unpin + Send + Default,
255{
256 fn default() -> Self {
257 Self::with_io(R::default(), W::default())
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use mitoxide_proto::{Request, Response};
265 use std::io::Cursor;
266 use tokio::time::{timeout, Duration};
267 use uuid::Uuid;
268
269 struct MockHandler {
271 response: Response,
272 }
273
274 #[async_trait::async_trait]
275 impl Handler for MockHandler {
276 async fn handle(&self, request: Request) -> Result<Response> {
277 match request {
279 Request::Ping { id, timestamp } => {
280 Ok(Response::pong(id, timestamp))
281 }
282 _ => Ok(self.response.clone()),
283 }
284 }
285 }
286
287 #[tokio::test]
288 async fn test_agent_loop_creation() {
289 let agent = AgentLoop::new();
290 assert!(agent.shutdown_tx.is_some());
291 assert!(agent.shutdown_rx.is_some());
292 }
293
294 #[tokio::test]
295 async fn test_handler_registration() {
296 let agent = AgentLoop::new();
297 let handler = Arc::new(MockHandler {
298 response: Response::pong(Uuid::new_v4(), 12345),
299 });
300
301 agent.register_handler("test".to_string(), handler).await;
302
303 let handlers = agent.handlers.read().await;
304 assert!(handlers.contains_key("test"));
305 }
306
307 #[tokio::test]
308 async fn test_graceful_shutdown() {
309 let input = Cursor::new(Vec::<u8>::new());
310 let output = Cursor::new(Vec::<u8>::new());
311 let mut agent = AgentLoop::with_io(input, output);
312
313 let shutdown_tx = agent.shutdown_sender().unwrap();
314
315 let agent_task = tokio::spawn(async move {
317 agent.run().await
318 });
319
320 shutdown_tx.send(()).unwrap();
322
323 let result = timeout(Duration::from_secs(1), agent_task).await;
325 assert!(result.is_ok());
326 assert!(result.unwrap().unwrap().is_ok());
327 }
328
329 #[tokio::test]
330 async fn test_ping_request_handling() {
331 let request = Request::ping();
333 let request_id = request.id();
334 let message = Message::request(request);
335
336 let payload = rmp_serde::to_vec(&message).unwrap();
338 let frame = Frame::data(1, 1, Bytes::from(payload.clone()));
339
340 let codec = FrameCodec::new();
342 let encoded_frame = codec.encode_frame(&frame).unwrap();
343
344 let input = Cursor::new(encoded_frame.to_vec());
346 let output = Cursor::new(Vec::<u8>::new());
347 let mut agent = AgentLoop::with_io(input, output);
348
349 let handler = Arc::new(MockHandler {
351 response: Response::pong(request_id, 12345),
352 });
353 agent.register_handler("ping".to_string(), handler).await;
354
355 let frame_to_process = Frame::data(1, 1, Bytes::from(payload));
357 let result = agent.process_frame(frame_to_process).await;
358 assert!(result.is_ok());
359 }
360
361 #[tokio::test]
362 async fn test_invalid_message_handling() {
363 let frame = Frame::data(1, 1, Bytes::from(vec![0xFF, 0xFF, 0xFF, 0xFF]));
365
366 let input = Cursor::new(Vec::<u8>::new());
367 let output = Cursor::new(Vec::<u8>::new());
368 let mut agent = AgentLoop::with_io(input, output);
369
370 let result = agent.process_frame(frame).await;
372 assert!(result.is_ok());
373 }
374
375 #[tokio::test]
376 async fn test_unsupported_request_handling() {
377 let request = Request::process_exec(
379 vec!["echo".to_string()],
380 std::collections::HashMap::new(),
381 None,
382 None,
383 None,
384 );
385 let message = Message::request(request);
386 let payload = rmp_serde::to_vec(&message).unwrap();
387 let frame = Frame::data(1, 1, Bytes::from(payload));
388
389 let input = Cursor::new(Vec::<u8>::new());
390 let output = Cursor::new(Vec::<u8>::new());
391 let mut agent = AgentLoop::with_io(input, output);
392
393 let result = agent.process_frame(frame).await;
395 assert!(result.is_ok());
396 }
397
398 #[tokio::test]
399 async fn test_error_frame_handling() {
400 let error_frame = Frame::error(1, 1, Bytes::from("test error"));
401
402 let input = Cursor::new(Vec::<u8>::new());
403 let output = Cursor::new(Vec::<u8>::new());
404 let mut agent = AgentLoop::with_io(input, output);
405
406 let result = agent.process_frame(error_frame).await;
408 assert!(result.is_ok());
409 }
410
411 #[tokio::test]
412 async fn test_end_stream_frame_handling() {
413 let end_frame = Frame::end_stream(1, 1);
414
415 let input = Cursor::new(Vec::<u8>::new());
416 let output = Cursor::new(Vec::<u8>::new());
417 let mut agent = AgentLoop::with_io(input, output);
418
419 let result = agent.process_frame(end_frame).await;
421 assert!(result.is_ok());
422 }
423}