Skip to main content

microsandbox_agentd/
agent.rs

1//! Main agent loop: serial I/O, session management, heartbeat.
2
3use std::{collections::HashMap, fs::OpenOptions, os::fd::AsRawFd};
4
5use chrono::Utc;
6use tokio::{
7    io::unix::AsyncFd,
8    sync::mpsc,
9    time::{Duration, interval},
10};
11
12use microsandbox_protocol::{
13    codec::{MAX_FRAME_SIZE, encode_to_buf, try_decode_from_buf},
14    core::Ready,
15    exec::{
16        ExecExited, ExecRequest, ExecResize, ExecSignal, ExecStarted, ExecStderr, ExecStdin,
17        ExecStdout,
18    },
19    fs::{FsData, FsRequest},
20    message::{Message, MessageType},
21};
22
23use crate::{
24    error::{AgentdError, AgentdResult},
25    fs::FsWriteSession,
26    heartbeat::{heartbeat_dir_exists, write_heartbeat},
27    serial::{AGENT_PORT_NAME, find_serial_port},
28    session::{ExecSession, SessionOutput},
29};
30
31//--------------------------------------------------------------------------------------------------
32// Constants
33//--------------------------------------------------------------------------------------------------
34
35/// Heartbeat interval in seconds.
36const HEARTBEAT_INTERVAL_SECS: u64 = 5;
37
38/// Read buffer size for the serial port.
39const SERIAL_READ_BUF_SIZE: usize = 64 * 1024;
40
41/// Maximum allowed input buffer size (frame size limit + 4 bytes for length prefix).
42const MAX_INPUT_BUF_SIZE: usize = MAX_FRAME_SIZE as usize + 4;
43
44//--------------------------------------------------------------------------------------------------
45// Functions
46//--------------------------------------------------------------------------------------------------
47
48/// Runs the main agent loop.
49///
50/// Discovers the virtio serial port, sends `core.ready` with boot timing data,
51/// then enters the main select loop handling serial I/O, process output, and heartbeat.
52///
53/// - `boot_time_ns`: `CLOCK_BOOTTIME` at `main()` start (kernel boot duration).
54/// - `init_time_ns`: nanoseconds spent in `init::init()`.
55pub async fn run(boot_time_ns: u64, init_time_ns: u64) -> AgentdResult<()> {
56    // Discover serial port.
57    let port_path = find_serial_port(AGENT_PORT_NAME)?;
58
59    // Open the port once with read+write. Virtio-console multiport devices
60    // only allow a single open; a second open returns EBUSY.
61    let port_file = OpenOptions::new().read(true).write(true).open(&port_path)?;
62
63    // Set non-blocking for async I/O.
64    let port_fd = port_file.as_raw_fd();
65    set_nonblocking(port_fd)?;
66
67    // A single AsyncFd tracks both readable and writable readiness.
68    let async_port = AsyncFd::new(port_file)?;
69
70    // Buffer for serial reads.
71    let mut read_buf = vec![0u8; SERIAL_READ_BUF_SIZE];
72    let mut serial_in_buf = Vec::new();
73    let mut serial_out_buf = Vec::new();
74
75    // Active exec sessions.
76    let mut sessions: HashMap<u32, ExecSession> = HashMap::new();
77
78    // Active filesystem write sessions.
79    let mut write_sessions: HashMap<u32, FsWriteSession> = HashMap::new();
80
81    // Channel for session output events.
82    let (session_tx, mut session_rx) = mpsc::unbounded_channel::<(u32, SessionOutput)>();
83
84    // Heartbeat state.
85    let mut last_activity = Utc::now();
86    let mut heartbeat_timer = interval(Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
87
88    // Send core.ready with boot timing data.
89    let ready_time_ns = crate::clock::boottime_ns();
90    let ready_msg = Message::with_payload(
91        MessageType::Ready,
92        0,
93        &Ready {
94            boot_time_ns,
95            init_time_ns,
96            ready_time_ns,
97        },
98    )
99    .map_err(|e| AgentdError::ExecSession(format!("encode ready: {e}")))?;
100    encode_to_buf(&ready_msg, &mut serial_out_buf)
101        .map_err(|e| AgentdError::ExecSession(format!("encode ready frame: {e}")))?;
102    flush_write_buf(&async_port, &mut serial_out_buf).await?;
103
104    // Main loop.
105    loop {
106        tokio::select! {
107            // Read from serial port.
108            result = async_read_ready(&async_port) => {
109                if result.is_ok() {
110                    match read_from_fd(port_fd, &mut read_buf) {
111                        Ok(n) if n > 0 => {
112                            serial_in_buf.extend_from_slice(&read_buf[..n]);
113                            last_activity = Utc::now();
114
115                            // Guard against unbounded buffer growth.
116                            if serial_in_buf.len() > MAX_INPUT_BUF_SIZE {
117                                return Err(AgentdError::ExecSession(
118                                    "serial input buffer exceeded maximum size".into(),
119                                ));
120                            }
121
122                            // Try to parse complete messages.
123                            while let Some(msg) = try_decode_from_buf(&mut serial_in_buf)
124                                .map_err(|e| AgentdError::ExecSession(format!("decode: {e}")))?
125                            {
126                                handle_message(
127                                    msg,
128                                    &mut sessions,
129                                    &mut write_sessions,
130                                    &session_tx,
131                                    &mut serial_out_buf,
132                                ).await?;
133                            }
134
135                            // Flush any outgoing messages.
136                            if !serial_out_buf.is_empty() {
137                                flush_write_buf(&async_port, &mut serial_out_buf).await?;
138                            }
139                        }
140                        Ok(_) => {
141                            // EOF on serial — host disconnected.
142                            break;
143                        }
144                        Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
145                            // No data available, continue.
146                        }
147                        Err(e) => return Err(e.into()),
148                    }
149                }
150            }
151
152            // Receive output events from session reader tasks.
153            Some((id, output)) = session_rx.recv() => {
154                match output {
155                    SessionOutput::Stdout(data) => {
156                        let msg = Message::with_payload(MessageType::ExecStdout, id, &ExecStdout { data })
157                            .map_err(|e| AgentdError::ExecSession(format!("encode stdout: {e}")))?;
158                        encode_to_buf(&msg, &mut serial_out_buf)
159                            .map_err(|e| AgentdError::ExecSession(format!("encode stdout frame: {e}")))?;
160                    }
161                    SessionOutput::Stderr(data) => {
162                        let msg = Message::with_payload(MessageType::ExecStderr, id, &ExecStderr { data })
163                            .map_err(|e| AgentdError::ExecSession(format!("encode stderr: {e}")))?;
164                        encode_to_buf(&msg, &mut serial_out_buf)
165                            .map_err(|e| AgentdError::ExecSession(format!("encode stderr frame: {e}")))?;
166                    }
167                    SessionOutput::Exited(code) => {
168                        let msg = Message::with_payload(MessageType::ExecExited, id, &ExecExited { code })
169                            .map_err(|e| AgentdError::ExecSession(format!("encode exited: {e}")))?;
170                        encode_to_buf(&msg, &mut serial_out_buf)
171                            .map_err(|e| AgentdError::ExecSession(format!("encode exited frame: {e}")))?;
172                        sessions.remove(&id);
173                    }
174                    SessionOutput::Raw(frame_bytes) => {
175                        // Pre-encoded frame — write directly to output buffer.
176                        serial_out_buf.extend_from_slice(&frame_bytes);
177                    }
178                }
179
180                if !serial_out_buf.is_empty() {
181                    flush_write_buf(&async_port, &mut serial_out_buf).await?;
182                }
183            }
184
185            // Heartbeat tick.
186            _ = heartbeat_timer.tick() => {
187                if heartbeat_dir_exists() {
188                    let _ = write_heartbeat(
189                        sessions.len() as u32,
190                        last_activity,
191                    ).await;
192                }
193            }
194        }
195    }
196
197    Ok(())
198}
199
200//--------------------------------------------------------------------------------------------------
201// Functions: Helpers
202//--------------------------------------------------------------------------------------------------
203
204/// Handles a single incoming message from the host.
205async fn handle_message(
206    msg: Message,
207    sessions: &mut HashMap<u32, ExecSession>,
208    write_sessions: &mut HashMap<u32, FsWriteSession>,
209    session_tx: &mpsc::UnboundedSender<(u32, SessionOutput)>,
210    out_buf: &mut Vec<u8>,
211) -> AgentdResult<()> {
212    match msg.t {
213        MessageType::ExecRequest => {
214            let req: ExecRequest = msg
215                .payload()
216                .map_err(|e| AgentdError::ExecSession(format!("decode exec request: {e}")))?;
217            match ExecSession::spawn(msg.id, &req, session_tx.clone()) {
218                Ok(session) => {
219                    let reply = Message::with_payload(
220                        MessageType::ExecStarted,
221                        msg.id,
222                        &ExecStarted { pid: session.pid() },
223                    )
224                    .map_err(|e| AgentdError::ExecSession(format!("encode started: {e}")))?;
225                    encode_to_buf(&reply, out_buf).map_err(|e| {
226                        AgentdError::ExecSession(format!("encode started frame: {e}"))
227                    })?;
228                    sessions.insert(msg.id, session);
229                }
230                Err(e) => {
231                    // Send an immediate exit with code -1 on spawn failure.
232                    let reply = Message::with_payload(
233                        MessageType::ExecExited,
234                        msg.id,
235                        &ExecExited { code: -1 },
236                    )
237                    .map_err(|e| AgentdError::ExecSession(format!("encode exited: {e}")))?;
238                    encode_to_buf(&reply, out_buf).map_err(|e| {
239                        AgentdError::ExecSession(format!("encode exited frame: {e}"))
240                    })?;
241                    eprintln!("failed to spawn exec session {}: {e}", msg.id);
242                }
243            }
244        }
245
246        MessageType::ExecStdin => {
247            let stdin: ExecStdin = msg
248                .payload()
249                .map_err(|e| AgentdError::ExecSession(format!("decode stdin: {e}")))?;
250            if let Some(session) = sessions.get_mut(&msg.id) {
251                if stdin.data.is_empty() {
252                    // Empty data signals EOF — close stdin.
253                    session.close_stdin();
254                } else {
255                    let _ = session.write_stdin(&stdin.data).await;
256                }
257            }
258        }
259
260        MessageType::ExecResize => {
261            let resize: ExecResize = msg
262                .payload()
263                .map_err(|e| AgentdError::ExecSession(format!("decode resize: {e}")))?;
264            if let Some(session) = sessions.get(&msg.id) {
265                let _ = session.resize(resize.rows, resize.cols);
266            }
267        }
268
269        MessageType::ExecSignal => {
270            let signal: ExecSignal = msg
271                .payload()
272                .map_err(|e| AgentdError::ExecSession(format!("decode signal: {e}")))?;
273            if let Some(session) = sessions.get(&msg.id) {
274                let _ = session.send_signal(signal.signal);
275            }
276        }
277
278        MessageType::FsRequest => {
279            let req: FsRequest = msg
280                .payload()
281                .map_err(|e| AgentdError::ExecSession(format!("decode fs request: {e}")))?;
282            match crate::fs::handle_fs_request(msg.id, req, out_buf, session_tx).await {
283                Ok(Some(ws)) => {
284                    write_sessions.insert(msg.id, ws);
285                }
286                Ok(None) => {}
287                Err(e) => {
288                    eprintln!("fs request error for {}: {e}", msg.id);
289                }
290            }
291        }
292
293        MessageType::FsData => {
294            let data: FsData = msg
295                .payload()
296                .map_err(|e| AgentdError::ExecSession(format!("decode fs data: {e}")))?;
297            if let Some(session) = write_sessions.get_mut(&msg.id) {
298                match crate::fs::handle_fs_data(msg.id, data, session, out_buf).await {
299                    Ok(true) => {
300                        // Session complete — remove it.
301                        write_sessions.remove(&msg.id);
302                    }
303                    Ok(false) => {}
304                    Err(e) => {
305                        eprintln!("fs data error for {}: {e}", msg.id);
306                        write_sessions.remove(&msg.id);
307                    }
308                }
309            } else {
310                // No write session for this ID — send error response.
311                let resp = microsandbox_protocol::fs::FsResponse {
312                    ok: false,
313                    error: Some(format!("unknown write session: {}", msg.id)),
314                    data: None,
315                };
316                let reply = Message::with_payload(MessageType::FsResponse, msg.id, &resp)
317                    .map_err(|e| AgentdError::ExecSession(format!("encode fs error: {e}")))?;
318                encode_to_buf(&reply, out_buf)
319                    .map_err(|e| AgentdError::ExecSession(format!("encode fs error frame: {e}")))?;
320            }
321        }
322
323        MessageType::Shutdown => {
324            // Graceful shutdown — signal all sessions and break from main loop.
325            for (_, session) in sessions.drain() {
326                let _ = session.send_signal(15); // SIGTERM
327            }
328            write_sessions.clear();
329            return Err(AgentdError::Shutdown);
330        }
331
332        _ => {
333            // Ignore unknown or unexpected message types.
334        }
335    }
336
337    Ok(())
338}
339
340/// Sets a file descriptor to non-blocking mode.
341fn set_nonblocking(fd: i32) -> AgentdResult<()> {
342    let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
343    if flags < 0 {
344        return Err(std::io::Error::last_os_error().into());
345    }
346    let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
347    if ret < 0 {
348        return Err(std::io::Error::last_os_error().into());
349    }
350    Ok(())
351}
352
353/// Waits for the async fd to be readable.
354async fn async_read_ready(fd: &AsyncFd<std::fs::File>) -> std::io::Result<()> {
355    let mut guard = fd.readable().await?;
356    guard.clear_ready();
357    Ok(())
358}
359
360/// Reads from a raw fd (non-blocking).
361fn read_from_fd(fd: i32, buf: &mut [u8]) -> std::io::Result<usize> {
362    let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
363    if n < 0 {
364        Err(std::io::Error::last_os_error())
365    } else {
366        Ok(n as usize)
367    }
368}
369
370/// Flushes the write buffer to the async fd.
371async fn flush_write_buf(fd: &AsyncFd<std::fs::File>, buf: &mut Vec<u8>) -> AgentdResult<()> {
372    while !buf.is_empty() {
373        let mut guard = fd.writable().await?;
374        let raw_fd = fd.as_raw_fd();
375        match write_to_fd(raw_fd, buf) {
376            Ok(n) => {
377                buf.drain(..n);
378            }
379            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
380                guard.clear_ready();
381                continue;
382            }
383            Err(e) => return Err(e.into()),
384        }
385        guard.clear_ready();
386    }
387    Ok(())
388}
389
390/// Writes to a raw fd (non-blocking).
391fn write_to_fd(fd: i32, buf: &[u8]) -> std::io::Result<usize> {
392    let n = unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) };
393    if n < 0 {
394        Err(std::io::Error::last_os_error())
395    } else {
396        Ok(n as usize)
397    }
398}