rsmp4decrypt 0.2.0

Rust bindings and a CLI for Bento4 mp4decrypt
#include "rsmp4decrypt.h"

#include "Ap4.h"

#include <cstdlib>
#include <cstring>

struct RsMp4DecryptContext {
  AP4_ProtectionKeyMap key_map;
};

class RustProgressListener : public AP4_Processor::ProgressListener {
public:
  RustProgressListener(rsmp4decrypt_progress_callback callback, void *user_data)
      : callback_(callback), user_data_(user_data) {}

  AP4_Result OnProgress(unsigned int step, unsigned int total) override {
    if (callback_) {
      callback_(step, total, user_data_);
    }
    return AP4_SUCCESS;
  }

private:
  rsmp4decrypt_progress_callback callback_;
  void *user_data_;
};

static AP4_Processor *
create_decrypting_processor(AP4_ProtectionKeyMap *key_map, AP4_ByteStream &input,
                            AP4_ByteStream *fragments_info) {
  AP4_Processor *processor = NULL;
  AP4_File input_file(fragments_info ? *fragments_info : input);
  AP4_FtypAtom *ftyp = input_file.GetFileType();

  if (ftyp) {
    if (ftyp->GetMajorBrand() == AP4_OMA_DCF_BRAND_ODCF ||
        ftyp->HasCompatibleBrand(AP4_OMA_DCF_BRAND_ODCF)) {
      processor = new AP4_OmaDcfDecryptingProcessor(key_map);
    } else if (ftyp->GetMajorBrand() == AP4_MARLIN_BRAND_MGSV ||
               ftyp->HasCompatibleBrand(AP4_MARLIN_BRAND_MGSV)) {
      processor = new AP4_MarlinIpmpDecryptingProcessor(key_map);
    } else if (ftyp->GetMajorBrand() == AP4_PIFF_BRAND ||
               ftyp->HasCompatibleBrand(AP4_PIFF_BRAND)) {
      processor = new AP4_CencDecryptingProcessor(key_map);
    }
  }

  if (processor == NULL) {
    AP4_Movie *movie = input_file.GetMovie();
    if (movie) {
      AP4_List<AP4_Track> &tracks = movie->GetTracks();
      for (unsigned int index = 0; index < tracks.ItemCount(); ++index) {
        AP4_Track *track = NULL;
        tracks.Get(index, track);
        if (track == NULL) {
          continue;
        }

        AP4_SampleDescription *sample_description =
            track->GetSampleDescription(0);
        if (sample_description == NULL ||
            sample_description->GetType() !=
                AP4_SampleDescription::TYPE_PROTECTED) {
          continue;
        }

        AP4_ProtectedSampleDescription *protected_sample_description =
            AP4_DYNAMIC_CAST(AP4_ProtectedSampleDescription,
                             sample_description);
        if (protected_sample_description == NULL) {
          continue;
        }

        AP4_UI32 scheme = protected_sample_description->GetSchemeType();
        if (scheme == AP4_PROTECTION_SCHEME_TYPE_CENC ||
            scheme == AP4_PROTECTION_SCHEME_TYPE_CBC1 ||
            scheme == AP4_PROTECTION_SCHEME_TYPE_CENS ||
            scheme == AP4_PROTECTION_SCHEME_TYPE_CBCS) {
          processor = new AP4_CencDecryptingProcessor(key_map);
          break;
        }
      }
    }
  }

  if (processor == NULL) {
    processor = new AP4_StandardDecryptingProcessor(key_map);
  }

  return processor;
}

static AP4_Result process_streams(RsMp4DecryptContext *ctx,
                                  AP4_ByteStream &input,
                                  AP4_ByteStream &output,
                                  AP4_ByteStream *fragments_info,
                                  rsmp4decrypt_progress_callback progress,
                                  void *user_data) {
  if (ctx == NULL) {
    return AP4_ERROR_INVALID_PARAMETERS;
  }

  AP4_Processor *processor =
      create_decrypting_processor(&ctx->key_map, input, fragments_info);
  if (processor == NULL) {
    return AP4_ERROR_OUT_OF_MEMORY;
  }

  AP4_Result result = fragments_info ? fragments_info->Seek(0) : input.Seek(0);
  if (AP4_FAILED(result)) {
    delete processor;
    return result;
  }

  RustProgressListener listener(progress, user_data);
  AP4_Processor::ProgressListener *listener_ptr = progress ? &listener : NULL;

  result = fragments_info ? processor->Process(input, output, *fragments_info,
                                               listener_ptr)
                          : processor->Process(input, output, listener_ptr);

  delete processor;
  return result;
}

extern "C" RsMp4DecryptContext *
rsmp4decrypt_context_new(const RsMp4DecryptKeyEntry *keys,
                         unsigned int key_count) {
  if (keys == NULL || key_count == 0) {
    return NULL;
  }

  RsMp4DecryptContext *ctx = new RsMp4DecryptContext();
  for (unsigned int index = 0; index < key_count; ++index) {
    const RsMp4DecryptKeyEntry &entry = keys[index];
    if (entry.kind == RSMP4DECRYPT_KEY_KIND_TRACK_ID) {
      ctx->key_map.SetKey(entry.track_id, entry.key, 16);
    } else if (entry.kind == RSMP4DECRYPT_KEY_KIND_KID) {
      ctx->key_map.SetKeyForKid(entry.kid, entry.key, 16);
    } else {
      delete ctx;
      return NULL;
    }
  }

  return ctx;
}

