tasign 0.1.2

TA ELF signing utilities with CMS/PKCS#7 support
Documentation
//! 命令行:签名 / 验签 / 写出中间文件。

use std::collections::HashMap;
use std::fs;
use std::path::Path;
use tasign::{
    append_ta_signature_objcopy, build_plain_bin, extract_cms_detached_parts,
    sign_gmssl_cms_attached_native, sign_sm2_cms, verify_elf_signature,
    verify_elf_signature_from_parts, write_outputs,
    CmsSignAlgorithm, SignInputs,
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    env_logger::Builder::from_default_env().init();

    let args: Vec<String> = std::env::args().collect();
    if args.len() < 2 {
        eprintln!(
            "usage:\n  {} sign --elf P --out-dir D --leaf-cert L [--support-intermediates true|false] [--intermediate-count N] [--intermediate-certs a.der,b.der] [--intermediate-cert I] [--algorithm sm2-sm3|rsa-sha256|ecdsa-sha256] --leaf-key K --leaf-key-pass PASS [--gmssl PATH]]  # 省略 --gmssl 时:$GMSSL → PATH 中 gmssl → ./GmSSL/build/bin/gmssl\n  {} verify --pkcs7 P --plain T [--ca-cert root.pem]\n  {} verify-elf --elf E [--ca-cert root.pem]\n  {} write-elf --input-elf A --signature-bin S --output-elf B [--objcopy /path/to/objcopy]\n  {} plain --elf E --out O\n  {} cms-sign --plain P --leaf-cert L [--support-intermediates true|false] [--intermediate-count N] [--intermediate-certs a.der,b.der] [--intermediate-cert I] [--cms-attached true|false] [--cms-out der|pem] [--cms-compat pkcs7|gmssl] [--algorithm sm2-sm3|rsa-sha256|ecdsa-sha256] --leaf-key K --leaf-key-pass PASS --out O [--gmssl PATH]]\n  {} cms-extract --pkcs7 S --out-signed-attrs A --out-sm2-sig G --out-leaf-cert C",
            args[0], args[0], args[0], args[0], args[0], args[0], args[0]
        );
        std::process::exit(2);
    }

    let flags = parse_flags(&args[2..]);
    match args[1].as_str() {
        "sign" => {
            let elf = flags.get("--elf").cloned().ok_or("missing --elf")?;
            let out_dir = flags.get("--out-dir").cloned().ok_or("missing --out-dir")?;
            let leaf_cert = fs::read(flags.get("--leaf-cert").cloned().ok_or("missing --leaf-cert")?)?;
            let int_certs = load_intermediate_certs(&flags)?;
            let leaf_key = flags.get("--leaf-key").cloned().ok_or("missing --leaf-key")?;
            let leaf_key_pass = flags
                .get("--leaf-key-pass")
                .cloned()
                .ok_or("missing --leaf-key-pass")?;
            let gmssl_cli = flags.get("--gmssl").map(|s| Path::new(s.as_str()));
            let algorithm = parse_algorithm(flags.get("--algorithm").map(|s| s.as_str()))?;
            write_outputs(
                elf.as_ref(),
                out_dir.as_ref(),
                &leaf_cert,
                &int_certs,
                leaf_key.as_ref(),
                &leaf_key_pass,
                gmssl_cli,
                algorithm,
            )?;
            println!("wrote plain.bin and signature.bin to {out_dir}");
        }
        "verify" => {
            let pkcs7 = fs::read(flags.get("--pkcs7").cloned().ok_or("missing --pkcs7")?)?;
            let plain = fs::read(flags.get("--plain").cloned().ok_or("missing --plain")?)?;
            let ca_owned = if let Some(ca_cert) = flags.get("--ca-cert") {
                Some(fs::read(ca_cert)?)
            } else {
                None
            };
            verify_elf_signature_from_parts(&plain, &pkcs7, ca_owned.as_deref())?;
            println!("OK");
        }
        "verify-elf" => {
            let elf = fs::read(flags.get("--elf").cloned().ok_or("missing --elf")?)?;
            let ca_owned = if let Some(ca_cert) = flags.get("--ca-cert") {
                Some(fs::read(ca_cert)?)
            } else {
                None
            };
            verify_elf_signature(&elf, ca_owned.as_deref())?;
            println!("OK");
        }
        "write-elf" => {
            let elf_in = flags.get("--input-elf").cloned().ok_or("missing --input-elf")?;
            let sig = flags.get("--signature-bin").cloned().ok_or("missing --signature-bin")?;
            let elf_out = flags.get("--output-elf").cloned().ok_or("missing --output-elf")?;
            let objcopy = flags.get("--objcopy").map(|s| s.as_str()).unwrap_or("objcopy");
            append_ta_signature_objcopy(elf_in.as_ref(), sig.as_ref(), elf_out.as_ref(), objcopy)?;
            println!("wrote {elf_out}");
        }
        "plain" => {
            let elf = fs::read(flags.get("--elf").cloned().ok_or("missing --elf")?)?;
            let p = build_plain_bin(&elf)?;
            let out = flags.get("--out").cloned().ok_or("missing --out")?;
            fs::write(&out, &p)?;
            println!("wrote {} bytes to {}", p.len(), out);
        }
        "cms-sign" => {
            let plain_path = flags.get("--plain").cloned().ok_or("missing --plain")?;
            let plain = fs::read(&plain_path)?;
            let leaf_cert_path = flags.get("--leaf-cert").cloned().ok_or("missing --leaf-cert")?;
            let int_certs = load_intermediate_certs(&flags)?;
            let leaf_key = flags.get("--leaf-key").cloned().ok_or("missing --leaf-key")?;
            let leaf_key_pass = flags
                .get("--leaf-key-pass")
                .cloned()
                .ok_or("missing --leaf-key-pass")?;
            let gmssl_cli = flags.get("--gmssl").map(|s| Path::new(s.as_str()));
            let int_refs: Vec<&[u8]> = int_certs.iter().map(|v| v.as_slice()).collect();
            let algorithm = parse_algorithm(flags.get("--algorithm").map(|s| s.as_str()))?;
            let cms_attached = flags
                .get("--cms-attached")
                .map(|s| s == "true" || s == "1" || s.eq_ignore_ascii_case("yes"))
                .unwrap_or(false);
            let cms_out = flags.get("--cms-out").map(|s| s.as_str()).unwrap_or("der");
            let cms_compat = flags.get("--cms-compat").map(|s| s.as_str()).unwrap_or("pkcs7");
            let out = flags.get("--out").cloned().ok_or("missing --out")?;

            if cms_compat == "gmssl" {
                if !cms_attached {
                    return Err("cms-compat=gmssl 仅支持 --cms-attached true".into());
                }
                if !int_refs.is_empty() {
                    return Err("cms-compat=gmssl 当前仅支持不带中间证书".into());
                }
                if algorithm != CmsSignAlgorithm::Sm2WithSm3 {
                    return Err("cms-compat=gmssl 仅支持 --algorithm sm2-sm3".into());
                }
                let leaf_cert = fs::read(&leaf_cert_path)?;
                let key_pem = fs::read_to_string(&leaf_key)?;
                let der =
                    sign_gmssl_cms_attached_native(&leaf_cert, &key_pem, &leaf_key_pass, &plain)?;
                match cms_out {
                    "der" => fs::write(&out, &der)?,
                    "pem" => fs::write(&out, der_to_pem("CMS", &der))?,
                    _ => return Err("invalid --cms-out, expected der|pem".into()),
                }
            } else if cms_compat == "pkcs7" {
                let leaf_cert = fs::read(&leaf_cert_path)?;
                let der = sign_sm2_cms(SignInputs {
                    plain: &plain,
                    leaf_cert_der: &leaf_cert,
                    intermediate_certs_der: &int_refs,
                    cms_attached,
                    cms_use_gmssl_oid: false,
                    leaf_key_path: leaf_key.as_ref(),
                    leaf_key_pass: &leaf_key_pass,
                    gmssl_path: gmssl_cli,
                    algorithm,
                })?;
                match cms_out {
                    "der" => fs::write(&out, &der)?,
                    "pem" => fs::write(&out, der_to_pem("CMS", &der))?,
                    _ => return Err("invalid --cms-out, expected der|pem".into()),
                }
            } else {
                return Err("invalid --cms-compat, expected pkcs7|gmssl".into());
            }
            println!("wrote PKCS#7 to {out}");
        }
        "cms-extract" => {
            let pkcs7 = fs::read(flags.get("--pkcs7").cloned().ok_or("missing --pkcs7")?)?;
            let out_attrs = flags
                .get("--out-signed-attrs")
                .cloned()
                .ok_or("missing --out-signed-attrs")?;
            let out_sig = flags
                .get("--out-sm2-sig")
                .cloned()
                .ok_or("missing --out-sm2-sig")?;
            let out_leaf = flags
                .get("--out-leaf-cert")
                .cloned()
                .ok_or("missing --out-leaf-cert")?;
            let x = extract_cms_detached_parts(&pkcs7)?;
            fs::write(out_attrs, x.signed_attrs_der)?;
            fs::write(out_sig, x.signature_der)?;
            fs::write(out_leaf, x.leaf_cert_pem)?;
            println!("cms extracted");
        }
        _ => {
            eprintln!("unknown command");
            std::process::exit(2);
        }
    }

    Ok(())
}

