use std::io::{self, Read, Write};
use std::path::Path;
use memchr::{memchr_iter, memrchr_iter};
use crate::common::io::{FileData, read_file, read_stdin};
#[derive(Clone, Debug)]
pub enum HeadMode {
Lines(u64),
LinesFromEnd(u64),
Bytes(u64),
BytesFromEnd(u64),
}
#[derive(Clone, Debug)]
pub struct HeadConfig {
pub mode: HeadMode,
pub zero_terminated: bool,
}
impl Default for HeadConfig {
fn default() -> Self {
Self {
mode: HeadMode::Lines(10),
zero_terminated: false,
}
}
}
pub fn parse_size(s: &str) -> Result<u64, String> {
let s = s.trim();
if s.is_empty() {
return Err("empty size".to_string());
}
let mut num_end = 0;
for (i, c) in s.char_indices() {
if c.is_ascii_digit() || (i == 0 && (c == '+' || c == '-')) {
num_end = i + c.len_utf8();
} else {
break;
}
}
if num_end == 0 {
return Err(format!("invalid number: '{}'", s));
}
let num_str = &s[..num_end];
let suffix = &s[num_end..];
let num: u64 = match num_str.parse() {
Ok(n) => n,
Err(_) => {
let digits = num_str
.strip_prefix('+')
.or_else(|| num_str.strip_prefix('-'))
.unwrap_or(num_str);
if !digits.is_empty() && digits.chars().all(|c| c.is_ascii_digit()) {
u64::MAX
} else {
return Err(format!("invalid number: '{}'", num_str));
}
}
};
let multiplier: u64 = match suffix {
"" => 1,
"b" => 512,
"kB" => 1000,
"k" | "K" | "KiB" => 1024,
"MB" => 1_000_000,
"M" | "MiB" => 1_048_576,
"GB" => 1_000_000_000,
"G" | "GiB" => 1_073_741_824,
"TB" => 1_000_000_000_000,
"T" | "TiB" => 1_099_511_627_776,
"PB" => 1_000_000_000_000_000,
"P" | "PiB" => 1_125_899_906_842_624,
"EB" => 1_000_000_000_000_000_000,
"E" | "EiB" => 1_152_921_504_606_846_976,
"ZB" | "Z" | "ZiB" | "YB" | "Y" | "YiB" => {
if num > 0 {
return Ok(u64::MAX);
}
return Ok(0);
}
_ => return Err(format!("invalid suffix in '{}'", s)),
};
num.checked_mul(multiplier)
.ok_or_else(|| format!("number too large: '{}'", s))
}
pub fn head_lines(data: &[u8], n: u64, delimiter: u8, out: &mut impl Write) -> io::Result<()> {
if n == 0 || data.is_empty() {
return Ok(());
}
let mut count = 0u64;
for pos in memchr_iter(delimiter, data) {
count += 1;
if count == n {
return out.write_all(&data[..=pos]);
}
}
out.write_all(data)
}
pub fn head_lines_from_end(
data: &[u8],
n: u64,
delimiter: u8,
out: &mut impl Write,
) -> io::Result<()> {
if n == 0 {
return out.write_all(data);
}
if data.is_empty() {
return Ok(());
}
let mut count = if !data.is_empty() && *data.last().unwrap() != delimiter {
1u64
} else {
0u64
};
for pos in memrchr_iter(delimiter, data) {
count += 1;
if count > n {
return out.write_all(&data[..=pos]);
}
}
Ok(())
}
pub fn head_bytes(data: &[u8], n: u64, out: &mut impl Write) -> io::Result<()> {
let n = n.min(data.len() as u64) as usize;
if n > 0 {
out.write_all(&data[..n])?;
}
Ok(())
}
pub fn head_bytes_from_end(data: &[u8], n: u64, out: &mut impl Write) -> io::Result<()> {
if n >= data.len() as u64 {
return Ok(());
}
let end = data.len() - n as usize;
if end > 0 {
out.write_all(&data[..end])?;
}
Ok(())
}
#[cfg(target_os = "linux")]
fn write_all_raw(mut data: &[u8]) -> io::Result<()> {
while !data.is_empty() {
let ret = unsafe { libc::write(1, data.as_ptr() as *const libc::c_void, data.len()) };
if ret > 0 {
data = &data[ret as usize..];
} else if ret == 0 {
return Err(io::Error::new(io::ErrorKind::WriteZero, "write returned 0"));
} else {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::Interrupted {
continue;
}
return Err(err);
}
}
Ok(())
}
pub fn head_file_direct(filename: &str, n: u64, delimiter: u8) -> io::Result<bool> {
if n == 0 {
return Ok(true);
}
let path = Path::new(filename);
#[cfg(target_os = "linux")]
{
use std::os::unix::fs::OpenOptionsExt;
let file = std::fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NOATIME)
.open(path)
.or_else(|_| std::fs::File::open(path));
let mut file = match file {
Ok(f) => f,
Err(e) => {
eprintln!(
"head: cannot open '{}' for reading: {}",
filename,
crate::common::io_error_msg(&e)
);
return Ok(false);
}
};
{
use std::os::unix::io::AsRawFd;
unsafe {
libc::posix_fadvise(file.as_raw_fd(), 0, 0, libc::POSIX_FADV_SEQUENTIAL);
}
}
let mut buf = [0u8; 65536];
let mut count = 0u64;
loop {
let bytes_read = match file.read(&mut buf) {
Ok(0) => break,
Ok(n) => n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
let chunk = &buf[..bytes_read];
for pos in memchr_iter(delimiter, chunk) {
count += 1;
if count == n {
write_all_raw(&chunk[..=pos])?;
return Ok(true);
}
}
write_all_raw(chunk)?;
}
return Ok(true);
}
#[cfg(not(target_os = "linux"))]
{
let stdout = io::stdout();
let mut out = io::BufWriter::with_capacity(8192, stdout.lock());
match head_lines_streaming_file(path, n, delimiter, &mut out) {
Ok(true) => {
out.flush()?;
Ok(true)
}
Ok(false) => Ok(false),
Err(e) => {
eprintln!(
"head: cannot open '{}' for reading: {}",
filename,
crate::common::io_error_msg(&e)
);
Ok(false)
}
}
}
}
#[cfg(target_os = "linux")]
pub fn sendfile_bytes(path: &Path, n: u64, out_fd: i32) -> io::Result<bool> {
use std::os::unix::fs::OpenOptionsExt;
let file = std::fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NOATIME)
.open(path)
.or_else(|_| std::fs::File::open(path))?;
{
use std::os::unix::io::AsRawFd;
unsafe {
libc::posix_fadvise(file.as_raw_fd(), 0, 0, libc::POSIX_FADV_SEQUENTIAL);
}
}
let metadata = file.metadata()?;
let file_size = metadata.len();
let to_send = n.min(file_size) as usize;
if to_send == 0 {
return Ok(true);
}
use std::os::unix::io::AsRawFd;
let in_fd = file.as_raw_fd();
let mut offset: libc::off_t = 0;
let mut remaining = to_send;
let total = to_send;
while remaining > 0 {
let chunk = remaining.min(0x7ffff000); let ret = unsafe { libc::sendfile(out_fd, in_fd, &mut offset, chunk) };
if ret > 0 {
remaining -= ret as usize;
} else if ret == 0 {
break;
} else {
let err = io::Error::last_os_error();
if err.kind() == io::ErrorKind::Interrupted {
continue;
}
if err.raw_os_error() == Some(libc::EINVAL) && remaining == total {
let mut file = file;
let mut buf = [0u8; 65536];
let mut left = to_send;
while left > 0 {
let to_read = left.min(buf.len());
let nr = match file.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(nr) => nr,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
write_all_raw(&buf[..nr])?;
left -= nr;
}
return Ok(true);
}
return Err(err);
}
}
Ok(true)
}
fn head_lines_streaming_file(
path: &Path,
n: u64,
delimiter: u8,
out: &mut impl Write,
) -> io::Result<bool> {
if n == 0 {
return Ok(true);
}
#[cfg(target_os = "linux")]
let file = {
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NOATIME)
.open(path)
.or_else(|_| std::fs::File::open(path))?
};
#[cfg(not(target_os = "linux"))]
let file = std::fs::File::open(path)?;
let mut file = file;
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
unsafe {
libc::posix_fadvise(file.as_raw_fd(), 0, 0, libc::POSIX_FADV_SEQUENTIAL);
}
}
let mut buf = [0u8; 65536];
let mut count = 0u64;
loop {
let bytes_read = match file.read(&mut buf) {
Ok(0) => break,
Ok(n) => n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
let chunk = &buf[..bytes_read];
for pos in memchr_iter(delimiter, chunk) {
count += 1;
if count == n {
out.write_all(&chunk[..=pos])?;
return Ok(true);
}
}
out.write_all(chunk)?;
}
Ok(true)
}
pub fn head_file(
filename: &str,
config: &HeadConfig,
out: &mut impl Write,
tool_name: &str,
) -> io::Result<bool> {
let delimiter = if config.zero_terminated { b'\0' } else { b'\n' };
if filename != "-" {
let path = Path::new(filename);
match &config.mode {
HeadMode::Lines(n) => {
match head_lines_streaming_file(path, *n, delimiter, out) {
Ok(true) => return Ok(true),
Err(e) => {
eprintln!(
"{}: cannot open '{}' for reading: {}",
tool_name,
filename,
crate::common::io_error_msg(&e)
);
return Ok(false);
}
_ => {}
}
}
HeadMode::Bytes(n) => {
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
let stdout = io::stdout();
let out_fd = stdout.as_raw_fd();
if let Ok(true) = sendfile_bytes(path, *n, out_fd) {
return Ok(true);
}
}
#[cfg(not(target_os = "linux"))]
{
if let Ok(true) = head_bytes_streaming_file(path, *n, out) {
return Ok(true);
}
}
}
_ => {
}
}
}
if filename == "-" {
match &config.mode {
HeadMode::Lines(n) => {
return match head_stdin_lines_streaming(*n, delimiter, out) {
Ok(()) => Ok(true),
Err(e) if e.kind() == io::ErrorKind::BrokenPipe => Ok(true),
Err(e) => {
eprintln!(
"{}: standard input: {}",
tool_name,
crate::common::io_error_msg(&e)
);
Ok(false)
}
};
}
HeadMode::Bytes(n) => {
return match head_stdin_bytes_streaming(*n, out) {
Ok(()) => Ok(true),
Err(e) if e.kind() == io::ErrorKind::BrokenPipe => Ok(true),
Err(e) => {
eprintln!(
"{}: standard input: {}",
tool_name,
crate::common::io_error_msg(&e)
);
Ok(false)
}
};
}
_ => {} }
}
let data: FileData = if filename == "-" {
match read_stdin() {
Ok(d) => FileData::Owned(d),
Err(e) => {
eprintln!(
"{}: standard input: {}",
tool_name,
crate::common::io_error_msg(&e)
);
return Ok(false);
}
}
} else {
match read_file(Path::new(filename)) {
Ok(d) => d,
Err(e) => {
eprintln!(
"{}: cannot open '{}' for reading: {}",
tool_name,
filename,
crate::common::io_error_msg(&e)
);
return Ok(false);
}
}
};
match &config.mode {
HeadMode::Lines(n) => head_lines(&data, *n, delimiter, out)?,
HeadMode::LinesFromEnd(n) => head_lines_from_end(&data, *n, delimiter, out)?,
HeadMode::Bytes(n) => head_bytes(&data, *n, out)?,
HeadMode::BytesFromEnd(n) => head_bytes_from_end(&data, *n, out)?,
}
Ok(true)
}
#[cfg(not(target_os = "linux"))]
fn head_bytes_streaming_file(path: &Path, n: u64, out: &mut impl Write) -> io::Result<bool> {
let mut file = std::fs::File::open(path)?;
let mut remaining = n as usize;
let mut buf = [0u8; 65536];
while remaining > 0 {
let to_read = remaining.min(buf.len());
let bytes_read = match file.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
out.write_all(&buf[..bytes_read])?;
remaining -= bytes_read;
}
Ok(true)
}
pub fn head_stdin_lines_streaming(n: u64, delimiter: u8, out: &mut impl Write) -> io::Result<()> {
if n == 0 {
return Ok(());
}
let stdin = io::stdin();
let mut reader = stdin.lock();
let mut buf = [0u8; 262144];
let mut count = 0u64;
loop {
let bytes_read = match reader.read(&mut buf) {
Ok(0) => break,
Ok(n) => n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
let chunk = &buf[..bytes_read];
for pos in memchr_iter(delimiter, chunk) {
count += 1;
if count == n {
out.write_all(&chunk[..=pos])?;
return Ok(());
}
}
out.write_all(chunk)?;
}
Ok(())
}
fn head_stdin_bytes_streaming(n: u64, out: &mut impl Write) -> io::Result<()> {
if n == 0 {
return Ok(());
}
let stdin = io::stdin();
let mut reader = stdin.lock();
let mut buf = [0u8; 262144];
let mut remaining = n;
loop {
let to_read = (remaining as usize).min(buf.len());
let bytes_read = match reader.read(&mut buf[..to_read]) {
Ok(0) => break,
Ok(n) => n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
out.write_all(&buf[..bytes_read])?;
remaining -= bytes_read as u64;
if remaining == 0 {
break;
}
}
Ok(())
}