Skip to main content

coreutils_rs/head/
core.rs

1use std::io::{self, Read, Write};
2use std::path::Path;
3
4use memchr::{memchr_iter, memrchr_iter};
5
6use crate::common::io::{FileData, read_file, read_stdin};
7
8/// Mode for head operation
9#[derive(Clone, Debug)]
10pub enum HeadMode {
11    /// First N lines (default: 10)
12    Lines(u64),
13    /// All but last N lines
14    LinesFromEnd(u64),
15    /// First N bytes
16    Bytes(u64),
17    /// All but last N bytes
18    BytesFromEnd(u64),
19}
20
21/// Configuration for head
22#[derive(Clone, Debug)]
23pub struct HeadConfig {
24    pub mode: HeadMode,
25    pub zero_terminated: bool,
26}
27
28impl Default for HeadConfig {
29    fn default() -> Self {
30        Self {
31            mode: HeadMode::Lines(10),
32            zero_terminated: false,
33        }
34    }
35}
36
37/// Parse a numeric argument with optional suffix (K, M, G, etc.)
38/// Supports: b(512), kB(1000), K(1024), MB(1e6), M(1048576), GB(1e9), G(1<<30),
39/// TB, T, PB, P, EB, E, ZB, Z, YB, Y
40pub fn parse_size(s: &str) -> Result<u64, String> {
41    let s = s.trim();
42    if s.is_empty() {
43        return Err("empty size".to_string());
44    }
45
46    // Find where the numeric part ends
47    let mut num_end = 0;
48    for (i, c) in s.char_indices() {
49        if c.is_ascii_digit() || (i == 0 && (c == '+' || c == '-')) {
50            num_end = i + c.len_utf8();
51        } else {
52            break;
53        }
54    }
55
56    if num_end == 0 {
57        return Err(format!("invalid number: '{}'", s));
58    }
59
60    let num_str = &s[..num_end];
61    let suffix = &s[num_end..];
62
63    let num: u64 = match num_str.parse() {
64        Ok(n) => n,
65        Err(_) => {
66            // If the string is valid digits but overflows u64, clamp to u64::MAX
67            // like GNU coreutils does for huge counts
68            let digits = num_str
69                .strip_prefix('+')
70                .or_else(|| num_str.strip_prefix('-'))
71                .unwrap_or(num_str);
72            if !digits.is_empty() && digits.chars().all(|c| c.is_ascii_digit()) {
73                u64::MAX
74            } else {
75                return Err(format!("invalid number: '{}'", num_str));
76            }
77        }
78    };
79
80    let multiplier: u64 = match suffix {
81        "" => 1,
82        "b" => 512,
83        "kB" => 1000,
84        "k" | "K" | "KiB" => 1024,
85        "MB" => 1_000_000,
86        "M" | "MiB" => 1_048_576,
87        "GB" => 1_000_000_000,
88        "G" | "GiB" => 1_073_741_824,
89        "TB" => 1_000_000_000_000,
90        "T" | "TiB" => 1_099_511_627_776,
91        "PB" => 1_000_000_000_000_000,
92        "P" | "PiB" => 1_125_899_906_842_624,
93        "EB" => 1_000_000_000_000_000_000,
94        "E" | "EiB" => 1_152_921_504_606_846_976,
95        // ZB/Z/YB/Y would overflow u64, treat as max
96        "ZB" | "Z" | "ZiB" | "YB" | "Y" | "YiB" => {
97            if num > 0 {
98                return Ok(u64::MAX);
99            }
100            return Ok(0);
101        }
102        _ => return Err(format!("invalid suffix in '{}'", s)),
103    };
104
105    num.checked_mul(multiplier)
106        .ok_or_else(|| format!("number too large: '{}'", s))
107}
108
109/// Output first N lines from data
110pub fn head_lines(data: &[u8], n: u64, delimiter: u8, out: &mut impl Write) -> io::Result<()> {
111    if n == 0 || data.is_empty() {
112        return Ok(());
113    }
114
115    let mut count = 0u64;
116    for pos in memchr_iter(delimiter, data) {
117        count += 1;
118        if count == n {
119            return out.write_all(&data[..=pos]);
120        }
121    }
122
123    // Fewer than N lines — output everything
124    out.write_all(data)
125}
126
127/// Output all but last N lines from data.
128/// Uses reverse scanning (memrchr_iter) for single-pass O(n) instead of 2-pass.
129pub fn head_lines_from_end(
130    data: &[u8],
131    n: u64,
132    delimiter: u8,
133    out: &mut impl Write,
134) -> io::Result<()> {
135    if n == 0 {
136        return out.write_all(data);
137    }
138    if data.is_empty() {
139        return Ok(());
140    }
141
142    // Scan backward: skip N delimiters (= N lines), then the next delimiter
143    // marks the end of the last line to keep.
144    // If the data does not end with a delimiter, the unterminated last "line"
145    // counts as one line to skip.
146    let mut count = if !data.is_empty() && *data.last().unwrap() != delimiter {
147        1u64
148    } else {
149        0u64
150    };
151    for pos in memrchr_iter(delimiter, data) {
152        count += 1;
153        if count > n {
154            return out.write_all(&data[..=pos]);
155        }
156    }
157
158    // Fewer than N+1 lines → N >= total lines → output nothing
159    Ok(())
160}
161
162/// Output first N bytes from data
163pub fn head_bytes(data: &[u8], n: u64, out: &mut impl Write) -> io::Result<()> {
164    let n = n.min(data.len() as u64) as usize;
165    if n > 0 {
166        out.write_all(&data[..n])?;
167    }
168    Ok(())
169}
170
171/// Output all but last N bytes from data
172pub fn head_bytes_from_end(data: &[u8], n: u64, out: &mut impl Write) -> io::Result<()> {
173    if n >= data.len() as u64 {
174        return Ok(());
175    }
176    let end = data.len() - n as usize;
177    if end > 0 {
178        out.write_all(&data[..end])?;
179    }
180    Ok(())
181}
182
183/// Raw write(2) to stdout, bypassing all Rust I/O layers.
184/// Avoids stdout.lock(), BufWriter allocation, and Write trait overhead.
185#[cfg(target_os = "linux")]
186fn write_all_raw(mut data: &[u8]) -> io::Result<()> {
187    while !data.is_empty() {
188        let ret = unsafe { libc::write(1, data.as_ptr() as *const libc::c_void, data.len()) };
189        if ret > 0 {
190            data = &data[ret as usize..];
191        } else if ret == 0 {
192            return Err(io::Error::new(io::ErrorKind::WriteZero, "write returned 0"));
193        } else {
194            let err = io::Error::last_os_error();
195            if err.kind() == io::ErrorKind::Interrupted {
196                continue;
197            }
198            return Err(err);
199        }
200    }
201    Ok(())
202}
203
204/// Ultra-fast direct path: single file, positive line count, writes directly
205/// to stdout fd without BufWriter overhead. Uses raw write(2) on Linux;
206/// on other platforms uses a small stack-buffered stdout.
207/// Returns Ok(true) on success, Ok(false) on file error (already printed).
208pub fn head_file_direct(filename: &str, n: u64, delimiter: u8) -> io::Result<bool> {
209    if n == 0 {
210        return Ok(true);
211    }
212
213    let path = Path::new(filename);
214
215    #[cfg(target_os = "linux")]
216    {
217        use std::os::unix::fs::OpenOptionsExt;
218        let file = std::fs::OpenOptions::new()
219            .read(true)
220            .custom_flags(libc::O_NOATIME)
221            .open(path)
222            .or_else(|_| std::fs::File::open(path));
223        let mut file = match file {
224            Ok(f) => f,
225            Err(e) => {
226                eprintln!(
227                    "head: cannot open '{}' for reading: {}",
228                    filename,
229                    crate::common::io_error_msg(&e)
230                );
231                return Ok(false);
232            }
233        };
234
235        let mut buf = [0u8; 8192];
236        let mut count = 0u64;
237
238        loop {
239            let bytes_read = match file.read(&mut buf) {
240                Ok(0) => break,
241                Ok(n) => n,
242                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
243                Err(e) => return Err(e),
244            };
245
246            let chunk = &buf[..bytes_read];
247
248            for pos in memchr_iter(delimiter, chunk) {
249                count += 1;
250                if count == n {
251                    write_all_raw(&chunk[..=pos])?;
252                    return Ok(true);
253                }
254            }
255
256            write_all_raw(chunk)?;
257        }
258
259        return Ok(true);
260    }
261
262    #[cfg(not(target_os = "linux"))]
263    {
264        let stdout = io::stdout();
265        let mut out = io::BufWriter::with_capacity(8192, stdout.lock());
266        match head_lines_streaming_file(path, n, delimiter, &mut out) {
267            Ok(true) => {
268                out.flush()?;
269                Ok(true)
270            }
271            Ok(false) => Ok(false),
272            Err(e) => {
273                eprintln!(
274                    "head: cannot open '{}' for reading: {}",
275                    filename,
276                    crate::common::io_error_msg(&e)
277                );
278                Ok(false)
279            }
280        }
281    }
282}
283
284/// Use sendfile for zero-copy byte output on Linux.
285/// Falls back to read+write if sendfile fails (e.g., stdout is a terminal).
286#[cfg(target_os = "linux")]
287pub fn sendfile_bytes(path: &Path, n: u64, out_fd: i32) -> io::Result<bool> {
288    use std::os::unix::fs::OpenOptionsExt;
289
290    let file = std::fs::OpenOptions::new()
291        .read(true)
292        .custom_flags(libc::O_NOATIME)
293        .open(path)
294        .or_else(|_| std::fs::File::open(path))?;
295
296    let metadata = file.metadata()?;
297    let file_size = metadata.len();
298    let to_send = n.min(file_size) as usize;
299
300    if to_send == 0 {
301        return Ok(true);
302    }
303
304    use std::os::unix::io::AsRawFd;
305    let in_fd = file.as_raw_fd();
306    let mut offset: libc::off_t = 0;
307    let mut remaining = to_send;
308    let total = to_send;
309
310    while remaining > 0 {
311        let chunk = remaining.min(0x7ffff000); // sendfile max per call
312        let ret = unsafe { libc::sendfile(out_fd, in_fd, &mut offset, chunk) };
313        if ret > 0 {
314            remaining -= ret as usize;
315        } else if ret == 0 {
316            break;
317        } else {
318            let err = io::Error::last_os_error();
319            if err.kind() == io::ErrorKind::Interrupted {
320                continue;
321            }
322            // sendfile fails with EINVAL for terminal fds; fall back to read+write
323            if err.raw_os_error() == Some(libc::EINVAL) && remaining == total {
324                let mut file = file;
325                let mut buf = [0u8; 65536];
326                let mut left = to_send;
327                while left > 0 {
328                    let to_read = left.min(buf.len());
329                    let nr = match file.read(&mut buf[..to_read]) {
330                        Ok(0) => break,
331                        Ok(nr) => nr,
332                        Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
333                        Err(e) => return Err(e),
334                    };
335                    write_all_raw(&buf[..nr])?;
336                    left -= nr;
337                }
338                return Ok(true);
339            }
340            return Err(err);
341        }
342    }
343
344    Ok(true)
345}
346
347/// Streaming head for positive line count on a regular file.
348/// Reads small chunks from the start, never mmaps the whole file.
349/// This is the critical fast path: `head -n 10` on a 100MB file
350/// reads only a few KB instead of mapping all 100MB.
351fn head_lines_streaming_file(
352    path: &Path,
353    n: u64,
354    delimiter: u8,
355    out: &mut impl Write,
356) -> io::Result<bool> {
357    if n == 0 {
358        return Ok(true);
359    }
360
361    #[cfg(target_os = "linux")]
362    let file = {
363        use std::os::unix::fs::OpenOptionsExt;
364        std::fs::OpenOptions::new()
365            .read(true)
366            .custom_flags(libc::O_NOATIME)
367            .open(path)
368            .or_else(|_| std::fs::File::open(path))?
369    };
370    #[cfg(not(target_os = "linux"))]
371    let file = std::fs::File::open(path)?;
372
373    let mut file = file;
374    // Use 8KB buffer: default 10 lines almost always fits in one read.
375    // Avoids reading 65KB just to extract the first few lines.
376    let mut buf = [0u8; 8192];
377    let mut count = 0u64;
378
379    loop {
380        let bytes_read = match file.read(&mut buf) {
381            Ok(0) => break,
382            Ok(n) => n,
383            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
384            Err(e) => return Err(e),
385        };
386
387        let chunk = &buf[..bytes_read];
388
389        for pos in memchr_iter(delimiter, chunk) {
390            count += 1;
391            if count == n {
392                out.write_all(&chunk[..=pos])?;
393                return Ok(true);
394            }
395        }
396
397        out.write_all(chunk)?;
398    }
399
400    Ok(true)
401}
402
403/// Process a single file/stdin for head
404pub fn head_file(
405    filename: &str,
406    config: &HeadConfig,
407    out: &mut impl Write,
408    tool_name: &str,
409) -> io::Result<bool> {
410    let delimiter = if config.zero_terminated { b'\0' } else { b'\n' };
411
412    if filename != "-" {
413        let path = Path::new(filename);
414
415        // Fast paths that avoid reading/mmapping the whole file
416        match &config.mode {
417            HeadMode::Lines(n) => {
418                // Streaming: read small chunks, stop after N lines
419                match head_lines_streaming_file(path, *n, delimiter, out) {
420                    Ok(true) => return Ok(true),
421                    Err(e) => {
422                        eprintln!(
423                            "{}: cannot open '{}' for reading: {}",
424                            tool_name,
425                            filename,
426                            crate::common::io_error_msg(&e)
427                        );
428                        return Ok(false);
429                    }
430                    _ => {}
431                }
432            }
433            HeadMode::Bytes(n) => {
434                // sendfile: zero-copy, reads only N bytes
435                #[cfg(target_os = "linux")]
436                {
437                    use std::os::unix::io::AsRawFd;
438                    let stdout = io::stdout();
439                    let out_fd = stdout.as_raw_fd();
440                    if let Ok(true) = sendfile_bytes(path, *n, out_fd) {
441                        return Ok(true);
442                    }
443                }
444                // Non-Linux: still avoid full mmap
445                #[cfg(not(target_os = "linux"))]
446                {
447                    if let Ok(true) = head_bytes_streaming_file(path, *n, out) {
448                        return Ok(true);
449                    }
450                }
451            }
452            _ => {
453                // LinesFromEnd and BytesFromEnd need the whole file — use mmap
454            }
455        }
456    }
457
458    // Fast path for stdin with positive line/byte counts — stream without buffering everything.
459    if filename == "-" {
460        match &config.mode {
461            HeadMode::Lines(n) => {
462                return match head_stdin_lines_streaming(*n, delimiter, out) {
463                    Ok(()) => Ok(true),
464                    Err(e) if e.kind() == io::ErrorKind::BrokenPipe => Ok(true),
465                    Err(e) => {
466                        eprintln!(
467                            "{}: standard input: {}",
468                            tool_name,
469                            crate::common::io_error_msg(&e)
470                        );
471                        Ok(false)
472                    }
473                };
474            }
475            HeadMode::Bytes(n) => {
476                return match head_stdin_bytes_streaming(*n, out) {
477                    Ok(()) => Ok(true),
478                    Err(e) if e.kind() == io::ErrorKind::BrokenPipe => Ok(true),
479                    Err(e) => {
480                        eprintln!(
481                            "{}: standard input: {}",
482                            tool_name,
483                            crate::common::io_error_msg(&e)
484                        );
485                        Ok(false)
486                    }
487                };
488            }
489            _ => {} // LinesFromEnd/BytesFromEnd need full buffer
490        }
491    }
492
493    // Slow path: read entire file (needed for -n -N, -c -N, or stdin from-end modes)
494    let data: FileData = if filename == "-" {
495        match read_stdin() {
496            Ok(d) => FileData::Owned(d),
497            Err(e) => {
498                eprintln!(
499                    "{}: standard input: {}",
500                    tool_name,
501                    crate::common::io_error_msg(&e)
502                );
503                return Ok(false);
504            }
505        }
506    } else {
507        match read_file(Path::new(filename)) {
508            Ok(d) => d,
509            Err(e) => {
510                eprintln!(
511                    "{}: cannot open '{}' for reading: {}",
512                    tool_name,
513                    filename,
514                    crate::common::io_error_msg(&e)
515                );
516                return Ok(false);
517            }
518        }
519    };
520
521    match &config.mode {
522        HeadMode::Lines(n) => head_lines(&data, *n, delimiter, out)?,
523        HeadMode::LinesFromEnd(n) => head_lines_from_end(&data, *n, delimiter, out)?,
524        HeadMode::Bytes(n) => head_bytes(&data, *n, out)?,
525        HeadMode::BytesFromEnd(n) => head_bytes_from_end(&data, *n, out)?,
526    }
527
528    Ok(true)
529}
530
531/// Streaming head for positive byte count on non-Linux.
532#[cfg(not(target_os = "linux"))]
533fn head_bytes_streaming_file(path: &Path, n: u64, out: &mut impl Write) -> io::Result<bool> {
534    let mut file = std::fs::File::open(path)?;
535    let mut remaining = n as usize;
536    let mut buf = [0u8; 65536];
537
538    while remaining > 0 {
539        let to_read = remaining.min(buf.len());
540        let bytes_read = match file.read(&mut buf[..to_read]) {
541            Ok(0) => break,
542            Ok(n) => n,
543            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
544            Err(e) => return Err(e),
545        };
546        out.write_all(&buf[..bytes_read])?;
547        remaining -= bytes_read;
548    }
549
550    Ok(true)
551}
552
553/// Process head for stdin streaming (line mode, positive count)
554/// Reads chunks and counts lines, stopping early once count reached.
555pub fn head_stdin_lines_streaming(n: u64, delimiter: u8, out: &mut impl Write) -> io::Result<()> {
556    if n == 0 {
557        return Ok(());
558    }
559
560    let stdin = io::stdin();
561    let mut reader = stdin.lock();
562    let mut buf = [0u8; 262144];
563    let mut count = 0u64;
564
565    loop {
566        let bytes_read = match reader.read(&mut buf) {
567            Ok(0) => break,
568            Ok(n) => n,
569            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
570            Err(e) => return Err(e),
571        };
572
573        let chunk = &buf[..bytes_read];
574
575        // Count delimiters in this chunk
576        for pos in memchr_iter(delimiter, chunk) {
577            count += 1;
578            if count == n {
579                out.write_all(&chunk[..=pos])?;
580                return Ok(());
581            }
582        }
583
584        // Haven't reached N lines yet, output entire chunk
585        out.write_all(chunk)?;
586    }
587
588    Ok(())
589}
590
591/// Process head for stdin streaming (byte mode, positive count).
592/// Reads chunks and outputs up to N bytes, stopping early.
593fn head_stdin_bytes_streaming(n: u64, out: &mut impl Write) -> io::Result<()> {
594    if n == 0 {
595        return Ok(());
596    }
597
598    let stdin = io::stdin();
599    let mut reader = stdin.lock();
600    let mut buf = [0u8; 262144];
601    let mut remaining = n;
602
603    loop {
604        let to_read = (remaining as usize).min(buf.len());
605        let bytes_read = match reader.read(&mut buf[..to_read]) {
606            Ok(0) => break,
607            Ok(n) => n,
608            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
609            Err(e) => return Err(e),
610        };
611        out.write_all(&buf[..bytes_read])?;
612        remaining -= bytes_read as u64;
613        if remaining == 0 {
614            break;
615        }
616    }
617
618    Ok(())
619}