sha3sum 1.3.2

sha3sum - compute and check SHA3 message digest.
Documentation
/*
 *  This file is part of sha3sum
 *
 *  sha3sum is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  any later version.
 *
 *  sha3sum is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with sha3sum. If not, see <http://www.gnu.org/licenses/>
 */

//! Binary entry point for `sha3sum`.
//!
//! Wraps the [`sha3`](https://docs.rs/sha3) crate from [RustCrypto/hashes](https://github.com/RustCrypto/hashes).
//! Supports SHA3-224/256/384/512 and the Keccak family, mirroring the interface
//! of the GNU `sha256sum` / `sha512sum` utilities.
//!
//! CLI parsing is handled by [`clap`]. Multiple files are hashed in parallel
//! using one worker thread per logical CPU (`std::thread::available_parallelism`).

use clap::{ArgGroup, Parser};
use sha3::*;
use sha3sum::{
    wrapper::{hash_from_file, hash_from_reader, read_check_file},
    Mode, NO_BREAK_SPACE,
};
use sha3sum::wrapper::Sha3Mode;
use std::io::Error;
use std::path::Path;
use std::process::exit;
use std::sync::mpsc;
use std::thread;
use std::{fs, io};

// License identifier printed by --license
const LICENSE: &str = "GPL-3.0-or-later";
// Exit codes — mirroring the convention of GNU coreutils checksum tools
const EXIT_CODE_OK: i32 = 0;             // success
#[allow(dead_code)]
const EXIT_CODE_NOK: i32 = 1;            // reserved
const EXIT_CODE_WRONG_PARAMETERS: i32 = 2;   // bad / missing CLI arguments
const EXIT_CODE_FILE_ERROR: i32 = 64;        // file could not be opened or read
const EXIT_CODE_HASH_NOT_EQUAL: i32 = 65;    // computed hash differs from reference

/// A single hashing task, including all display parameters.
#[derive(Clone)]
struct HashTask {
    mode: Mode,
    file_name: String,
    hash_algorithm: Sha3Mode,
    is_bsd_display: bool,
    ref_hash: Option<String>,
    is_quiet: bool,
    is_status: bool,
}

/// A worker thread with its task-submission channel.
struct Worker {
    sender: Option<mpsc::Sender<HashTask>>,
}

/// CLI definition
#[derive(Parser)]
#[command(author=env!("CARGO_PKG_AUTHORS"), version=env!("CARGO_PKG_VERSION"), about=env!("CARGO_PKG_DESCRIPTION"), long_about = None)]
#[command(next_line_help = true)]
#[command(group(ArgGroup::new("tag_algo").args(["tag"]).requires("algorithm")))]
#[command(group(ArgGroup::new("hash").args(["check"]).conflicts_with("algorithm")))]
#[command(group(ArgGroup::new("read_mode").args(["text"]).conflicts_with("binary")))]
#[command(group(ArgGroup::new("show_help").args(["license"]).conflicts_with_all(["binary","text","quiet","status","tag","check"])))]
#[command(arg_required_else_help(true))]
struct Cli {
    #[arg(short, long, help = "sha3 algorithm {224 256 384 512 Keccak224 Keccak256 Keccak256Full Keccak384 Keccak512}")]
    algorithm: Option<String>,

    #[arg(short, long, help = "read SHA3 sums and file path from a file and check them")]
    check: Option<String>,

    #[arg(short, long, default_value_t = true, help = "read using Binary mode (default)")]
    binary: bool,

    #[arg(long, help = "don't print OK for each successfully verified file")]
    quiet: bool,

    #[arg(long, help = "don't output anything, status code shows success")]
    status: bool,

    #[arg(long, help = "create a BSD-style checksum")]
    tag: bool,

    #[arg(short, long, help = "read using Text mode")]
    text: bool,

    #[arg(short, long, help = "Prints license information")]
    license: bool,

    #[arg(help = "Displays the check sum for the files")]
    files: Option<Vec<String>>,
}

