#[cfg(unix)]
pub type FrameFds = Vec<std::os::fd::OwnedFd>;
#[cfg(not(unix))]
pub type FrameFds = ();
#[cfg(not(unix))]
mod portable {
pub fn collect_fds<R>(f: impl FnOnce() -> R) -> (R, super::FrameFds) {
(f(), ())
}
pub fn provide_fds<R>(_fds: super::FrameFds, f: impl FnOnce() -> R) -> R {
f()
}
}
#[cfg(not(unix))]
pub use portable::{collect_fds, provide_fds};
#[cfg(unix)]
pub fn frame_fds_len(fds: &FrameFds) -> usize {
fds.len()
}
#[cfg(not(unix))]
pub fn frame_fds_len(_fds: &FrameFds) -> usize {
0
}
#[cfg(unix)]
mod unix {
use std::cell::{Cell, RefCell};
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd};
use facet::{Facet, FacetOpaqueAdapter, OpaqueDeserialize, OpaqueSerialize, PtrConst};
use super::FrameFds;
pub const SCM_MAX_FD: usize = 253;
const NOT_COLLECTED: u32 = u32::MAX;
#[derive(Facet)]
#[facet(opaque = FdAdapter, traits(Debug))]
pub struct Fd {
inner: Cell<Option<OwnedFd>>,
wire_index: Cell<u32>,
}
impl Fd {
pub fn new(fd: impl Into<OwnedFd>) -> Self {
Self {
inner: Cell::new(Some(fd.into())),
wire_index: Cell::new(NOT_COLLECTED),
}
}
pub fn as_raw_fd(&self) -> Option<RawFd> {
let taken = self.inner.take();
let raw = taken.as_ref().map(|f| f.as_raw_fd());
self.inner.set(taken);
raw
}
pub fn into_owned_fd(self) -> Option<OwnedFd> {
self.inner.take()
}
pub fn into_raw_fd(self) -> Option<RawFd> {
self.inner.take().map(IntoRawFd::into_raw_fd)
}
}
impl std::fmt::Debug for Fd {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.as_raw_fd() {
Some(raw) => f.debug_tuple("Fd").field(&raw).finish(),
None => f.debug_tuple("Fd").field(&"<consumed>").finish(),
}
}
}
unsafe impl Send for Fd {}
struct FdCollector {
fds: Vec<OwnedFd>,
seen: std::collections::HashMap<usize, u32>,
}
std::thread_local! {
static FD_COLLECTOR: RefCell<Option<FdCollector>> = const { RefCell::new(None) };
static FD_SOURCE: RefCell<Option<Vec<Option<OwnedFd>>>> = const { RefCell::new(None) };
}
pub fn collect_fds<R>(f: impl FnOnce() -> R) -> (R, FrameFds) {
struct Restore(Option<FdCollector>);
impl Drop for Restore {
fn drop(&mut self) {
FD_COLLECTOR.with(|c| *c.borrow_mut() = self.0.take());
}
}
let fresh = FdCollector {
fds: Vec::new(),
seen: std::collections::HashMap::new(),
};
let _restore = Restore(FD_COLLECTOR.with(|c| c.borrow_mut().replace(fresh)));
let out = f();
let fds = FD_COLLECTOR
.with(|c| c.borrow_mut().as_mut().map(|col| std::mem::take(&mut col.fds)))
.unwrap_or_default();
(out, fds)
}
pub fn provide_fds<R>(fds: FrameFds, f: impl FnOnce() -> R) -> R {
struct Restore(Option<Vec<Option<OwnedFd>>>);
impl Drop for Restore {
fn drop(&mut self) {
FD_SOURCE.with(|c| *c.borrow_mut() = self.0.take());
}
}
let slots = fds.into_iter().map(Some).collect();
let _restore = Restore(FD_SOURCE.with(|c| c.borrow_mut().replace(slots)));
f()
}
fn collect_fd(key: usize, fd: BorrowedFd<'_>) -> u32 {
FD_COLLECTOR.with(|c| {
let mut slot = c.borrow_mut();
let Some(col) = slot.as_mut() else {
return NOT_COLLECTED;
};
if let Some(&idx) = col.seen.get(&key) {
return idx;
}
let Ok(dup) = fd.try_clone_to_owned() else {
return NOT_COLLECTED;
};
let idx = col.fds.len() as u32;
col.fds.push(dup);
col.seen.insert(key, idx);
idx
})
}
fn take_fd(index: u32) -> Result<OwnedFd, String> {
if index == NOT_COLLECTED {
return Err("Fd was sent without a descriptor".to_string());
}
FD_SOURCE.with(|c| {
let mut slot = c.borrow_mut();
let vec = slot
.as_mut()
.ok_or_else(|| "Fd decoded with no fd source installed".to_string())?;
let len = vec.len();
let cell = vec
.get_mut(index as usize)
.ok_or_else(|| format!("Fd wire index {index} out of range ({len})"))?;
cell.take()
.ok_or_else(|| format!("Fd wire index {index} already claimed"))
})
}
pub struct FdAdapter;
impl FacetOpaqueAdapter for FdAdapter {
type Error = String;
type SendValue<'a> = Fd;
type RecvValue<'de> = Fd;
fn serialize_map(value: &Self::SendValue<'_>) -> OpaqueSerialize {
let taken = value.inner.take();
let idx = match taken.as_ref() {
Some(owned) => collect_fd(value as *const Fd as usize, owned.as_fd()),
None => NOT_COLLECTED,
};
value.inner.set(taken);
value.wire_index.set(idx);
OpaqueSerialize {
ptr: PtrConst::new(value.wire_index.as_ptr().cast::<u8>()),
shape: <u32 as Facet>::SHAPE,
}
}
fn deserialize_build<'de>(
input: OpaqueDeserialize<'de>,
) -> Result<Self::RecvValue<'de>, Self::Error> {
let bytes = match &input {
OpaqueDeserialize::Borrowed(b) => *b,
OpaqueDeserialize::Owned(b) => b.as_slice(),
};
let mut cursor = vox_postcard::decode::Cursor::new(bytes);
let index = cursor
.read_varint()
.map_err(|e| format!("Fd index varint: {e}"))? as u32;
let owned = take_fd(index)?;
Ok(Fd {
inner: Cell::new(Some(owned)),
wire_index: Cell::new(index),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Seek, Write};
fn temp_file_with(seed: &[u8]) -> std::fs::File {
let mut path = std::env::temp_dir();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
path.push(format!("vox-fd-test-{}-{nanos}", std::process::id()));
let mut f = std::fs::OpenOptions::new()
.create(true)
.read(true)
.write(true)
.truncate(true)
.open(&path)
.unwrap();
let _ = std::fs::remove_file(&path);
f.write_all(seed).unwrap();
f.rewind().unwrap();
f
}
#[test]
fn fd_round_trips_through_postcard() {
let file = temp_file_with(b"vox-fd-payload");
let msg = Fd::new(OwnedFd::from(file));
let (bytes, collected) = collect_fds(|| vox_postcard::to_vec(&msg).expect("encode"));
assert_eq!(collected.len(), 1, "one fd collected");
assert_eq!(&bytes[..4], &1u32.to_le_bytes());
assert_eq!(bytes[4], 0);
let decoded: Fd = provide_fds(collected, || {
vox_postcard::from_slice(&bytes).expect("decode")
});
let mut f = std::fs::File::from(decoded.into_owned_fd().expect("owned fd"));
let mut got = String::new();
f.read_to_string(&mut got).unwrap();
assert_eq!(got, "vox-fd-payload");
}
#[test]
fn missing_source_is_a_clean_error() {
let msg = Fd::new(OwnedFd::from(temp_file_with(b"x")));
let (bytes, _fds) = collect_fds(|| vox_postcard::to_vec(&msg).unwrap());
let r = std::panic::catch_unwind(|| vox_postcard::from_slice::<Fd>(&bytes));
assert!(
r.is_err() || r.unwrap().is_err(),
"decoding an Fd with no source must fail"
);
}
}
}
#[cfg(unix)]
pub use unix::{Fd, FdAdapter, SCM_MAX_FD, collect_fds, provide_fds};