a653rs_linux/
syscall.rs

1//! Implementation of the mechanism to perform system calls
2
3// TODO: Document the mechanism here
4
5use std::io::IoSlice;
6use std::num::NonZeroUsize;
7use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
8
9use anyhow::Result;
10use nix::libc::EINTR;
11use nix::sys::eventfd::EventFd;
12use nix::sys::socket::{sendmsg, ControlMessage, MsgFlags};
13use polling::{Event, Events, Poller};
14
15use a653rs_linux_core::mfd::{Mfd, Seals};
16use a653rs_linux_core::syscall::{SyscallRequest, SyscallResponse};
17
18use crate::SYSCALL;
19
20/// Sends a vector of file descriptors through a Unix socket
21fn send_fds<const COUNT: usize, T: AsRawFd>(hv: BorrowedFd, fds: [T; COUNT]) -> Result<()> {
22    let fds = fds.map(|f| f.as_raw_fd());
23    let cmsg = [ControlMessage::ScmRights(&fds)];
24    let buffer = 0_u64.to_ne_bytes();
25    let iov = [IoSlice::new(buffer.as_slice())];
26    sendmsg::<()>(hv.as_raw_fd(), &iov, &cmsg, MsgFlags::empty(), None)?;
27    Ok(())
28}
29
30/// Waits for action on the event fd
31// TODO: Consider timeout
32fn wait_event(event_fd: BorrowedFd) -> Result<()> {
33    let poller = Poller::new()?;
34    let mut events = Events::with_capacity(NonZeroUsize::MIN);
35    unsafe {
36        poller.add(event_fd.as_raw_fd(), Event::readable(0))?;
37    }
38
39    loop {
40        match poller.wait(&mut events, None) {
41            Ok(1) => break,
42            Err(e) => {
43                if e.raw_os_error() == Some(EINTR) {
44                    continue;
45                } else {
46                    panic!("poller failed with {:?}", e)
47                }
48            }
49            _ => panic!("unknown poller state"),
50        }
51    }
52
53    Ok(())
54}
55
56fn execute_fd(fd: BorrowedFd, request: SyscallRequest) -> Result<SyscallResponse> {
57    // Create the file descriptor triple
58    let mut request_fd = Mfd::create("requ")?;
59    let mut response_fd = Mfd::create("resp")?;
60    let event_fd = EventFd::new()?;
61
62    // Initialize the request file descriptor
63    request_fd.write(&request.serialize()?)?;
64    request_fd.finalize(Seals::Readable)?;
65
66    // Send the file descriptors to the hypervisor
67    send_fds(
68        fd,
69        [request_fd.as_fd(), response_fd.as_fd(), event_fd.as_fd()],
70    )?;
71
72    wait_event(event_fd.as_fd())?;
73
74    let response = SyscallResponse::deserialize(&response_fd.read_all()?)?;
75    Ok(response)
76}
77
78pub fn execute(request: SyscallRequest) -> Result<SyscallResponse> {
79    execute_fd(SYSCALL.as_fd(), request)
80}
81
82#[cfg(test)]
83mod tests {
84    use std::io::IoSliceMut;
85    use std::os::fd::{FromRawFd, OwnedFd, RawFd};
86
87    use nix::sys::socket::{
88        recvmsg, socketpair, AddressFamily, ControlMessageOwned, SockFlag, SockType,
89    };
90    use nix::{cmsg_space, unistd};
91
92    use a653rs_linux_core::syscall::ApexSyscall;
93
94    use super::*;
95
96    #[test]
97    fn test_execute() {
98        let (requester, responder) = socketpair(
99            AddressFamily::Unix,
100            SockType::Datagram,
101            None,
102            SockFlag::empty(),
103        )
104        .unwrap();
105
106        let request_thread = std::thread::spawn(move || {
107            let response = execute_fd(
108                requester.as_fd(),
109                SyscallRequest {
110                    id: ApexSyscall::Start,
111                    params: vec![1, 2, 42],
112                },
113            )
114            .unwrap();
115
116            assert_eq!(response.id, ApexSyscall::Start);
117            assert_eq!(response.status, 42);
118        });
119        let response_thread = std::thread::spawn(move || {
120            // Receive the file descriptors
121            let mut cmsg = cmsg_space!([RawFd; 3]);
122            let mut iobuf = [0u8];
123            let mut iov = [IoSliceMut::new(&mut iobuf)];
124            let res = recvmsg::<()>(
125                responder.as_raw_fd(),
126                &mut iov,
127                Some(&mut cmsg),
128                MsgFlags::empty(),
129            )
130            .unwrap();
131
132            let fds: Vec<OwnedFd> = match res.cmsgs().unwrap().next().unwrap() {
133                ControlMessageOwned::ScmRights(fds) => fds
134                    .into_iter()
135                    .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) })
136                    .collect::<Vec<_>>(),
137                _ => panic!("unknown cmsg received"),
138            };
139
140            let [request, response, event_fd] = fds.try_into().unwrap();
141            let mut request_fd = Mfd::from_fd(request).unwrap();
142            let mut response_fd = Mfd::from_fd(response).unwrap();
143
144            // Fetch the request
145            let request = SyscallRequest::deserialize(&request_fd.read_all().unwrap()).unwrap();
146            assert_eq!(request.id, ApexSyscall::Start);
147            assert_eq!(request.params, vec![1, 2, 42]);
148
149            // Write the response
150            response_fd
151                .write(
152                    &SyscallResponse {
153                        id: ApexSyscall::Start,
154                        status: 42,
155                    }
156                    .serialize()
157                    .unwrap(),
158                )
159                .unwrap();
160            response_fd.finalize(Seals::Readable).unwrap();
161
162            // Trigger the eventfd
163            let buf = 1_u64.to_ne_bytes();
164            unistd::write(event_fd, &buf).unwrap();
165        });
166
167        request_thread.join().unwrap();
168        response_thread.join().unwrap();
169    }
170}