1use std::io::IoSlice;
6use std::num::NonZeroUsize;
7use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
8
9use anyhow::Result;
10use nix::libc::EINTR;
11use nix::sys::eventfd::EventFd;
12use nix::sys::socket::{sendmsg, ControlMessage, MsgFlags};
13use polling::{Event, Events, Poller};
14
15use a653rs_linux_core::mfd::{Mfd, Seals};
16use a653rs_linux_core::syscall::{SyscallRequest, SyscallResponse};
17
18use crate::SYSCALL;
19
20fn send_fds<const COUNT: usize, T: AsRawFd>(hv: BorrowedFd, fds: [T; COUNT]) -> Result<()> {
22 let fds = fds.map(|f| f.as_raw_fd());
23 let cmsg = [ControlMessage::ScmRights(&fds)];
24 let buffer = 0_u64.to_ne_bytes();
25 let iov = [IoSlice::new(buffer.as_slice())];
26 sendmsg::<()>(hv.as_raw_fd(), &iov, &cmsg, MsgFlags::empty(), None)?;
27 Ok(())
28}
29
30fn wait_event(event_fd: BorrowedFd) -> Result<()> {
33 let poller = Poller::new()?;
34 let mut events = Events::with_capacity(NonZeroUsize::MIN);
35 unsafe {
36 poller.add(event_fd.as_raw_fd(), Event::readable(0))?;
37 }
38
39 loop {
40 match poller.wait(&mut events, None) {
41 Ok(1) => break,
42 Err(e) => {
43 if e.raw_os_error() == Some(EINTR) {
44 continue;
45 } else {
46 panic!("poller failed with {:?}", e)
47 }
48 }
49 _ => panic!("unknown poller state"),
50 }
51 }
52
53 Ok(())
54}
55
56fn execute_fd(fd: BorrowedFd, request: SyscallRequest) -> Result<SyscallResponse> {
57 let mut request_fd = Mfd::create("requ")?;
59 let mut response_fd = Mfd::create("resp")?;
60 let event_fd = EventFd::new()?;
61
62 request_fd.write(&request.serialize()?)?;
64 request_fd.finalize(Seals::Readable)?;
65
66 send_fds(
68 fd,
69 [request_fd.as_fd(), response_fd.as_fd(), event_fd.as_fd()],
70 )?;
71
72 wait_event(event_fd.as_fd())?;
73
74 let response = SyscallResponse::deserialize(&response_fd.read_all()?)?;
75 Ok(response)
76}
77
78pub fn execute(request: SyscallRequest) -> Result<SyscallResponse> {
79 execute_fd(SYSCALL.as_fd(), request)
80}
81
82#[cfg(test)]
83mod tests {
84 use std::io::IoSliceMut;
85 use std::os::fd::{FromRawFd, OwnedFd, RawFd};
86
87 use nix::sys::socket::{
88 recvmsg, socketpair, AddressFamily, ControlMessageOwned, SockFlag, SockType,
89 };
90 use nix::{cmsg_space, unistd};
91
92 use a653rs_linux_core::syscall::ApexSyscall;
93
94 use super::*;
95
96 #[test]
97 fn test_execute() {
98 let (requester, responder) = socketpair(
99 AddressFamily::Unix,
100 SockType::Datagram,
101 None,
102 SockFlag::empty(),
103 )
104 .unwrap();
105
106 let request_thread = std::thread::spawn(move || {
107 let response = execute_fd(
108 requester.as_fd(),
109 SyscallRequest {
110 id: ApexSyscall::Start,
111 params: vec![1, 2, 42],
112 },
113 )
114 .unwrap();
115
116 assert_eq!(response.id, ApexSyscall::Start);
117 assert_eq!(response.status, 42);
118 });
119 let response_thread = std::thread::spawn(move || {
120 let mut cmsg = cmsg_space!([RawFd; 3]);
122 let mut iobuf = [0u8];
123 let mut iov = [IoSliceMut::new(&mut iobuf)];
124 let res = recvmsg::<()>(
125 responder.as_raw_fd(),
126 &mut iov,
127 Some(&mut cmsg),
128 MsgFlags::empty(),
129 )
130 .unwrap();
131
132 let fds: Vec<OwnedFd> = match res.cmsgs().unwrap().next().unwrap() {
133 ControlMessageOwned::ScmRights(fds) => fds
134 .into_iter()
135 .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) })
136 .collect::<Vec<_>>(),
137 _ => panic!("unknown cmsg received"),
138 };
139
140 let [request, response, event_fd] = fds.try_into().unwrap();
141 let mut request_fd = Mfd::from_fd(request).unwrap();
142 let mut response_fd = Mfd::from_fd(response).unwrap();
143
144 let request = SyscallRequest::deserialize(&request_fd.read_all().unwrap()).unwrap();
146 assert_eq!(request.id, ApexSyscall::Start);
147 assert_eq!(request.params, vec![1, 2, 42]);
148
149 response_fd
151 .write(
152 &SyscallResponse {
153 id: ApexSyscall::Start,
154 status: 42,
155 }
156 .serialize()
157 .unwrap(),
158 )
159 .unwrap();
160 response_fd.finalize(Seals::Readable).unwrap();
161
162 let buf = 1_u64.to_ne_bytes();
164 unistd::write(event_fd, &buf).unwrap();
165 });
166
167 request_thread.join().unwrap();
168 response_thread.join().unwrap();
169 }
170}