extern "C" void rsmp4decrypt_context_free(RsMp4DecryptContext *ctx) {
  delete ctx;
}

extern "C" void rsmp4decrypt_free(unsigned char *ptr) { free(ptr); }

extern "C" const char *rsmp4decrypt_result_text(int result) {
  return AP4_ResultText(result);
}

extern "C" int rsmp4decrypt_decrypt_file(
    RsMp4DecryptContext *ctx, const char *input_path, const char *output_path,
    const char *fragments_info_path, rsmp4decrypt_progress_callback progress,
    void *user_data, unsigned int *stage) {
  if (stage != NULL) {
    *stage = RSMP4DECRYPT_STAGE_NONE;
  }
  if (ctx == NULL || input_path == NULL || output_path == NULL) {
    return AP4_ERROR_INVALID_PARAMETERS;
  }

  AP4_Result result;
  AP4_ByteStream *input = NULL;
  result = AP4_FileByteStream::Create(input_path,
                                      AP4_FileByteStream::STREAM_MODE_READ,
                                      input);
  if (AP4_FAILED(result)) {
    if (stage != NULL) {
      *stage = RSMP4DECRYPT_STAGE_OPEN_INPUT;
    }
    return result;
  }

  AP4_ByteStream *output = NULL;
  result = AP4_FileByteStream::Create(output_path,
                                      AP4_FileByteStream::STREAM_MODE_WRITE,
                                      output);
  if (AP4_FAILED(result)) {
    if (stage != NULL) {
      *stage = RSMP4DECRYPT_STAGE_OPEN_OUTPUT;
    }
    input->Release();
    return result;
  }

  AP4_ByteStream *fragments_info = NULL;
  if (fragments_info_path != NULL) {
    result = AP4_FileByteStream::Create(fragments_info_path,
                                        AP4_FileByteStream::STREAM_MODE_READ,
                                        fragments_info);
    if (AP4_FAILED(result)) {
      if (stage != NULL) {
        *stage = RSMP4DECRYPT_STAGE_OPEN_FRAGMENTS_INFO;
      }
      input->Release();
      output->Release();
      return result;
    }
  }

  if (stage != NULL) {
    *stage = RSMP4DECRYPT_STAGE_PROCESS;
  }
  result = process_streams(ctx, *input, *output, fragments_info, progress,
                           user_data);

  input->Release();
  output->Release();
  if (fragments_info != NULL) {
    fragments_info->Release();
  }

  return result;
}

extern "C" int rsmp4decrypt_decrypt_memory(
    RsMp4DecryptContext *ctx, const unsigned char *input_data,
    unsigned int input_size, const unsigned char *fragments_info_data,
    unsigned int fragments_info_size, rsmp4decrypt_progress_callback progress,
    void *user_data, unsigned char **output_data, unsigned int *output_size,
    unsigned int *stage) {
  if (stage != NULL) {
    *stage = RSMP4DECRYPT_STAGE_NONE;
  }
  if (ctx == NULL || input_data == NULL || output_data == NULL ||
      output_size == NULL) {
    return AP4_ERROR_INVALID_PARAMETERS;
  }

  AP4_ByteStream *input = new AP4_MemoryByteStream(input_data, input_size);
  AP4_ByteStream *fragments_info = NULL;
  if (fragments_info_data != NULL && fragments_info_size > 0) {
    fragments_info =
        new AP4_MemoryByteStream(fragments_info_data, fragments_info_size);
  }

  AP4_MemoryByteStream *output = new AP4_MemoryByteStream();
  if (stage != NULL) {
    *stage = RSMP4DECRYPT_STAGE_PROCESS;
  }
  AP4_Result result = process_streams(ctx, *input, *output, fragments_info,
                                      progress, user_data);

  input->Release();
  if (fragments_info != NULL) {
    fragments_info->Release();
  }

  if (AP4_FAILED(result)) {
    output->Release();
    return result;
  }

  AP4_LargeSize data_size = output->GetDataSize();
  if (data_size > 0xFFFFFFFFu) {
    output->Release();
    return AP4_ERROR_OUT_OF_RANGE;
  }

  *output_size = static_cast<unsigned int>(data_size);
  if (*output_size == 0) {
    *output_data = NULL;
    output->Release();
    return AP4_SUCCESS;
  }

  *output_data = static_cast<unsigned char *>(malloc(*output_size));
  if (*output_data == NULL) {
    if (stage != NULL) {
      *stage = RSMP4DECRYPT_STAGE_COPY_OUTPUT;
    }
    output->Release();
    return AP4_ERROR_OUT_OF_MEMORY;
  }

  if (stage != NULL) {
    *stage = RSMP4DECRYPT_STAGE_COPY_OUTPUT;
  }
  memcpy(*output_data, output->GetData(), *output_size);
  output->Release();
  return AP4_SUCCESS;
}