1use std::path::PathBuf;
2
3use anyhow::Result;
4
5#[derive(Debug, Clone)]
6pub struct ModuleBConfig {
7 pub memfd_bytes: u64,
8 pub socket_path: Option<PathBuf>,
9 pub marker: String,
10}
11
12#[derive(Debug, Clone)]
13pub struct ModuleBReceiverConfig {
14 pub socket_path: PathBuf,
15 pub marker: String,
16 pub memfd_bytes: u64,
17}
18
19#[derive(Debug, Clone)]
20pub struct ModuleBStats {
21 pub memfd_bytes: u64,
22 pub handoff_latency_ms: f64,
23 pub marker_ok: bool,
24 pub socket_path: PathBuf,
25}
26
27impl ModuleBStats {
28 pub fn to_json(&self) -> String {
29 format!(
30 "{{\"module\":\"B\",\"memfd_bytes\":{},\"handoff_latency_ms\":{:.3},\"marker_ok\":{},\"socket_path\":\"{}\"}}",
31 self.memfd_bytes,
32 self.handoff_latency_ms,
33 self.marker_ok,
34 self.socket_path.display()
35 )
36 }
37}
38
39#[cfg(target_os = "linux")]
40mod linux {
41 use std::ffi::CString;
42 use std::fs;
43 use std::io::{IoSlice, IoSliceMut, Read, Write};
44 use std::num::NonZeroUsize;
45 use std::os::fd::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd};
46 use std::os::unix::net::{UnixListener, UnixStream};
47 use std::path::{Path, PathBuf};
48 use std::process::Command;
49 use std::ptr::NonNull;
50 use std::thread;
51 use std::time::{Duration, Instant};
52
53 use anyhow::{Context, Result};
54 use nix::cmsg_space;
55 use nix::sys::memfd::{memfd_create, MFdFlags};
56 use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags};
57 use nix::sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags};
58 use nix::unistd::ftruncate;
59
60 use crate::module_b::{ModuleBConfig, ModuleBReceiverConfig, ModuleBStats};
61
62 const ACK_SUCCESS: u8 = 1;
63
64 pub fn run_parent(config: ModuleBConfig) -> Result<ModuleBStats> {
65 let memfd_len = usize::try_from(config.memfd_bytes)
66 .context("memfd_bytes exceeds usize on this platform")?;
67 if memfd_len == 0 {
68 anyhow::bail!("memfd_bytes must be > 0");
69 }
70 if config.marker.is_empty() {
71 anyhow::bail!("marker must not be empty");
72 }
73 if config.marker.len() > memfd_len {
74 anyhow::bail!("marker length must be <= memfd_bytes");
75 }
76
77 let socket_path = config
78 .socket_path
79 .unwrap_or_else(|| default_socket_path(std::process::id()));
80 if socket_path.exists() {
81 fs::remove_file(&socket_path).with_context(|| {
82 format!(
83 "failed to remove pre-existing socket path {}",
84 socket_path.display()
85 )
86 })?;
87 }
88
89 let listener = UnixListener::bind(&socket_path)
90 .with_context(|| format!("failed to bind UDS listener at {}", socket_path.display()))?;
91
92 let memfd_name = CString::new("tracer-bullet-memfd").context("invalid memfd name")?;
93 let memfd = memfd_create(memfd_name.as_c_str(), MFdFlags::MFD_CLOEXEC)
94 .context("memfd_create failed")?;
95 ftruncate(&memfd, config.memfd_bytes as i64).context("ftruncate(memfd) failed")?;
96
97 let mut writable_map = MmapGuard::map(&memfd, memfd_len, true)
98 .context("failed to mmap writable sender memory")?;
99 write_marker(writable_map.as_mut_slice(), config.marker.as_bytes())
100 .context("failed to write validation marker into memfd")?;
101
102 let exe = std::env::current_exe().context("failed to locate current executable")?;
103 let mut child = Command::new(exe)
104 .arg("module-b-receiver-internal")
105 .arg("--socket-path")
106 .arg(&socket_path)
107 .arg("--marker")
108 .arg(&config.marker)
109 .arg("--memfd-bytes")
110 .arg(config.memfd_bytes.to_string())
111 .spawn()
112 .context("failed to spawn module-b receiver child process")?;
113
114 let (mut stream, _) = listener
115 .accept()
116 .context("receiver failed to connect to UDS")?;
117
118 let start = Instant::now();
119 send_fd(stream.as_raw_fd(), memfd.as_raw_fd()).context("sendmsg SCM_RIGHTS failed")?;
120
121 let mut ack = [0_u8; 1];
122 stream
123 .read_exact(&mut ack)
124 .context("failed to read receiver ACK")?;
125 let handoff_latency_ms = start.elapsed().as_secs_f64() * 1_000.0;
126
127 let status = child
128 .wait()
129 .context("failed waiting for receiver child process")?;
130 if !status.success() {
131 anyhow::bail!("receiver child exited unsuccessfully: {status}");
132 }
133
134 drop(writable_map);
135
136 if socket_path.exists() {
137 let _ = fs::remove_file(&socket_path);
138 }
139
140 if ack[0] != ACK_SUCCESS {
141 anyhow::bail!("receiver returned non-success ACK byte: {}", ack[0]);
142 }
143
144 Ok(ModuleBStats {
145 memfd_bytes: config.memfd_bytes,
146 handoff_latency_ms,
147 marker_ok: true,
148 socket_path,
149 })
150 }
151
152 pub fn run_receiver_internal(config: ModuleBReceiverConfig) -> Result<()> {
153 let memfd_len = usize::try_from(config.memfd_bytes)
154 .context("memfd_bytes exceeds usize on this platform")?;
155 if memfd_len == 0 {
156 anyhow::bail!("memfd_bytes must be > 0");
157 }
158
159 let mut stream = connect_with_retry(&config.socket_path, 200, Duration::from_millis(10))
160 .with_context(|| {
161 format!(
162 "failed to connect to sender socket {}",
163 config.socket_path.display()
164 )
165 })?;
166
167 let received_fd = recv_fd(stream.as_raw_fd()).context("recvmsg SCM_RIGHTS failed")?;
168 let mapped = MmapGuard::map(&received_fd, memfd_len, false)
169 .context("receiver failed to mmap read-only memfd")?;
170
171 validate_marker(mapped.as_slice(), config.marker.as_bytes())
172 .context("receiver marker validation failed")?;
173
174 stream
175 .write_all(&[ACK_SUCCESS])
176 .context("failed to write ACK to sender")?;
177 stream.flush().context("failed to flush ACK to sender")?;
178
179 Ok(())
180 }
181
182 fn default_socket_path(pid: u32) -> PathBuf {
183 PathBuf::from(format!("/tmp/tracer-bullet-{}.sock", pid))
184 }
185
186 fn send_fd(socket_fd: RawFd, fd_to_send: RawFd) -> Result<()> {
187 let payload = [0xAB_u8];
188 let iov = [IoSlice::new(&payload)];
189 let fds = [fd_to_send];
190 let cmsg = [ControlMessage::ScmRights(&fds)];
191
192 let bytes = sendmsg::<()>(socket_fd, &iov, &cmsg, MsgFlags::empty(), None)
193 .context("sendmsg failed")?;
194 if bytes != payload.len() {
195 anyhow::bail!("sendmsg wrote {} bytes instead of {}", bytes, payload.len());
196 }
197
198 Ok(())
199 }
200
201 fn recv_fd(socket_fd: RawFd) -> Result<OwnedFd> {
202 let mut payload = [0_u8; 1];
203 let mut iov = [IoSliceMut::new(&mut payload)];
204 let mut cmsgspace = cmsg_space!([RawFd; 1]);
205
206 let msg = recvmsg::<()>(socket_fd, &mut iov, Some(&mut cmsgspace), MsgFlags::empty())
207 .context("recvmsg failed")?;
208
209 if msg.bytes == 0 {
210 anyhow::bail!("recvmsg received no payload bytes");
211 }
212 if msg
213 .flags
214 .intersects(MsgFlags::MSG_CTRUNC | MsgFlags::MSG_TRUNC)
215 {
216 anyhow::bail!("recvmsg control/payload data was truncated");
217 }
218
219 for cmsg in msg
220 .cmsgs()
221 .context("failed to parse recvmsg control messages")?
222 {
223 if let ControlMessageOwned::ScmRights(fds) = cmsg {
224 if let Some(fd) = fds.first().copied() {
225 let owned = unsafe { OwnedFd::from_raw_fd(fd) };
226 return Ok(owned);
227 }
228 }
229 }
230
231 anyhow::bail!("no SCM_RIGHTS file descriptor received")
232 }
233
234 fn connect_with_retry(path: &Path, attempts: usize, sleep: Duration) -> Result<UnixStream> {
235 let mut last_err = None;
236 for _ in 0..attempts {
237 match UnixStream::connect(path) {
238 Ok(stream) => return Ok(stream),
239 Err(err) => {
240 last_err = Some(err);
241 thread::sleep(sleep);
242 }
243 }
244 }
245
246 match last_err {
247 Some(err) => {
248 Err(err).with_context(|| format!("unable to connect to {}", path.display()))
249 }
250 None => anyhow::bail!("unable to connect to {}", path.display()),
251 }
252 }
253
254 fn write_marker(mapped: &mut [u8], marker: &[u8]) -> Result<()> {
255 if marker.len() > mapped.len() {
256 anyhow::bail!("marker longer than mapped memory");
257 }
258 let start = mapped.len() - marker.len();
259 mapped[start..].copy_from_slice(marker);
260 Ok(())
261 }
262
263 fn validate_marker(mapped: &[u8], marker: &[u8]) -> Result<()> {
264 if marker.len() > mapped.len() {
265 anyhow::bail!("marker longer than mapped memory");
266 }
267 let start = mapped.len() - marker.len();
268 if mapped[start..] != *marker {
269 anyhow::bail!("marker mismatch at memfd tail");
270 }
271 Ok(())
272 }
273
274 struct MmapGuard {
275 ptr: NonNull<std::ffi::c_void>,
276 len: usize,
277 }
278
279 impl MmapGuard {
280 fn map<Fd: AsFd>(fd: &Fd, len: usize, writable: bool) -> Result<Self> {
281 let nz_len = NonZeroUsize::new(len).context("mmap length must be > 0")?;
282 let prot = if writable {
283 ProtFlags::PROT_READ | ProtFlags::PROT_WRITE
284 } else {
285 ProtFlags::PROT_READ
286 };
287
288 let ptr = unsafe { mmap(None, nz_len, prot, MapFlags::MAP_SHARED, fd, 0) }
289 .context("mmap syscall failed")?;
290
291 Ok(Self { ptr, len })
292 }
293
294 fn as_slice(&self) -> &[u8] {
295 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const u8, self.len) }
296 }
297
298 fn as_mut_slice(&mut self) -> &mut [u8] {
299 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut u8, self.len) }
300 }
301 }
302
303 impl Drop for MmapGuard {
304 fn drop(&mut self) {
305 let _ = unsafe { munmap(self.ptr, self.len) };
306 }
307 }
308
309 #[cfg(test)]
310 mod tests {
311 use std::io::{Read, Write};
312 use std::os::fd::AsRawFd;
313 use std::os::unix::net::UnixStream;
314 use std::thread;
315
316 use super::*;
317
318 #[test]
319 fn transfers_fd_and_validates_marker() {
320 let memfd_name = CString::new("module-b-test").expect("valid test memfd name");
321 let memfd = memfd_create(memfd_name.as_c_str(), MFdFlags::MFD_CLOEXEC)
322 .expect("memfd_create should succeed in test");
323
324 let len = 8 * 1024 * 1024;
325 let marker = b"TRACER_BULLET_SUCCESS";
326 ftruncate(&memfd, len as i64).expect("ftruncate should succeed in test");
327
328 let mut map = MmapGuard::map(&memfd, len, true).expect("mmap should succeed in test");
329 write_marker(map.as_mut_slice(), marker).expect("writing marker should succeed");
330
331 let (mut send_stream, recv_stream) =
332 UnixStream::pair().expect("socketpair should succeed");
333
334 let marker_vec = marker.to_vec();
335 let join = thread::spawn(move || {
336 let fd =
337 recv_fd(recv_stream.as_raw_fd()).expect("recv_fd should succeed in thread");
338 let mapped = MmapGuard::map(&fd, len, false).expect("receiver mmap should succeed");
339 validate_marker(mapped.as_slice(), &marker_vec)
340 .expect("marker should validate in receiver");
341 let mut stream = recv_stream;
342 stream
343 .write_all(&[ACK_SUCCESS])
344 .expect("thread ACK write should succeed");
345 });
346
347 send_fd(send_stream.as_raw_fd(), memfd.as_raw_fd()).expect("send_fd should succeed");
348 let mut ack = [0_u8; 1];
349 send_stream
350 .read_exact(&mut ack)
351 .expect("ACK read should succeed in sender");
352 assert_eq!(ack, [ACK_SUCCESS]);
353
354 join.join()
355 .expect("receiver thread should join successfully");
356 }
357
358 #[test]
359 #[ignore = "heavy memory mapping validation"]
360 fn validates_full_5gb_mapping_without_copy() {
361 let memfd_name = CString::new("module-b-test-5gb").expect("valid test memfd name");
362 let memfd = memfd_create(memfd_name.as_c_str(), MFdFlags::MFD_CLOEXEC)
363 .expect("memfd_create should succeed in test");
364
365 let len = 5 * 1024 * 1024 * 1024usize;
366 let marker = b"TRACER_BULLET_SUCCESS";
367 ftruncate(&memfd, len as i64).expect("ftruncate should succeed in test");
368
369 let mut map = MmapGuard::map(&memfd, len, true).expect("mmap should succeed in test");
370 write_marker(map.as_mut_slice(), marker).expect("write marker should succeed");
371 validate_marker(map.as_slice(), marker).expect("validate marker should succeed");
372 }
373 }
374}
375
376#[cfg(target_os = "linux")]
377pub fn run_parent(config: ModuleBConfig) -> Result<ModuleBStats> {
378 linux::run_parent(config)
379}
380
381#[cfg(target_os = "linux")]
382pub fn run_receiver_internal(config: ModuleBReceiverConfig) -> Result<()> {
383 linux::run_receiver_internal(config)
384}
385
386#[cfg(not(target_os = "linux"))]
387pub fn run_parent(_config: ModuleBConfig) -> Result<ModuleBStats> {
388 anyhow::bail!("module-b requires Linux for memfd + SCM_RIGHTS")
389}
390
391#[cfg(not(target_os = "linux"))]
392pub fn run_receiver_internal(_config: ModuleBReceiverConfig) -> Result<()> {
393 anyhow::bail!("module-b receiver requires Linux")
394}