fn main() {
    let mut exit_code: i32 = EXIT_CODE_WRONG_PARAMETERS;
    let cli = Cli::parse();

    if cli.license {
        println!("The license is {LICENSE}. The gpl.txt file contains the full Text.");
        exit_code = EXIT_CODE_OK;
    }

    let mode = if cli.text { Mode::Text } else { Mode::Binary };

    if let Some(check) = cli.check.as_deref() {
        let result_list = read_check_file(check, cli.status);
        if let Ok(list) = result_list.as_ref() {
            let tasks = list.iter().map(|item| HashTask {
                mode: item.mode,
                file_name: item.file_name.clone(),
                hash_algorithm: item.algorithm,
                is_bsd_display: false,
                ref_hash: Some(item.hash.clone()),
                is_quiet: cli.quiet,
                is_status: cli.status,
            }).collect();
            exit_code = do_hashes(tasks);
        } else {
            if !cli.status {
                eprintln!("{}", result_list.unwrap_err());
            }
            exit_code = EXIT_CODE_FILE_ERROR;
        }
    } else if let Some(algorithm) = cli.algorithm.as_deref() {
        let selected_sha3_mode: Option<Sha3Mode> = match algorithm.to_lowercase().as_str() {
            "224" | "sha3_224"    => Some(Sha3Mode::Sha3_224),
            "256" | "sha3_256"    => Some(Sha3Mode::Sha3_256),
            "384" | "sha3_384"    => Some(Sha3Mode::Sha3_384),
            "512" | "sha3_512"    => Some(Sha3Mode::Sha3_512),
            "keccak224"           => Some(Sha3Mode::Keccak224),
            "keccak256"           => Some(Sha3Mode::Keccak256),
            "keccak256full"       => Some(Sha3Mode::Keccak256Full),
            "keccak384"           => Some(Sha3Mode::Keccak384),
            "keccak512"           => Some(Sha3Mode::Keccak512),
            // SHAKE-128/256 are XOFs (variable-length output) — not yet supported
            "shake128"            => unimplemented!(),
            "shake256"            => unimplemented!(),
            v => { eprintln!("Invalid value for algorithm. {v}"); None }
        };

        if let Some(selected_mode) = selected_sha3_mode {
            if let Some(files) = cli.files.as_deref() {
                let mut all_file: Vec<String> = Vec::new();
                for param in files {
                    let candidate = Path::new(param.as_str());
                    if candidate.is_file() {
                        all_file.push(param.to_string());
                    } else if candidate.is_dir() {
                        if let Ok(entries) = fs::read_dir(param) {
                            entries.filter_map(Result::ok)
                                .filter(|d| d.metadata().unwrap().is_file())
                                .filter(|d| !d.file_name().into_string().unwrap().starts_with('.'))
                                .for_each(|f| all_file.push(f.path().to_str().unwrap().to_string()));
                        }
                    } else {
                        eprintln!("Error file: {param} has been rejected.");
                    }
                }
                let tasks = all_file.into_iter().map(|file| HashTask {
                    mode,
                    file_name: file,
                    hash_algorithm: selected_mode,
                    is_bsd_display: cli.tag,
                    ref_hash: None,
                    is_quiet: false,
                    is_status: false,
                }).collect();
                exit_code = do_hashes(tasks);
            } else {
                // Hash stdin
                let result = match selected_mode {
                    Sha3Mode::Sha3_224     => hash_from_reader::<Sha3_224>(Box::new(io::stdin())),
                    Sha3Mode::Sha3_256     => hash_from_reader::<Sha3_256>(Box::new(io::stdin())),
                    Sha3Mode::Sha3_384     => hash_from_reader::<Sha3_384>(Box::new(io::stdin())),
                    Sha3Mode::Sha3_512     => hash_from_reader::<Sha3_512>(Box::new(io::stdin())),
                    Sha3Mode::Keccak224    => hash_from_reader::<Keccak224>(Box::new(io::stdin())),
                    Sha3Mode::Keccak256    => hash_from_reader::<Keccak256>(Box::new(io::stdin())),
                    Sha3Mode::Keccak384    => hash_from_reader::<Keccak384>(Box::new(io::stdin())),
                    Sha3Mode::Keccak256Full=> hash_from_reader::<Keccak256Full>(Box::new(io::stdin())),
                    Sha3Mode::Keccak512    => hash_from_reader::<Keccak512>(Box::new(io::stdin())),
                    _ => Err(Error::other("Could not determine algorithm.")),
                };
                if let Ok(hash) = result.as_ref() {
                    display_result("-", Some(Mode::Binary), hash.as_str(), None, selected_sha3_mode, false, false);
                    exit_code = EXIT_CODE_OK;
                } else {
                    eprintln!("{}", result.unwrap_err());
                    exit_code = EXIT_CODE_FILE_ERROR;
                }
            }
        } else {
            eprintln!("Invalid parameters");
            exit_code = EXIT_CODE_WRONG_PARAMETERS;
        }
    } else if exit_code != EXIT_CODE_OK {
        eprintln!("Nothing to do!\n");
    }

    exit(exit_code);
}

