Skip to main content

coreutils_rs/head/
core.rs

1use std::io::{self, Read, Write};
2use std::path::Path;
3
4use memchr::{memchr_iter, memrchr_iter};
5
6use crate::common::io::{FileData, read_file, read_stdin};
7
8/// Mode for head operation
9#[derive(Clone, Debug)]
10pub enum HeadMode {
11    /// First N lines (default: 10)
12    Lines(u64),
13    /// All but last N lines
14    LinesFromEnd(u64),
15    /// First N bytes
16    Bytes(u64),
17    /// All but last N bytes
18    BytesFromEnd(u64),
19}
20
21/// Configuration for head
22#[derive(Clone, Debug)]
23pub struct HeadConfig {
24    pub mode: HeadMode,
25    pub zero_terminated: bool,
26}
27
28impl Default for HeadConfig {
29    fn default() -> Self {
30        Self {
31            mode: HeadMode::Lines(10),
32            zero_terminated: false,
33        }
34    }
35}
36
37/// Parse a numeric argument with optional suffix (K, M, G, etc.)
38/// Supports: b(512), kB(1000), K(1024), MB(1e6), M(1048576), GB(1e9), G(1<<30),
39/// TB, T, PB, P, EB, E, ZB, Z, YB, Y
40pub fn parse_size(s: &str) -> Result<u64, String> {
41    let s = s.trim();
42    if s.is_empty() {
43        return Err("empty size".to_string());
44    }
45
46    // Find where the numeric part ends
47    let mut num_end = 0;
48    for (i, c) in s.char_indices() {
49        if c.is_ascii_digit() || (i == 0 && (c == '+' || c == '-')) {
50            num_end = i + c.len_utf8();
51        } else {
52            break;
53        }
54    }
55
56    if num_end == 0 {
57        return Err(format!("invalid number: '{}'", s));
58    }
59
60    let num_str = &s[..num_end];
61    let suffix = &s[num_end..];
62
63    let num: u64 = num_str
64        .parse()
65        .map_err(|_| format!("invalid number: '{}'", num_str))?;
66
67    let multiplier: u64 = match suffix {
68        "" => 1,
69        "b" => 512,
70        "kB" => 1000,
71        "K" | "KiB" => 1024,
72        "MB" => 1_000_000,
73        "M" | "MiB" => 1_048_576,
74        "GB" => 1_000_000_000,
75        "G" | "GiB" => 1_073_741_824,
76        "TB" => 1_000_000_000_000,
77        "T" | "TiB" => 1_099_511_627_776,
78        "PB" => 1_000_000_000_000_000,
79        "P" | "PiB" => 1_125_899_906_842_624,
80        "EB" => 1_000_000_000_000_000_000,
81        "E" | "EiB" => 1_152_921_504_606_846_976,
82        // ZB/Z/YB/Y would overflow u64, treat as max
83        "ZB" | "Z" | "ZiB" | "YB" | "Y" | "YiB" => {
84            if num > 0 {
85                return Ok(u64::MAX);
86            }
87            return Ok(0);
88        }
89        _ => return Err(format!("invalid suffix in '{}'", s)),
90    };
91
92    num.checked_mul(multiplier)
93        .ok_or_else(|| format!("number too large: '{}'", s))
94}
95
96/// Output first N lines from data
97pub fn head_lines(data: &[u8], n: u64, delimiter: u8, out: &mut impl Write) -> io::Result<()> {
98    if n == 0 || data.is_empty() {
99        return Ok(());
100    }
101
102    let mut count = 0u64;
103    for pos in memchr_iter(delimiter, data) {
104        count += 1;
105        if count == n {
106            return out.write_all(&data[..=pos]);
107        }
108    }
109
110    // Fewer than N lines — output everything
111    out.write_all(data)
112}
113
114/// Output all but last N lines from data.
115/// Uses reverse scanning (memrchr_iter) for single-pass O(n) instead of 2-pass.
116pub fn head_lines_from_end(
117    data: &[u8],
118    n: u64,
119    delimiter: u8,
120    out: &mut impl Write,
121) -> io::Result<()> {
122    if n == 0 {
123        return out.write_all(data);
124    }
125    if data.is_empty() {
126        return Ok(());
127    }
128
129    // Scan backward: skip N delimiters (= N lines), then the next delimiter
130    // marks the end of the last line to keep.
131    let mut count = 0u64;
132    for pos in memrchr_iter(delimiter, data) {
133        count += 1;
134        if count > n {
135            return out.write_all(&data[..=pos]);
136        }
137    }
138
139    // Fewer than N+1 delimiters → N >= total lines → output nothing
140    Ok(())
141}
142
143/// Output first N bytes from data
144pub fn head_bytes(data: &[u8], n: u64, out: &mut impl Write) -> io::Result<()> {
145    let n = n.min(data.len() as u64) as usize;
146    if n > 0 {
147        out.write_all(&data[..n])?;
148    }
149    Ok(())
150}
151
152/// Output all but last N bytes from data
153pub fn head_bytes_from_end(data: &[u8], n: u64, out: &mut impl Write) -> io::Result<()> {
154    if n >= data.len() as u64 {
155        return Ok(());
156    }
157    let end = data.len() - n as usize;
158    if end > 0 {
159        out.write_all(&data[..end])?;
160    }
161    Ok(())
162}
163
164/// Use sendfile for zero-copy byte output on Linux
165#[cfg(target_os = "linux")]
166pub fn sendfile_bytes(path: &Path, n: u64, out_fd: i32) -> io::Result<bool> {
167    use std::os::unix::fs::OpenOptionsExt;
168
169    let file = std::fs::OpenOptions::new()
170        .read(true)
171        .custom_flags(libc::O_NOATIME)
172        .open(path)
173        .or_else(|_| std::fs::File::open(path))?;
174
175    let metadata = file.metadata()?;
176    let file_size = metadata.len();
177    let to_send = n.min(file_size) as usize;
178
179    if to_send == 0 {
180        return Ok(true);
181    }
182
183    use std::os::unix::io::AsRawFd;
184    let in_fd = file.as_raw_fd();
185    let mut offset: libc::off_t = 0;
186    let mut remaining = to_send;
187
188    while remaining > 0 {
189        let chunk = remaining.min(0x7ffff000); // sendfile max per call
190        let ret = unsafe { libc::sendfile(out_fd, in_fd, &mut offset, chunk) };
191        if ret > 0 {
192            remaining -= ret as usize;
193        } else if ret == 0 {
194            break;
195        } else {
196            let err = io::Error::last_os_error();
197            if err.kind() == io::ErrorKind::Interrupted {
198                continue;
199            }
200            return Err(err);
201        }
202    }
203
204    Ok(true)
205}
206
207/// Streaming head for positive line count on a regular file.
208/// Reads small chunks from the start, never mmaps the whole file.
209/// This is the critical fast path: `head -n 10` on a 100MB file
210/// reads only a few KB instead of mapping all 100MB.
211fn head_lines_streaming_file(
212    path: &Path,
213    n: u64,
214    delimiter: u8,
215    out: &mut impl Write,
216) -> io::Result<bool> {
217    if n == 0 {
218        return Ok(true);
219    }
220
221    #[cfg(target_os = "linux")]
222    let file = {
223        use std::os::unix::fs::OpenOptionsExt;
224        std::fs::OpenOptions::new()
225            .read(true)
226            .custom_flags(libc::O_NOATIME)
227            .open(path)
228            .or_else(|_| std::fs::File::open(path))?
229    };
230    #[cfg(not(target_os = "linux"))]
231    let file = std::fs::File::open(path)?;
232
233    let mut file = file;
234    let mut buf = [0u8; 65536];
235    let mut count = 0u64;
236
237    loop {
238        let bytes_read = match file.read(&mut buf) {
239            Ok(0) => break,
240            Ok(n) => n,
241            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
242            Err(e) => return Err(e),
243        };
244
245        let chunk = &buf[..bytes_read];
246
247        for pos in memchr_iter(delimiter, chunk) {
248            count += 1;
249            if count == n {
250                out.write_all(&chunk[..=pos])?;
251                return Ok(true);
252            }
253        }
254
255        out.write_all(chunk)?;
256    }
257
258    Ok(true)
259}
260
261/// Process a single file/stdin for head
262pub fn head_file(
263    filename: &str,
264    config: &HeadConfig,
265    out: &mut impl Write,
266    tool_name: &str,
267) -> io::Result<bool> {
268    let delimiter = if config.zero_terminated { b'\0' } else { b'\n' };
269
270    if filename != "-" {
271        let path = Path::new(filename);
272
273        // Fast paths that avoid reading/mmapping the whole file
274        match &config.mode {
275            HeadMode::Lines(n) => {
276                // Streaming: read small chunks, stop after N lines
277                match head_lines_streaming_file(path, *n, delimiter, out) {
278                    Ok(true) => return Ok(true),
279                    Err(e) => {
280                        eprintln!(
281                            "{}: cannot open '{}' for reading: {}",
282                            tool_name,
283                            filename,
284                            crate::common::io_error_msg(&e)
285                        );
286                        return Ok(false);
287                    }
288                    _ => {}
289                }
290            }
291            HeadMode::Bytes(n) => {
292                // sendfile: zero-copy, reads only N bytes
293                #[cfg(target_os = "linux")]
294                {
295                    use std::os::unix::io::AsRawFd;
296                    let stdout = io::stdout();
297                    let out_fd = stdout.as_raw_fd();
298                    if let Ok(true) = sendfile_bytes(path, *n, out_fd) {
299                        return Ok(true);
300                    }
301                }
302                // Non-Linux: still avoid full mmap
303                #[cfg(not(target_os = "linux"))]
304                {
305                    if let Ok(true) = head_bytes_streaming_file(path, *n, out) {
306                        return Ok(true);
307                    }
308                }
309            }
310            _ => {
311                // LinesFromEnd and BytesFromEnd need the whole file — use mmap
312            }
313        }
314    }
315
316    // Slow path: read entire file (needed for -n -N, -c -N, or stdin)
317    let data: FileData = if filename == "-" {
318        match read_stdin() {
319            Ok(d) => FileData::Owned(d),
320            Err(e) => {
321                eprintln!(
322                    "{}: standard input: {}",
323                    tool_name,
324                    crate::common::io_error_msg(&e)
325                );
326                return Ok(false);
327            }
328        }
329    } else {
330        match read_file(Path::new(filename)) {
331            Ok(d) => d,
332            Err(e) => {
333                eprintln!(
334                    "{}: cannot open '{}' for reading: {}",
335                    tool_name,
336                    filename,
337                    crate::common::io_error_msg(&e)
338                );
339                return Ok(false);
340            }
341        }
342    };
343
344    match &config.mode {
345        HeadMode::Lines(n) => head_lines(&data, *n, delimiter, out)?,
346        HeadMode::LinesFromEnd(n) => head_lines_from_end(&data, *n, delimiter, out)?,
347        HeadMode::Bytes(n) => head_bytes(&data, *n, out)?,
348        HeadMode::BytesFromEnd(n) => head_bytes_from_end(&data, *n, out)?,
349    }
350
351    Ok(true)
352}
353
354/// Streaming head for positive byte count on non-Linux.
355#[cfg(not(target_os = "linux"))]
356fn head_bytes_streaming_file(path: &Path, n: u64, out: &mut impl Write) -> io::Result<bool> {
357    let mut file = std::fs::File::open(path)?;
358    let mut remaining = n as usize;
359    let mut buf = [0u8; 65536];
360
361    while remaining > 0 {
362        let to_read = remaining.min(buf.len());
363        let bytes_read = match file.read(&mut buf[..to_read]) {
364            Ok(0) => break,
365            Ok(n) => n,
366            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
367            Err(e) => return Err(e),
368        };
369        out.write_all(&buf[..bytes_read])?;
370        remaining -= bytes_read;
371    }
372
373    Ok(true)
374}
375
376/// Process head for stdin streaming (line mode, positive count)
377/// Reads chunks and counts lines, stopping early once count reached.
378pub fn head_stdin_lines_streaming(n: u64, delimiter: u8, out: &mut impl Write) -> io::Result<()> {
379    if n == 0 {
380        return Ok(());
381    }
382
383    let stdin = io::stdin();
384    let mut reader = stdin.lock();
385    let mut buf = [0u8; 262144];
386    let mut count = 0u64;
387
388    loop {
389        let bytes_read = match reader.read(&mut buf) {
390            Ok(0) => break,
391            Ok(n) => n,
392            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
393            Err(e) => return Err(e),
394        };
395
396        let chunk = &buf[..bytes_read];
397
398        // Count delimiters in this chunk
399        for pos in memchr_iter(delimiter, chunk) {
400            count += 1;
401            if count == n {
402                out.write_all(&chunk[..=pos])?;
403                return Ok(());
404            }
405        }
406
407        // Haven't reached N lines yet, output entire chunk
408        out.write_all(chunk)?;
409    }
410
411    Ok(())
412}