use std::{
io::{Error, ErrorKind},
os::unix::io::AsRawFd,
ptr,
};
use crate::socket::SocketResult;
pub const DEFAULT_SPLICE_PIPE_CAPACITY: usize = 65_536;
static PIPE_CAPACITY_OVERRIDE: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
pub fn set_pipe_capacity(bytes: usize) {
if bytes == 0 {
return;
}
let _ = PIPE_CAPACITY_OVERRIDE.set(bytes);
}
fn requested_pipe_capacity() -> usize {
PIPE_CAPACITY_OVERRIDE
.get()
.copied()
.unwrap_or(DEFAULT_SPLICE_PIPE_CAPACITY)
}
pub struct SplicePipe {
pub in_pipe: [libc::c_int; 2],
pub out_pipe: [libc::c_int; 2],
pub in_pipe_pending: usize,
pub out_pipe_pending: usize,
pub capacity: usize,
}
impl SplicePipe {
pub fn new() -> Option<Self> {
let in_pipe = create_pipe()?;
let out_pipe = match create_pipe() {
Some(p) => p,
None => {
unsafe {
libc::close(in_pipe[0]);
libc::close(in_pipe[1]);
}
return None;
}
};
let requested = requested_pipe_capacity();
let capacity = apply_pipe_capacity(in_pipe[0], out_pipe[0], requested);
Some(SplicePipe {
in_pipe,
out_pipe,
in_pipe_pending: 0,
out_pipe_pending: 0,
capacity,
})
}
}
fn apply_pipe_capacity(in_read: libc::c_int, out_read: libc::c_int, requested: usize) -> usize {
let in_actual = set_and_query_pipe_size(in_read, requested);
let out_actual = set_and_query_pipe_size(out_read, requested);
in_actual.min(out_actual)
}
fn set_and_query_pipe_size(fd: libc::c_int, requested: usize) -> usize {
let set_ret = unsafe { libc::fcntl(fd, libc::F_SETPIPE_SZ, requested as libc::c_int) };
if set_ret == -1 {
let err = Error::last_os_error();
warn!(
"SPLICE\tF_SETPIPE_SZ({}) on pipe fd({}) failed: {:?}; keeping the kernel default. Lower the requested value or raise /proc/sys/fs/pipe-max-size.",
requested, fd, err
);
}
let get_ret = unsafe { libc::fcntl(fd, libc::F_GETPIPE_SZ) };
if get_ret < 0 {
DEFAULT_SPLICE_PIPE_CAPACITY
} else {
get_ret as usize
}
}
impl Drop for SplicePipe {
fn drop(&mut self) {
unsafe {
libc::close(self.in_pipe[0]);
libc::close(self.in_pipe[1]);
libc::close(self.out_pipe[0]);
libc::close(self.out_pipe[1]);
}
}
}
fn create_pipe() -> Option<[libc::c_int; 2]> {
let mut fds: [libc::c_int; 2] = [0; 2];
let ret = unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_NONBLOCK | libc::O_CLOEXEC) };
if ret == 0 { Some(fds) } else { None }
}
pub fn splice_in(
fd: &dyn AsRawFd,
pipe_write_end: libc::c_int,
len: usize,
) -> (usize, SocketResult) {
let res = unsafe {
libc::splice(
fd.as_raw_fd(),
ptr::null_mut(),
pipe_write_end,
ptr::null_mut(),
len,
libc::SPLICE_F_NONBLOCK | libc::SPLICE_F_MOVE,
)
};
match res {
-1 => {
let err = Error::last_os_error();
match err.kind() {
ErrorKind::WouldBlock => (0, SocketResult::WouldBlock),
_ => {
error!(
"SPLICE\terr splicing from fd({}) to pipe({}): {:?}",
fd.as_raw_fd(),
pipe_write_end,
err
);
(0, SocketResult::Error)
}
}
}
0 => (0, SocketResult::Closed),
n => (n as usize, SocketResult::Continue),
}
}
pub fn splice_out(
pipe_read_end: libc::c_int,
fd: &dyn AsRawFd,
len: usize,
) -> (usize, SocketResult) {
if len == 0 {
return (0, SocketResult::Continue);
}
let res = unsafe {
libc::splice(
pipe_read_end,
ptr::null_mut(),
fd.as_raw_fd(),
ptr::null_mut(),
len,
libc::SPLICE_F_NONBLOCK | libc::SPLICE_F_MOVE,
)
};
match res {
-1 => {
let err = Error::last_os_error();
match err.kind() {
ErrorKind::WouldBlock => (0, SocketResult::WouldBlock),
_ => {
error!(
"SPLICE\terr splicing from pipe({}) to fd({}): {:?}",
pipe_read_end,
fd.as_raw_fd(),
err
);
(0, SocketResult::Error)
}
}
}
0 => (0, SocketResult::Closed),
n => (n as usize, SocketResult::Continue),
}
}
#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},
net::{TcpListener, TcpStream},
thread,
time::Duration,
};
use super::*;
#[test]
fn splice_roundtrip() {
let proxy_listener = TcpListener::bind("127.0.0.1:0").expect("bind proxy");
let proxy_addr = proxy_listener.local_addr().expect("local_addr");
let pipe = create_pipe().expect("create_pipe");
let pipe_thread = thread::spawn(move || {
let (conn, _) = proxy_listener.accept().expect("accept");
conn.set_read_timeout(Some(Duration::from_secs(2))).ok();
let mut moved = 0usize;
for _ in 0..50 {
let (sz, status) = splice_in(&conn, pipe[1], DEFAULT_SPLICE_PIPE_CAPACITY);
if sz > 0 {
moved = sz;
assert_eq!(status, SocketResult::Continue);
break;
}
thread::sleep(Duration::from_millis(20));
}
assert!(moved > 0, "splice_in moved 0 bytes");
let (sz_out, status_out) = splice_out(pipe[0], &conn, moved);
assert_eq!(sz_out, moved, "splice_out byte count mismatch");
assert_eq!(status_out, SocketResult::Continue);
unsafe {
libc::close(pipe[0]);
libc::close(pipe[1]);
}
});
let mut client = TcpStream::connect(proxy_addr).expect("connect");
client.set_read_timeout(Some(Duration::from_secs(2))).ok();
let payload = b"splice test data";
client.write_all(payload).expect("client write");
let mut buf = [0u8; 128];
let n = client.read(&mut buf).expect("client read");
assert_eq!(&buf[..n], payload);
pipe_thread.join().expect("pipe thread");
}
}