Skip to main content

nexus_core/
module_b.rs

1use std::path::PathBuf;
2
3use anyhow::Result;
4
5#[derive(Debug, Clone)]
6pub struct ModuleBConfig {
7    pub memfd_bytes: u64,
8    pub socket_path: Option<PathBuf>,
9    pub marker: String,
10}
11
12#[derive(Debug, Clone)]
13pub struct ModuleBReceiverConfig {
14    pub socket_path: PathBuf,
15    pub marker: String,
16    pub memfd_bytes: u64,
17}
18
19#[derive(Debug, Clone)]
20pub struct ModuleBStats {
21    pub memfd_bytes: u64,
22    pub handoff_latency_ms: f64,
23    pub marker_ok: bool,
24    pub socket_path: PathBuf,
25}
26
27impl ModuleBStats {
28    pub fn to_json(&self) -> String {
29        format!(
30            "{{\"module\":\"B\",\"memfd_bytes\":{},\"handoff_latency_ms\":{:.3},\"marker_ok\":{},\"socket_path\":\"{}\"}}",
31            self.memfd_bytes,
32            self.handoff_latency_ms,
33            self.marker_ok,
34            self.socket_path.display()
35        )
36    }
37}
38
39#[cfg(target_os = "linux")]
40mod linux {
41    use std::ffi::CString;
42    use std::fs;
43    use std::io::{IoSlice, IoSliceMut, Read, Write};
44    use std::num::NonZeroUsize;
45    use std::os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd};
46    use std::os::unix::net::{UnixListener, UnixStream};
47    use std::path::{Path, PathBuf};
48    use std::process::Command;
49    use std::ptr::NonNull;
50    use std::thread;
51    use std::time::{Duration, Instant};
52
53    use anyhow::{Context, Result};
54    use nix::cmsg_space;
55    use nix::sys::memfd::{memfd_create, MFdFlags};
56    use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags};
57    use nix::sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags};
58    use nix::unistd::ftruncate;
59
60    use crate::module_b::{ModuleBConfig, ModuleBReceiverConfig, ModuleBStats};
61
62    const ACK_SUCCESS: u8 = 1;
63
64    pub fn run_parent(config: ModuleBConfig) -> Result<ModuleBStats> {
65        let memfd_len = usize::try_from(config.memfd_bytes)
66            .context("memfd_bytes exceeds usize on this platform")?;
67        if memfd_len == 0 {
68            anyhow::bail!("memfd_bytes must be > 0");
69        }
70        if config.marker.is_empty() {
71            anyhow::bail!("marker must not be empty");
72        }
73        if config.marker.len() > memfd_len {
74            anyhow::bail!("marker length must be <= memfd_bytes");
75        }
76
77        let socket_path = config
78            .socket_path
79            .unwrap_or_else(|| default_socket_path(std::process::id()));
80        if socket_path.exists() {
81            fs::remove_file(&socket_path).with_context(|| {
82                format!(
83                    "failed to remove pre-existing socket path {}",
84                    socket_path.display()
85                )
86            })?;
87        }
88
89        let listener = UnixListener::bind(&socket_path)
90            .with_context(|| format!("failed to bind UDS listener at {}", socket_path.display()))?;
91
92        let memfd_name = CString::new("tracer-bullet-memfd").context("invalid memfd name")?;
93        let memfd = memfd_create(memfd_name.as_c_str(), MFdFlags::MFD_CLOEXEC)
94            .context("memfd_create failed")?;
95        ftruncate(&memfd, config.memfd_bytes as i64).context("ftruncate(memfd) failed")?;
96
97        let mut writable_map = MmapGuard::map(&memfd, memfd_len, true)
98            .context("failed to mmap writable sender memory")?;
99        write_marker(writable_map.as_mut_slice(), config.marker.as_bytes())
100            .context("failed to write validation marker into memfd")?;
101
102        let exe = std::env::current_exe().context("failed to locate current executable")?;
103        let mut child = Command::new(exe)
104            .arg("module-b-receiver-internal")
105            .arg("--socket-path")
106            .arg(&socket_path)
107            .arg("--marker")
108            .arg(&config.marker)
109            .arg("--memfd-bytes")
110            .arg(config.memfd_bytes.to_string())
111            .spawn()
112            .context("failed to spawn module-b receiver child process")?;
113
114        let (mut stream, _) = listener
115            .accept()
116            .context("receiver failed to connect to UDS")?;
117
118        let start = Instant::now();
119        send_fd(stream.as_raw_fd(), memfd.as_raw_fd()).context("sendmsg SCM_RIGHTS failed")?;
120
121        let mut ack = [0_u8; 1];
122        stream
123            .read_exact(&mut ack)
124            .context("failed to read receiver ACK")?;
125        let handoff_latency_ms = start.elapsed().as_secs_f64() * 1_000.0;
126
127        let status = child
128            .wait()
129            .context("failed waiting for receiver child process")?;
130        if !status.success() {
131            anyhow::bail!("receiver child exited unsuccessfully: {status}");
132        }
133
134        drop(writable_map);
135
136        if socket_path.exists() {
137            let _ = fs::remove_file(&socket_path);
138        }
139
140        if ack[0] != ACK_SUCCESS {
141            anyhow::bail!("receiver returned non-success ACK byte: {}", ack[0]);
142        }
143
144        Ok(ModuleBStats {
145            memfd_bytes: config.memfd_bytes,
146            handoff_latency_ms,
147            marker_ok: true,
148            socket_path,
149        })
150    }
151
152    pub fn run_receiver_internal(config: ModuleBReceiverConfig) -> Result<()> {
153        let memfd_len = usize::try_from(config.memfd_bytes)
154            .context("memfd_bytes exceeds usize on this platform")?;
155        if memfd_len == 0 {
156            anyhow::bail!("memfd_bytes must be > 0");
157        }
158
159        let mut stream = connect_with_retry(&config.socket_path, 200, Duration::from_millis(10))
160            .with_context(|| {
161                format!(
162                    "failed to connect to sender socket {}",
163                    config.socket_path.display()
164                )
165            })?;
166
167        let received_fd = recv_fd(stream.as_raw_fd()).context("recvmsg SCM_RIGHTS failed")?;
168        let mapped = MmapGuard::map(&received_fd, memfd_len, false)
169            .context("receiver failed to mmap read-only memfd")?;
170
171        validate_marker(mapped.as_slice(), config.marker.as_bytes())
172            .context("receiver marker validation failed")?;
173
174        stream
175            .write_all(&[ACK_SUCCESS])
176            .context("failed to write ACK to sender")?;
177        stream.flush().context("failed to flush ACK to sender")?;
178
179        Ok(())
180    }
181
182    fn default_socket_path(pid: u32) -> PathBuf {
183        PathBuf::from(format!("/tmp/tracer-bullet-{}.sock", pid))
184    }
185
186    fn send_fd(socket_fd: RawFd, fd_to_send: RawFd) -> Result<()> {
187        let payload = [0xAB_u8];
188        let iov = [IoSlice::new(&payload)];
189        let fds = [fd_to_send];
190        let cmsg = [ControlMessage::ScmRights(&fds)];
191
192        let bytes = sendmsg::<()>(socket_fd, &iov, &cmsg, MsgFlags::empty(), None)
193            .context("sendmsg failed")?;
194        if bytes != payload.len() {
195            anyhow::bail!("sendmsg wrote {} bytes instead of {}", bytes, payload.len());
196        }
197
198        Ok(())
199    }
200
201    fn recv_fd(socket_fd: RawFd) -> Result<OwnedFd> {
202        let mut payload = [0_u8; 1];
203        let mut iov = [IoSliceMut::new(&mut payload)];
204        let mut cmsgspace = cmsg_space!([RawFd; 1]);
205
206        let msg = recvmsg::<()>(socket_fd, &mut iov, Some(&mut cmsgspace), MsgFlags::empty())
207            .context("recvmsg failed")?;
208
209        if msg.bytes == 0 {
210            anyhow::bail!("recvmsg received no payload bytes");
211        }
212        if msg
213            .flags
214            .intersects(MsgFlags::MSG_CTRUNC | MsgFlags::MSG_TRUNC)
215        {
216            anyhow::bail!("recvmsg control/payload data was truncated");
217        }
218
219        for cmsg in msg
220            .cmsgs()
221            .context("failed to parse recvmsg control messages")?
222        {
223            if let ControlMessageOwned::ScmRights(fds) = cmsg {
224                if let Some(fd) = fds.first().copied() {
225                    let owned = unsafe { OwnedFd::from_raw_fd(fd) };
226                    return Ok(owned);
227                }
228            }
229        }
230
231        anyhow::bail!("no SCM_RIGHTS file descriptor received")
232    }
233
234    fn connect_with_retry(path: &Path, attempts: usize, sleep: Duration) -> Result<UnixStream> {
235        let mut last_err = None;
236        for _ in 0..attempts {
237            match UnixStream::connect(path) {
238                Ok(stream) => return Ok(stream),
239                Err(err) => {
240                    last_err = Some(err);
241                    thread::sleep(sleep);
242                }
243            }
244        }
245
246        match last_err {
247            Some(err) => {
248                Err(err).with_context(|| format!("unable to connect to {}", path.display()))
249            }
250            None => anyhow::bail!("unable to connect to {}", path.display()),
251        }
252    }
253
254    fn write_marker(mapped: &mut [u8], marker: &[u8]) -> Result<()> {
255        if marker.len() > mapped.len() {
256            anyhow::bail!("marker longer than mapped memory");
257        }
258        let start = mapped.len() - marker.len();
259        mapped[start..].copy_from_slice(marker);
260        Ok(())
261    }
262
263    fn validate_marker(mapped: &[u8], marker: &[u8]) -> Result<()> {
264        if marker.len() > mapped.len() {
265            anyhow::bail!("marker longer than mapped memory");
266        }
267        let start = mapped.len() - marker.len();
268        if mapped[start..] != *marker {
269            anyhow::bail!("marker mismatch at memfd tail");
270        }
271        Ok(())
272    }
273
274    struct MmapGuard {
275        ptr: NonNull<std::ffi::c_void>,
276        len: usize,
277    }
278
279    impl MmapGuard {
280        fn map<Fd: AsFd>(fd: &Fd, len: usize, writable: bool) -> Result<Self> {
281            let nz_len = NonZeroUsize::new(len).context("mmap length must be > 0")?;
282            let prot = if writable {
283                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE
284            } else {
285                ProtFlags::PROT_READ
286            };
287
288            let ptr = unsafe { mmap(None, nz_len, prot, MapFlags::MAP_SHARED, fd, 0) }
289                .context("mmap syscall failed")?;
290
291            Ok(Self { ptr, len })
292        }
293
294        fn as_slice(&self) -> &[u8] {
295            unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const u8, self.len) }
296        }
297
298        fn as_mut_slice(&mut self) -> &mut [u8] {
299            unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut u8, self.len) }
300        }
301    }
302
303    impl Drop for MmapGuard {
304        fn drop(&mut self) {
305            let _ = unsafe { munmap(self.ptr, self.len) };
306        }
307    }
308
309    #[cfg(test)]
310    mod tests {
311        use std::io::{Read, Write};
312        use std::os::fd::AsRawFd;
313        use std::os::unix::net::UnixStream;
314        use std::thread;
315
316        use super::*;
317
318        #[test]
319        fn transfers_fd_and_validates_marker() {
320            let memfd_name = CString::new("module-b-test").expect("valid test memfd name");
321            let memfd = memfd_create(memfd_name.as_c_str(), MFdFlags::MFD_CLOEXEC)
322                .expect("memfd_create should succeed in test");
323
324            let len = 8 * 1024 * 1024;
325            let marker = b"TRACER_BULLET_SUCCESS";
326            ftruncate(&memfd, len as i64).expect("ftruncate should succeed in test");
327
328            let mut map = MmapGuard::map(&memfd, len, true).expect("mmap should succeed in test");
329            write_marker(map.as_mut_slice(), marker).expect("writing marker should succeed");
330
331            let (mut send_stream, recv_stream) =
332                UnixStream::pair().expect("socketpair should succeed");
333
334            let marker_vec = marker.to_vec();
335            let join = thread::spawn(move || {
336                let fd =
337                    recv_fd(recv_stream.as_raw_fd()).expect("recv_fd should succeed in thread");
338                let mapped = MmapGuard::map(&fd, len, false).expect("receiver mmap should succeed");
339                validate_marker(mapped.as_slice(), &marker_vec)
340                    .expect("marker should validate in receiver");
341                let mut stream = recv_stream;
342                stream
343                    .write_all(&[ACK_SUCCESS])
344                    .expect("thread ACK write should succeed");
345            });
346
347            send_fd(send_stream.as_raw_fd(), memfd.as_raw_fd()).expect("send_fd should succeed");
348            let mut ack = [0_u8; 1];
349            send_stream
350                .read_exact(&mut ack)
351                .expect("ACK read should succeed in sender");
352            assert_eq!(ack, [ACK_SUCCESS]);
353
354            join.join()
355                .expect("receiver thread should join successfully");
356        }
357
358        #[test]
359        #[ignore = "heavy memory mapping validation"]
360        fn validates_full_5gb_mapping_without_copy() {
361            let memfd_name = CString::new("module-b-test-5gb").expect("valid test memfd name");
362            let memfd = memfd_create(memfd_name.as_c_str(), MFdFlags::MFD_CLOEXEC)
363                .expect("memfd_create should succeed in test");
364
365            let len = 5 * 1024 * 1024 * 1024usize;
366            let marker = b"TRACER_BULLET_SUCCESS";
367            ftruncate(&memfd, len as i64).expect("ftruncate should succeed in test");
368
369            let mut map = MmapGuard::map(&memfd, len, true).expect("mmap should succeed in test");
370            write_marker(map.as_mut_slice(), marker).expect("write marker should succeed");
371            validate_marker(map.as_slice(), marker).expect("validate marker should succeed");
372        }
373    }
374}
375
376#[cfg(target_os = "linux")]
377pub fn run_parent(config: ModuleBConfig) -> Result<ModuleBStats> {
378    linux::run_parent(config)
379}
380
381#[cfg(target_os = "linux")]
382pub fn run_receiver_internal(config: ModuleBReceiverConfig) -> Result<()> {
383    linux::run_receiver_internal(config)
384}
385
386#[cfg(not(target_os = "linux"))]
387pub fn run_parent(_config: ModuleBConfig) -> Result<ModuleBStats> {
388    anyhow::bail!("module-b requires Linux for memfd + SCM_RIGHTS")
389}
390
391#[cfg(not(target_os = "linux"))]
392pub fn run_receiver_internal(_config: ModuleBReceiverConfig) -> Result<()> {
393    anyhow::bail!("module-b receiver requires Linux")
394}