rsmp4decrypt 0.2.0

Rust bindings and a CLI for Bento4 mp4decrypt
#include "rsmp4decrypt.h"

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <vector>

static const char *stage_name(unsigned int stage) {
  switch (stage) {
  case RSMP4DECRYPT_STAGE_OPEN_INPUT:
    return "opening input media";
  case RSMP4DECRYPT_STAGE_OPEN_OUTPUT:
    return "opening output media";
  case RSMP4DECRYPT_STAGE_OPEN_FRAGMENTS_INFO:
    return "opening fragments info media";
  case RSMP4DECRYPT_STAGE_PROCESS:
    return "decrypting media";
  case RSMP4DECRYPT_STAGE_COPY_OUTPUT:
    return "finalizing decrypted output";
  default:
    return "processing media";
  }
}

static void print_usage_error(const char *message) {
  std::fprintf(stderr, "RSMP4DECRYPT_WORKER_ERROR\t%s\n", message);
}

static void print_bento4_error(unsigned int stage, int code) {
  const char *name = rsmp4decrypt_result_text(code);
  if (name == NULL) {
    name = "UNKNOWN_BENTO4_ERROR";
  }

  std::fprintf(stderr, "RSMP4DECRYPT_ERROR\t%s\t%d\t%s\n", stage_name(stage),
               code, name);
}

static int hex_nibble(char value) {
  if (value >= '0' && value <= '9') {
    return value - '0';
  }
  if (value >= 'a' && value <= 'f') {
    return 10 + (value - 'a');
  }
  if (value >= 'A' && value <= 'F') {
    return 10 + (value - 'A');
  }
  return -1;
}

static bool parse_hex_16(const std::string &input, unsigned char output[16]) {
  if (input.size() != 32) {
    return false;
  }

  for (size_t index = 0; index < 16; ++index) {
    int high = hex_nibble(input[index * 2]);
    int low = hex_nibble(input[index * 2 + 1]);
    if (high < 0 || low < 0) {
      return false;
    }
    output[index] = static_cast<unsigned char>((high << 4) | low);
  }

  return true;
}

static bool parse_key_spec(const std::string &spec, RsMp4DecryptKeyEntry &entry) {
  size_t separator = spec.find(':');
  if (separator == std::string::npos) {
    return false;
  }

  std::string id_text = spec.substr(0, separator);
  std::string key_text = spec.substr(separator + 1);

  std::memset(&entry, 0, sizeof(entry));
  if (!parse_hex_16(key_text, entry.key)) {
    return false;
  }

  if (id_text.size() == 32) {
    entry.kind = RSMP4DECRYPT_KEY_KIND_KID;
    return parse_hex_16(id_text, entry.kid);
  }

  char *tail = NULL;
  unsigned long track_id = std::strtoul(id_text.c_str(), &tail, 10);
  if (tail == NULL || *tail != '\0' || track_id > 0xFFFFFFFFu) {
    return false;
  }

  entry.kind = RSMP4DECRYPT_KEY_KIND_TRACK_ID;
  entry.track_id = static_cast<unsigned int>(track_id);
  return true;
}

int main(int argc, char **argv) {
  std::vector<std::string> key_specs;
  const char *fragments_info = NULL;
  const char *input = NULL;
  const char *output = NULL;

  for (int index = 1; index < argc; ++index) {
    std::string arg = argv[index];

    if (arg == "--key") {
      if (index + 1 >= argc) {
        print_usage_error("--key requires a value");
        return 2;
      }
      key_specs.push_back(argv[++index]);
      continue;
    }

    if (arg == "--fragments-info") {
      if (index + 1 >= argc) {
        print_usage_error("--fragments-info requires a value");
        return 2;
      }
      fragments_info = argv[++index];
      continue;
    }

    if (arg.rfind("--", 0) == 0) {
      print_usage_error("unknown flag");
      return 2;
    }

    if (input == NULL) {
      input = argv[index];
    } else if (output == NULL) {
      output = argv[index];
    } else {
      print_usage_error("expected only input and output paths");
      return 2;
    }
  }

  if (key_specs.empty()) {
    print_usage_error("at least one key is required");
    return 2;
  }
  if (input == NULL || output == NULL) {
    print_usage_error("input and output paths are required");
    return 2;
  }

  std::vector<RsMp4DecryptKeyEntry> entries(key_specs.size());
  for (size_t index = 0; index < key_specs.size(); ++index) {
    if (!parse_key_spec(key_specs[index], entries[index])) {
      print_usage_error("invalid key specification");
      return 2;
    }
  }

  RsMp4DecryptContext *context =
      rsmp4decrypt_context_new(entries.data(),
                               static_cast<unsigned int>(entries.size()));
  if (context == NULL) {
    print_usage_error("failed to create decryptor context");
    return 2;
  }

  unsigned int stage = RSMP4DECRYPT_STAGE_NONE;
  int result = rsmp4decrypt_decrypt_file(context, input, output, fragments_info,
                                         NULL, NULL, &stage);
  rsmp4decrypt_context_free(context);

  if (result != 0) {
    print_bento4_error(stage, result);
    return 1;
  }

  return 0;
}