/// Dispatch tasks: single task runs inline; multiple tasks use a thread pool.
fn do_hashes(tasks: Vec<HashTask>) -> i32 {
    let mut result = EXIT_CODE_OK;
    if tasks.is_empty() {
        return EXIT_CODE_FILE_ERROR;
    } else if tasks.len() == 1 {
        return execute_task(tasks.into_iter().next().unwrap());
    }

    let (result_tx, result_rx) = mpsc::channel::<i32>();
    let mut workers: Vec<Worker> = Vec::new();
    create_workers(&mut workers, result_tx);
    let nb_worker = workers.len();
    let nb_task = tasks.len();

    for (i, task) in tasks.into_iter().enumerate() {
        workers[i % nb_worker].sender.as_ref().unwrap()
            .send(task).expect("failed to send task to worker");
    }
    for (done, code) in result_rx.iter().enumerate() {
        if code != EXIT_CODE_OK {
            result += code;
        }
        if done >= nb_task - 1 {
            break;
        }
    }
    result
}

/// Spawn one worker thread per logical CPU.
fn create_workers(workers: &mut Vec<Worker>, result_sender: mpsc::Sender<i32>) {
    let cpus = thread::available_parallelism().map_or(1, |n| n.get());
    for i in 0..cpus {
        let (task_tx, task_rx) = mpsc::channel::<HashTask>();
        let rs = result_sender.clone();
        thread::Builder::new()
            .name(format!("worker-{i}"))
            .spawn(move || {
                for task in task_rx {
                    rs.send(execute_task(task)).unwrap();
                }
            })
            .unwrap_or_else(|_| panic!("failed to spawn worker thread {i}"));
        workers.push(Worker { sender: Some(task_tx) });
    }
}

/// Execute one hashing task and return an exit code.
fn execute_task(task: HashTask) -> i32 {
    let result = match task.hash_algorithm {
        Sha3Mode::Sha3_224     => hash_from_file::<Sha3_224>(    &task.file_name, task.mode, task.is_status),
        Sha3Mode::Sha3_256     => hash_from_file::<Sha3_256>(    &task.file_name, task.mode, task.is_status),
        Sha3Mode::Sha3_384     => hash_from_file::<Sha3_384>(    &task.file_name, task.mode, task.is_status),
        Sha3Mode::Sha3_512     => hash_from_file::<Sha3_512>(    &task.file_name, task.mode, task.is_status),
        Sha3Mode::Keccak224    => hash_from_file::<Keccak224>(   &task.file_name, task.mode, task.is_status),
        Sha3Mode::Keccak256    => hash_from_file::<Keccak256>(   &task.file_name, task.mode, task.is_status),
        Sha3Mode::Keccak384    => hash_from_file::<Keccak384>(   &task.file_name, task.mode, task.is_status),
        Sha3Mode::Keccak256Full=> hash_from_file::<Keccak256Full>(&task.file_name, task.mode, task.is_status),
        Sha3Mode::Keccak512    => hash_from_file::<Keccak512>(   &task.file_name, task.mode, task.is_status),
        _ => Err(Error::other("Could not determine algorithm.")),
    };
    match result {
        Ok(hash) => {
            let display_mode = if task.is_bsd_display { None } else { Some(task.mode) };
            display_result(&task.file_name, display_mode, &hash, task.ref_hash.as_ref(), Some(task.hash_algorithm), task.is_quiet, task.is_status);
            if let Some(ref_hash) = &task.ref_hash {
                if !ref_hash.eq_ignore_ascii_case(&hash) {
                    return EXIT_CODE_HASH_NOT_EQUAL;
                }
            }
            EXIT_CODE_OK
        }
        Err(e) => {
            if !task.is_status {
                eprintln!("{e}");
            }
            EXIT_CODE_FILE_ERROR
        }
    }
}

/// Format and print a result line in one of four formats:
///
/// - BSD tag:    `ALGO (file) = hash`
/// - Binary:     `hash *file`
/// - Text:       `hash file`
/// - Check:      `file Ok` / `file NOk`
fn display_result(
    file_name: &str,
    mode: Option<Mode>,
    hash: &str,
    ref_hash: Option<&String>,
    algorithm: Option<Sha3Mode>,
    is_quiet: bool,
    is_status: bool,
) {
    let text = if let Some(the_ref) = ref_hash {
        let ok = the_ref.eq_ignore_ascii_case(hash);
        if is_status || (is_quiet && ok) {
            None
        } else {
            let label = if ok { "Ok" } else { "NOk" };
            Some(format!("{file_name}{NO_BREAK_SPACE}{label}"))
        }
    } else if let Some(the_mode) = mode {
        Some(match the_mode {
            Mode::Binary => format!("{hash}{NO_BREAK_SPACE}*{file_name}"),
            Mode::Text   => format!("{hash}{NO_BREAK_SPACE}{file_name}"),
        })
    } else {
        Some(format!(
            "{}{NO_BREAK_SPACE}({file_name}){NO_BREAK_SPACE}={NO_BREAK_SPACE}{hash}",
            algorithm.unwrap()
        ))
    };
    if let Some(t) = text {
        println!("{t}");
    }
}

#[test]
fn verify_cli() {
    use clap::CommandFactory;
    Cli::command().debug_assert()
}