fn parse_flags(args: &[String]) -> HashMap<String, String> {
    let mut m = HashMap::new();
    let mut i = 0;
    while i < args.len() {
        let k = &args[i];
        if k.starts_with("--") && i + 1 < args.len() {
            m.insert(k.clone(), args[i + 1].clone());
            i += 2;
        } else {
            i += 1;
        }
    }
    m
}

fn load_intermediate_certs(flags: &HashMap<String, String>) -> Result<Vec<Vec<u8>>, Box<dyn std::error::Error>> {
    let support = flags
        .get("--support-intermediates")
        .map(|s| s == "true" || s == "1" || s.eq_ignore_ascii_case("yes"))
        .unwrap_or(true);
    if !support {
        if let Some(c) = flags.get("--intermediate-count") {
            let cnt: usize = c.parse()?;
            if cnt != 0 {
                return Err("support-intermediates=false 但 intermediate-count 非 0".into());
            }
        }
        return Ok(vec![]);
    }

    let mut paths: Vec<String> = Vec::new();
    if let Some(single) = flags.get("--intermediate-cert") {
        paths.push(single.clone());
    }
    if let Some(list) = flags.get("--intermediate-certs") {
        for p in list.split(',') {
            let t = p.trim();
            if !t.is_empty() {
                paths.push(t.to_string());
            }
        }
    }

    let expected = flags
        .get("--intermediate-count")
        .map(|s| s.parse::<usize>())
        .transpose()?
        .unwrap_or(paths.len());
    if expected != paths.len() {
        return Err(format!(
            "intermediate-count={} 与实际提供证书数量={} 不一致",
            expected,
            paths.len()
        )
        .into());
    }

    let mut certs = Vec::with_capacity(paths.len());
    for p in paths {
        certs.push(fs::read(p)?);
    }
    Ok(certs)
}

fn der_to_pem(label: &str, der: &[u8]) -> String {
    use base64::Engine;
    use base64::engine::general_purpose::STANDARD;
    let b64 = STANDARD.encode(der);
    let mut out = String::new();
    out.push_str(&format!("-----BEGIN {}-----\n", label));
    for chunk in b64.as_bytes().chunks(64) {
        out.push_str(std::str::from_utf8(chunk).unwrap_or(""));
        out.push('\n');
    }
    out.push_str(&format!("-----END {}-----\n", label));
    out
}

fn parse_algorithm(v: Option<&str>) -> Result<CmsSignAlgorithm, Box<dyn std::error::Error>> {
    match v.unwrap_or("sm2-sm3") {
        "sm2-sm3" => Ok(CmsSignAlgorithm::Sm2WithSm3),
        "rsa-sha256" => Ok(CmsSignAlgorithm::Rsa2048WithSha256),
        "ecdsa-sha256" => Ok(CmsSignAlgorithm::EcdsaWithSha256),
        x => Err(format!(
            "invalid --algorithm {x}, expected sm2-sm3|rsa-sha256|ecdsa-sha256"
        )
        .into()),
    }
}