use std::{
io::{Read, Stdin, Write},
os::fd::AsFd,
};
use nix::{
errno::Errno,
sys::sendfile::sendfile64,
unistd::{read, write},
};
use crate::{
compat::{fstatx, STATX_SIZE},
err2no,
fd::SafeOwnedFd,
retry::retry_on_eintr,
};
pub fn read_buf<Fd: AsFd>(fd: Fd, buf: &mut [u8]) -> Result<usize, Errno> {
let mut nread = 0;
while nread < buf.len() {
match retry_on_eintr(|| read(&fd, &mut buf[nread..]))? {
0 => break,
n => nread = nread.checked_add(n).ok_or(Errno::EOVERFLOW)?,
}
}
Ok(nread)
}
pub fn read_all<Fd: AsFd>(fd: Fd) -> Result<Vec<u8>, Errno> {
let mut buf = Vec::new();
let size = fstatx(&fd, STATX_SIZE)
.map(|stx| stx.stx_size)
.and_then(|size| usize::try_from(size).or(Err(Errno::EOVERFLOW)))?;
if size == 0 {
return Ok(buf);
}
buf.try_reserve(size).or(Err(Errno::ENOMEM))?;
buf.resize(size, 0);
let n = read_buf(fd, &mut buf)?;
buf.truncate(n);
Ok(buf)
}
pub fn write_all<Fd: AsFd>(fd: Fd, data: &[u8]) -> Result<(), Errno> {
let mut nwrite = 0;
while nwrite < data.len() {
match retry_on_eintr(|| write(&fd, &data[nwrite..]))? {
0 => return Err(Errno::EPIPE),
n => nwrite = nwrite.checked_add(n).ok_or(Errno::EOVERFLOW)?,
}
}
Ok(())
}
pub trait ReadFd: AsFd + Read {}
pub trait WriteFd: AsFd + Write {}
#[expect(clippy::disallowed_types)]
impl ReadFd for std::fs::File {}
impl ReadFd for Stdin {}
impl ReadFd for SafeOwnedFd {}
#[expect(clippy::disallowed_types)]
impl WriteFd for std::fs::File {}
impl WriteFd for SafeOwnedFd {}
pub fn copy<Fd1, Fd2>(src: &mut Fd1, dst: &mut Fd2) -> Result<u64, Errno>
where
Fd1: ReadFd,
Fd2: WriteFd,
{
const MAX: usize = 0x7ffff000;
let mut ncopy = 0;
loop {
return match sendfile64(&dst, &src, None, MAX) {
Ok(0) => Ok(ncopy),
Ok(n) => {
let n = n.try_into().or(Err(Errno::EOVERFLOW))?;
ncopy = ncopy.checked_add(n).ok_or(Errno::EOVERFLOW)?;
continue;
}
Err(Errno::EINTR) => continue,
Err(Errno::EINVAL | Errno::ENOSYS) =>
{
#[expect(clippy::disallowed_methods)]
std::io::copy(src, dst).map_err(|err| err2no(&err))
}
Err(errno) => Err(errno),
};
}
}
#[cfg(test)]
mod tests {
use std::io::{Seek, SeekFrom, Write as IoWrite};
use super::*;
fn tempfile_with(data: &[u8]) -> std::fs::File {
let mut f = tempfile::tempfile().unwrap();
f.write_all(data).unwrap();
f.seek(SeekFrom::Start(0)).unwrap();
f
}
#[test]
fn test_read_buf_1() {
let f = tempfile_with(b"hello");
let mut buf = [0u8; 5];
let n = read_buf(&f, &mut buf).unwrap();
assert_eq!(n, 5);
assert_eq!(&buf, b"hello");
}
#[test]
fn test_read_buf_2() {
let f = tempfile_with(b"hi");
let mut buf = [0u8; 10];
let n = read_buf(&f, &mut buf).unwrap();
assert_eq!(n, 2);
assert_eq!(&buf[..n], b"hi");
}
#[test]
fn test_read_buf_3() {
let f = tempfile_with(b"");
let mut buf = [0u8; 4];
let n = read_buf(&f, &mut buf).unwrap();
assert_eq!(n, 0);
}
#[test]
fn test_read_buf_4() {
let f = tempfile_with(b"abc");
let mut buf = [];
let n = read_buf(&f, &mut buf).unwrap();
assert_eq!(n, 0);
}
#[test]
fn test_read_all_1() {
let f = tempfile_with(b"syd rocks");
let data = read_all(&f).unwrap();
assert_eq!(data, b"syd rocks");
}
#[test]
fn test_read_all_2() {
let f = tempfile_with(b"");
let data = read_all(&f).unwrap();
assert!(data.is_empty());
}
#[test]
fn test_read_all_3() {
let payload = vec![0xffu8; 8192];
let f = tempfile_with(&payload);
let data = read_all(&f).unwrap();
assert_eq!(data, payload);
}
#[test]
fn test_write_all_1() {
let f = tempfile::tempfile().unwrap();
write_all(&f, b"hello world").unwrap();
let mut f = f;
f.seek(SeekFrom::Start(0)).unwrap();
let mut out = Vec::new();
std::io::Read::read_to_end(&mut f, &mut out).unwrap();
assert_eq!(out, b"hello world");
}
#[test]
fn test_write_all_2() {
let f = tempfile::tempfile().unwrap();
write_all(&f, b"").unwrap();
let mut f = f;
f.seek(SeekFrom::Start(0)).unwrap();
let mut out = Vec::new();
std::io::Read::read_to_end(&mut f, &mut out).unwrap();
assert!(out.is_empty());
}
#[test]
fn test_write_all_3() {
let payload = vec![0xabu8; 16384];
let f = tempfile::tempfile().unwrap();
write_all(&f, &payload).unwrap();
let mut f = f;
f.seek(SeekFrom::Start(0)).unwrap();
let mut out = Vec::new();
std::io::Read::read_to_end(&mut f, &mut out).unwrap();
assert_eq!(out, payload);
}
#[test]
fn test_copy_1() {
let mut src = tempfile_with(b"copy me");
let mut dst = tempfile::tempfile().unwrap();
let n = copy(&mut src, &mut dst).unwrap();
assert_eq!(n, 7);
dst.seek(SeekFrom::Start(0)).unwrap();
let mut out = Vec::new();
std::io::Read::read_to_end(&mut dst, &mut out).unwrap();
assert_eq!(out, b"copy me");
}
#[test]
fn test_copy_2() {
let mut src = tempfile_with(b"");
let mut dst = tempfile::tempfile().unwrap();
let n = copy(&mut src, &mut dst).unwrap();
assert_eq!(n, 0);
}
#[test]
fn test_copy_3() {
let payload = vec![0x42u8; 65536];
let mut src = tempfile_with(&payload);
let mut dst = tempfile::tempfile().unwrap();
let n = copy(&mut src, &mut dst).unwrap();
assert_eq!(n as usize, payload.len());
dst.seek(SeekFrom::Start(0)).unwrap();
let mut out = Vec::new();
std::io::Read::read_to_end(&mut dst, &mut out).unwrap();
assert_eq!(out, payload);
}
#[test]
fn test_readfd_1() {
let mut f = tempfile_with(b"trait test");
fn accept_readfd(r: &mut dyn ReadFd) -> Vec<u8> {
let mut buf = Vec::new();
r.read_to_end(&mut buf).unwrap();
buf
}
let data = accept_readfd(&mut f);
assert_eq!(data, b"trait test");
}
#[test]
fn test_writefd_1() {
let mut f = tempfile::tempfile().unwrap();
fn accept_writefd(w: &mut dyn WriteFd, data: &[u8]) {
w.write_all(data).unwrap();
}
accept_writefd(&mut f, b"trait write");
f.seek(SeekFrom::Start(0)).unwrap();
let mut out = Vec::new();
std::io::Read::read_to_end(&mut f, &mut out).unwrap();
assert_eq!(out, b"trait write");
}
}