coreutils_rs/common/io.rs
1use std::fs::{self, File};
2use std::io::{self, Read};
3use std::ops::Deref;
4use std::path::Path;
5
6#[cfg(target_os = "linux")]
7use std::sync::atomic::{AtomicBool, Ordering};
8
9use memmap2::{Mmap, MmapOptions};
10
11/// Holds file data — either zero-copy mmap or an owned Vec.
12/// Dereferences to `&[u8]` for transparent use.
13pub enum FileData {
14 Mmap(Mmap),
15 Owned(Vec<u8>),
16}
17
18impl Deref for FileData {
19 type Target = [u8];
20
21 fn deref(&self) -> &[u8] {
22 match self {
23 FileData::Mmap(m) => m,
24 FileData::Owned(v) => v,
25 }
26 }
27}
28
29/// Threshold below which we use read() instead of mmap.
30/// For files under 1MB, read() is faster since mmap has setup/teardown overhead
31/// (page table creation for up to 256 pages, TLB flush on munmap) that exceeds
32/// the zero-copy benefit.
33const MMAP_THRESHOLD: u64 = 1024 * 1024;
34
35/// Track whether O_NOATIME is supported to avoid repeated failed open() attempts.
36/// After the first EPERM, we never try O_NOATIME again (saves one syscall per file).
37#[cfg(target_os = "linux")]
38static NOATIME_SUPPORTED: AtomicBool = AtomicBool::new(true);
39
40/// Open a file with O_NOATIME on Linux to avoid atime inode writes.
41/// Caches whether O_NOATIME works to avoid double-open on every file.
42#[cfg(target_os = "linux")]
43fn open_noatime(path: &Path) -> io::Result<File> {
44 use std::os::unix::fs::OpenOptionsExt;
45 if NOATIME_SUPPORTED.load(Ordering::Relaxed) {
46 match fs::OpenOptions::new()
47 .read(true)
48 .custom_flags(libc::O_NOATIME)
49 .open(path)
50 {
51 Ok(f) => return Ok(f),
52 Err(ref e) if e.raw_os_error() == Some(libc::EPERM) => {
53 // O_NOATIME requires file ownership or CAP_FOWNER — disable globally
54 NOATIME_SUPPORTED.store(false, Ordering::Relaxed);
55 }
56 Err(e) => return Err(e), // Real error, propagate
57 }
58 }
59 File::open(path)
60}
61
62#[cfg(not(target_os = "linux"))]
63fn open_noatime(path: &Path) -> io::Result<File> {
64 File::open(path)
65}
66
67/// Read a file with zero-copy mmap for large files or read() for small files.
68/// Opens once with O_NOATIME, uses fstat for metadata to save a syscall.
69pub fn read_file(path: &Path) -> io::Result<FileData> {
70 let file = open_noatime(path)?;
71 let metadata = file.metadata()?;
72 let len = metadata.len();
73
74 if len > 0 && metadata.file_type().is_file() {
75 // Small files: exact-size read from already-open fd.
76 // Uses read_full into pre-sized buffer instead of read_to_end,
77 // which avoids the grow-and-probe pattern (saves 1-2 extra read() syscalls).
78 if len < MMAP_THRESHOLD {
79 let mut buf = vec![0u8; len as usize];
80 let n = read_full(&mut &file, &mut buf)?;
81 buf.truncate(n);
82 return Ok(FileData::Owned(buf));
83 }
84
85 // SAFETY: Read-only mapping. MADV_SEQUENTIAL lets the kernel
86 // prefetch ahead of our sequential access pattern.
87 match unsafe { MmapOptions::new().populate().map(&file) } {
88 Ok(mmap) => {
89 #[cfg(target_os = "linux")]
90 {
91 let _ = mmap.advise(memmap2::Advice::Sequential);
92 let _ = mmap.advise(memmap2::Advice::WillNeed);
93 // HUGEPAGE reduces TLB misses for large files (2MB+ = 1+ huge page).
94 // With 4KB pages, a 100MB file needs 25,600 TLB entries; with 2MB
95 // huge pages it needs only 50, reducing TLB miss overhead by ~500x.
96 if len >= 2 * 1024 * 1024 {
97 let _ = mmap.advise(memmap2::Advice::HugePage);
98 }
99 }
100 Ok(FileData::Mmap(mmap))
101 }
102 Err(_) => {
103 // mmap failed — fall back to read
104 let mut buf = Vec::with_capacity(len as usize);
105 let mut reader = file;
106 reader.read_to_end(&mut buf)?;
107 Ok(FileData::Owned(buf))
108 }
109 }
110 } else if len > 0 {
111 // Non-regular file (special files) — read from open fd
112 let mut buf = Vec::new();
113 let mut reader = file;
114 reader.read_to_end(&mut buf)?;
115 Ok(FileData::Owned(buf))
116 } else {
117 Ok(FileData::Owned(Vec::new()))
118 }
119}
120
121/// Get file size without reading it (for byte-count-only optimization).
122pub fn file_size(path: &Path) -> io::Result<u64> {
123 Ok(fs::metadata(path)?.len())
124}
125
126/// Read all bytes from stdin into a Vec.
127/// On Linux, uses raw libc::read() to bypass Rust's StdinLock/BufReader overhead.
128/// Uses a direct read() loop into a pre-allocated buffer instead of read_to_end(),
129/// which avoids Vec's grow-and-probe pattern (extra read() calls and memcpy).
130/// Callers should enlarge the pipe buffer via fcntl(F_SETPIPE_SZ) before calling.
131/// Uses the full spare capacity for each read() to minimize syscalls.
132pub fn read_stdin() -> io::Result<Vec<u8>> {
133 #[cfg(target_os = "linux")]
134 return read_stdin_raw();
135
136 #[cfg(not(target_os = "linux"))]
137 read_stdin_generic()
138}
139
140/// Raw libc::read() implementation for Linux — bypasses Rust's StdinLock
141/// and BufReader layers entirely. StdinLock uses an internal 8KB BufReader
142/// which adds an extra memcpy for every read; raw read() goes directly
143/// from the kernel pipe buffer to our Vec.
144///
145/// Pre-allocates 16MB to cover most workloads (benchmark = 10MB) without
146/// over-allocating. For inputs > 16MB, doubles capacity on demand.
147/// Each read() uses the full spare capacity to maximize bytes per syscall.
148///
149/// Note: callers (ftac, ftr, fbase64) are expected to enlarge the pipe
150/// buffer via fcntl(F_SETPIPE_SZ) before calling this function. We don't
151/// do it here to avoid accidentally shrinking a previously enlarged pipe.
152#[cfg(target_os = "linux")]
153fn read_stdin_raw() -> io::Result<Vec<u8>> {
154 const PREALLOC: usize = 16 * 1024 * 1024;
155
156 let mut buf: Vec<u8> = Vec::with_capacity(PREALLOC);
157
158 loop {
159 let spare_cap = buf.capacity() - buf.len();
160 if spare_cap < 1024 * 1024 {
161 // Grow by doubling (or at least 64MB) to minimize realloc count
162 let new_cap = (buf.capacity() * 2).max(buf.len() + PREALLOC);
163 buf.reserve(new_cap - buf.capacity());
164 }
165 let spare_cap = buf.capacity() - buf.len();
166 let start = buf.len();
167
168 // SAFETY: we read into the uninitialized spare capacity and extend
169 // set_len only by the number of bytes actually read.
170 let ret = unsafe {
171 libc::read(
172 0,
173 buf.as_mut_ptr().add(start) as *mut libc::c_void,
174 spare_cap,
175 )
176 };
177 if ret < 0 {
178 let err = io::Error::last_os_error();
179 if err.kind() == io::ErrorKind::Interrupted {
180 continue;
181 }
182 return Err(err);
183 }
184 if ret == 0 {
185 break;
186 }
187 unsafe { buf.set_len(start + ret as usize) };
188 }
189
190 Ok(buf)
191}
192
193/// Generic read_stdin for non-Linux platforms.
194#[cfg(not(target_os = "linux"))]
195fn read_stdin_generic() -> io::Result<Vec<u8>> {
196 const PREALLOC: usize = 16 * 1024 * 1024;
197 const READ_BUF: usize = 4 * 1024 * 1024;
198
199 let mut stdin = io::stdin().lock();
200 let mut buf: Vec<u8> = Vec::with_capacity(PREALLOC);
201
202 loop {
203 let spare_cap = buf.capacity() - buf.len();
204 if spare_cap < READ_BUF {
205 buf.reserve(PREALLOC);
206 }
207 let spare_cap = buf.capacity() - buf.len();
208
209 let start = buf.len();
210 unsafe { buf.set_len(start + spare_cap) };
211 match stdin.read(&mut buf[start..start + spare_cap]) {
212 Ok(0) => {
213 buf.truncate(start);
214 break;
215 }
216 Ok(n) => {
217 buf.truncate(start + n);
218 }
219 Err(e) if e.kind() == io::ErrorKind::Interrupted => {
220 buf.truncate(start);
221 continue;
222 }
223 Err(e) => return Err(e),
224 }
225 }
226
227 Ok(buf)
228}
229
230/// Read as many bytes as possible into buf, retrying on partial reads.
231/// Ensures the full buffer is filled (or EOF reached), avoiding the
232/// probe-read overhead of read_to_end.
233/// Fast path: regular file reads usually return the full buffer on the first call.
234#[inline]
235fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
236 // Fast path: first read() usually fills the entire buffer for regular files
237 let n = reader.read(buf)?;
238 if n == buf.len() || n == 0 {
239 return Ok(n);
240 }
241 // Slow path: partial read — retry to fill buffer (pipes, slow devices)
242 let mut total = n;
243 while total < buf.len() {
244 match reader.read(&mut buf[total..]) {
245 Ok(0) => break,
246 Ok(n) => total += n,
247 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
248 Err(e) => return Err(e),
249 }
250 }
251 Ok(total)
252}