Skip to main content

coreutils_rs/head/
core.rs

1use std::io::{self, Read, Write};
2use std::path::Path;
3
4use memchr::{memchr_iter, memrchr_iter};
5
6use crate::common::io::{FileData, read_file, read_stdin};
7
8/// Mode for head operation
9#[derive(Clone, Debug)]
10pub enum HeadMode {
11    /// First N lines (default: 10)
12    Lines(u64),
13    /// All but last N lines
14    LinesFromEnd(u64),
15    /// First N bytes
16    Bytes(u64),
17    /// All but last N bytes
18    BytesFromEnd(u64),
19}
20
21/// Configuration for head
22#[derive(Clone, Debug)]
23pub struct HeadConfig {
24    pub mode: HeadMode,
25    pub zero_terminated: bool,
26}
27
28impl Default for HeadConfig {
29    fn default() -> Self {
30        Self {
31            mode: HeadMode::Lines(10),
32            zero_terminated: false,
33        }
34    }
35}
36
37/// Parse a numeric argument with optional suffix (K, M, G, etc.)
38/// Supports: b(512), kB(1000), K(1024), MB(1e6), M(1048576), GB(1e9), G(1<<30),
39/// TB, T, PB, P, EB, E, ZB, Z, YB, Y
40pub fn parse_size(s: &str) -> Result<u64, String> {
41    let s = s.trim();
42    if s.is_empty() {
43        return Err("empty size".to_string());
44    }
45
46    // Find where the numeric part ends
47    let mut num_end = 0;
48    for (i, c) in s.char_indices() {
49        if c.is_ascii_digit() || (i == 0 && (c == '+' || c == '-')) {
50            num_end = i + c.len_utf8();
51        } else {
52            break;
53        }
54    }
55
56    if num_end == 0 {
57        return Err(format!("invalid number: '{}'", s));
58    }
59
60    let num_str = &s[..num_end];
61    let suffix = &s[num_end..];
62
63    let num: u64 = match num_str.parse() {
64        Ok(n) => n,
65        Err(_) => {
66            // If the string is valid digits but overflows u64, clamp to u64::MAX
67            // like GNU coreutils does for huge counts
68            let digits = num_str
69                .strip_prefix('+')
70                .or_else(|| num_str.strip_prefix('-'))
71                .unwrap_or(num_str);
72            if !digits.is_empty() && digits.chars().all(|c| c.is_ascii_digit()) {
73                u64::MAX
74            } else {
75                return Err(format!("invalid number: '{}'", num_str));
76            }
77        }
78    };
79
80    let multiplier: u64 = match suffix {
81        "" => 1,
82        "b" => 512,
83        "kB" => 1000,
84        "k" | "K" | "KiB" => 1024,
85        "MB" => 1_000_000,
86        "M" | "MiB" => 1_048_576,
87        "GB" => 1_000_000_000,
88        "G" | "GiB" => 1_073_741_824,
89        "TB" => 1_000_000_000_000,
90        "T" | "TiB" => 1_099_511_627_776,
91        "PB" => 1_000_000_000_000_000,
92        "P" | "PiB" => 1_125_899_906_842_624,
93        "EB" => 1_000_000_000_000_000_000,
94        "E" | "EiB" => 1_152_921_504_606_846_976,
95        // ZB/Z/YB/Y would overflow u64, treat as max
96        "ZB" | "Z" | "ZiB" | "YB" | "Y" | "YiB" => {
97            if num > 0 {
98                return Ok(u64::MAX);
99            }
100            return Ok(0);
101        }
102        _ => return Err(format!("invalid suffix in '{}'", s)),
103    };
104
105    num.checked_mul(multiplier)
106        .ok_or_else(|| format!("number too large: '{}'", s))
107}
108
109/// Output first N lines from data
110pub fn head_lines(data: &[u8], n: u64, delimiter: u8, out: &mut impl Write) -> io::Result<()> {
111    if n == 0 || data.is_empty() {
112        return Ok(());
113    }
114
115    let mut count = 0u64;
116    for pos in memchr_iter(delimiter, data) {
117        count += 1;
118        if count == n {
119            return out.write_all(&data[..=pos]);
120        }
121    }
122
123    // Fewer than N lines — output everything
124    out.write_all(data)
125}
126
127/// Output all but last N lines from data.
128/// Uses reverse scanning (memrchr_iter) for single-pass O(n) instead of 2-pass.
129pub fn head_lines_from_end(
130    data: &[u8],
131    n: u64,
132    delimiter: u8,
133    out: &mut impl Write,
134) -> io::Result<()> {
135    if n == 0 {
136        return out.write_all(data);
137    }
138    if data.is_empty() {
139        return Ok(());
140    }
141
142    // Scan backward: skip N delimiters (= N lines), then the next delimiter
143    // marks the end of the last line to keep.
144    // If the data does not end with a delimiter, the unterminated last "line"
145    // counts as one line to skip.
146    let mut count = if !data.is_empty() && *data.last().unwrap() != delimiter {
147        1u64
148    } else {
149        0u64
150    };
151    for pos in memrchr_iter(delimiter, data) {
152        count += 1;
153        if count > n {
154            return out.write_all(&data[..=pos]);
155        }
156    }
157
158    // Fewer than N+1 lines → N >= total lines → output nothing
159    Ok(())
160}
161
162/// Output first N bytes from data
163pub fn head_bytes(data: &[u8], n: u64, out: &mut impl Write) -> io::Result<()> {
164    let n = n.min(data.len() as u64) as usize;
165    if n > 0 {
166        out.write_all(&data[..n])?;
167    }
168    Ok(())
169}
170
171/// Output all but last N bytes from data
172pub fn head_bytes_from_end(data: &[u8], n: u64, out: &mut impl Write) -> io::Result<()> {
173    if n >= data.len() as u64 {
174        return Ok(());
175    }
176    let end = data.len() - n as usize;
177    if end > 0 {
178        out.write_all(&data[..end])?;
179    }
180    Ok(())
181}
182
183/// Use sendfile for zero-copy byte output on Linux
184#[cfg(target_os = "linux")]
185pub fn sendfile_bytes(path: &Path, n: u64, out_fd: i32) -> io::Result<bool> {
186    use std::os::unix::fs::OpenOptionsExt;
187
188    let file = std::fs::OpenOptions::new()
189        .read(true)
190        .custom_flags(libc::O_NOATIME)
191        .open(path)
192        .or_else(|_| std::fs::File::open(path))?;
193
194    let metadata = file.metadata()?;
195    let file_size = metadata.len();
196    let to_send = n.min(file_size) as usize;
197
198    if to_send == 0 {
199        return Ok(true);
200    }
201
202    use std::os::unix::io::AsRawFd;
203    let in_fd = file.as_raw_fd();
204    let mut offset: libc::off_t = 0;
205    let mut remaining = to_send;
206
207    while remaining > 0 {
208        let chunk = remaining.min(0x7ffff000); // sendfile max per call
209        let ret = unsafe { libc::sendfile(out_fd, in_fd, &mut offset, chunk) };
210        if ret > 0 {
211            remaining -= ret as usize;
212        } else if ret == 0 {
213            break;
214        } else {
215            let err = io::Error::last_os_error();
216            if err.kind() == io::ErrorKind::Interrupted {
217                continue;
218            }
219            return Err(err);
220        }
221    }
222
223    Ok(true)
224}
225
226/// Streaming head for positive line count on a regular file.
227/// Reads small chunks from the start, never mmaps the whole file.
228/// This is the critical fast path: `head -n 10` on a 100MB file
229/// reads only a few KB instead of mapping all 100MB.
230fn head_lines_streaming_file(
231    path: &Path,
232    n: u64,
233    delimiter: u8,
234    out: &mut impl Write,
235) -> io::Result<bool> {
236    if n == 0 {
237        return Ok(true);
238    }
239
240    #[cfg(target_os = "linux")]
241    let file = {
242        use std::os::unix::fs::OpenOptionsExt;
243        std::fs::OpenOptions::new()
244            .read(true)
245            .custom_flags(libc::O_NOATIME)
246            .open(path)
247            .or_else(|_| std::fs::File::open(path))?
248    };
249    #[cfg(not(target_os = "linux"))]
250    let file = std::fs::File::open(path)?;
251
252    let mut file = file;
253    let mut buf = [0u8; 65536];
254    let mut count = 0u64;
255
256    loop {
257        let bytes_read = match file.read(&mut buf) {
258            Ok(0) => break,
259            Ok(n) => n,
260            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
261            Err(e) => return Err(e),
262        };
263
264        let chunk = &buf[..bytes_read];
265
266        for pos in memchr_iter(delimiter, chunk) {
267            count += 1;
268            if count == n {
269                out.write_all(&chunk[..=pos])?;
270                return Ok(true);
271            }
272        }
273
274        out.write_all(chunk)?;
275    }
276
277    Ok(true)
278}
279
280/// Process a single file/stdin for head
281pub fn head_file(
282    filename: &str,
283    config: &HeadConfig,
284    out: &mut impl Write,
285    tool_name: &str,
286) -> io::Result<bool> {
287    let delimiter = if config.zero_terminated { b'\0' } else { b'\n' };
288
289    if filename != "-" {
290        let path = Path::new(filename);
291
292        // Fast paths that avoid reading/mmapping the whole file
293        match &config.mode {
294            HeadMode::Lines(n) => {
295                // Streaming: read small chunks, stop after N lines
296                match head_lines_streaming_file(path, *n, delimiter, out) {
297                    Ok(true) => return Ok(true),
298                    Err(e) => {
299                        eprintln!(
300                            "{}: cannot open '{}' for reading: {}",
301                            tool_name,
302                            filename,
303                            crate::common::io_error_msg(&e)
304                        );
305                        return Ok(false);
306                    }
307                    _ => {}
308                }
309            }
310            HeadMode::Bytes(n) => {
311                // sendfile: zero-copy, reads only N bytes
312                #[cfg(target_os = "linux")]
313                {
314                    use std::os::unix::io::AsRawFd;
315                    let stdout = io::stdout();
316                    let out_fd = stdout.as_raw_fd();
317                    if let Ok(true) = sendfile_bytes(path, *n, out_fd) {
318                        return Ok(true);
319                    }
320                }
321                // Non-Linux: still avoid full mmap
322                #[cfg(not(target_os = "linux"))]
323                {
324                    if let Ok(true) = head_bytes_streaming_file(path, *n, out) {
325                        return Ok(true);
326                    }
327                }
328            }
329            _ => {
330                // LinesFromEnd and BytesFromEnd need the whole file — use mmap
331            }
332        }
333    }
334
335    // Fast path for stdin with positive line/byte counts — stream without buffering everything.
336    if filename == "-" {
337        match &config.mode {
338            HeadMode::Lines(n) => {
339                return match head_stdin_lines_streaming(*n, delimiter, out) {
340                    Ok(()) => Ok(true),
341                    Err(e) if e.kind() == io::ErrorKind::BrokenPipe => Ok(true),
342                    Err(e) => {
343                        eprintln!(
344                            "{}: standard input: {}",
345                            tool_name,
346                            crate::common::io_error_msg(&e)
347                        );
348                        Ok(false)
349                    }
350                };
351            }
352            HeadMode::Bytes(n) => {
353                return match head_stdin_bytes_streaming(*n, out) {
354                    Ok(()) => Ok(true),
355                    Err(e) if e.kind() == io::ErrorKind::BrokenPipe => Ok(true),
356                    Err(e) => {
357                        eprintln!(
358                            "{}: standard input: {}",
359                            tool_name,
360                            crate::common::io_error_msg(&e)
361                        );
362                        Ok(false)
363                    }
364                };
365            }
366            _ => {} // LinesFromEnd/BytesFromEnd need full buffer
367        }
368    }
369
370    // Slow path: read entire file (needed for -n -N, -c -N, or stdin from-end modes)
371    let data: FileData = if filename == "-" {
372        match read_stdin() {
373            Ok(d) => FileData::Owned(d),
374            Err(e) => {
375                eprintln!(
376                    "{}: standard input: {}",
377                    tool_name,
378                    crate::common::io_error_msg(&e)
379                );
380                return Ok(false);
381            }
382        }
383    } else {
384        match read_file(Path::new(filename)) {
385            Ok(d) => d,
386            Err(e) => {
387                eprintln!(
388                    "{}: cannot open '{}' for reading: {}",
389                    tool_name,
390                    filename,
391                    crate::common::io_error_msg(&e)
392                );
393                return Ok(false);
394            }
395        }
396    };
397
398    match &config.mode {
399        HeadMode::Lines(n) => head_lines(&data, *n, delimiter, out)?,
400        HeadMode::LinesFromEnd(n) => head_lines_from_end(&data, *n, delimiter, out)?,
401        HeadMode::Bytes(n) => head_bytes(&data, *n, out)?,
402        HeadMode::BytesFromEnd(n) => head_bytes_from_end(&data, *n, out)?,
403    }
404
405    Ok(true)
406}
407
408/// Streaming head for positive byte count on non-Linux.
409#[cfg(not(target_os = "linux"))]
410fn head_bytes_streaming_file(path: &Path, n: u64, out: &mut impl Write) -> io::Result<bool> {
411    let mut file = std::fs::File::open(path)?;
412    let mut remaining = n as usize;
413    let mut buf = [0u8; 65536];
414
415    while remaining > 0 {
416        let to_read = remaining.min(buf.len());
417        let bytes_read = match file.read(&mut buf[..to_read]) {
418            Ok(0) => break,
419            Ok(n) => n,
420            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
421            Err(e) => return Err(e),
422        };
423        out.write_all(&buf[..bytes_read])?;
424        remaining -= bytes_read;
425    }
426
427    Ok(true)
428}
429
430/// Process head for stdin streaming (line mode, positive count)
431/// Reads chunks and counts lines, stopping early once count reached.
432pub fn head_stdin_lines_streaming(n: u64, delimiter: u8, out: &mut impl Write) -> io::Result<()> {
433    if n == 0 {
434        return Ok(());
435    }
436
437    let stdin = io::stdin();
438    let mut reader = stdin.lock();
439    let mut buf = [0u8; 262144];
440    let mut count = 0u64;
441
442    loop {
443        let bytes_read = match reader.read(&mut buf) {
444            Ok(0) => break,
445            Ok(n) => n,
446            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
447            Err(e) => return Err(e),
448        };
449
450        let chunk = &buf[..bytes_read];
451
452        // Count delimiters in this chunk
453        for pos in memchr_iter(delimiter, chunk) {
454            count += 1;
455            if count == n {
456                out.write_all(&chunk[..=pos])?;
457                return Ok(());
458            }
459        }
460
461        // Haven't reached N lines yet, output entire chunk
462        out.write_all(chunk)?;
463    }
464
465    Ok(())
466}
467
468/// Process head for stdin streaming (byte mode, positive count).
469/// Reads chunks and outputs up to N bytes, stopping early.
470fn head_stdin_bytes_streaming(n: u64, out: &mut impl Write) -> io::Result<()> {
471    if n == 0 {
472        return Ok(());
473    }
474
475    let stdin = io::stdin();
476    let mut reader = stdin.lock();
477    let mut buf = [0u8; 262144];
478    let mut remaining = n;
479
480    loop {
481        let to_read = (remaining as usize).min(buf.len());
482        let bytes_read = match reader.read(&mut buf[..to_read]) {
483            Ok(0) => break,
484            Ok(n) => n,
485            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
486            Err(e) => return Err(e),
487        };
488        out.write_all(&buf[..bytes_read])?;
489        remaining -= bytes_read as u64;
490        if remaining == 0 {
491            break;
492        }
493    }
494
495    Ok(())
496}