#include "rsmp4decrypt.h"
#include "Ap4.h"
#include <cstdlib>
#include <cstring>
struct RsMp4DecryptContext {
AP4_ProtectionKeyMap key_map;
};
class RustProgressListener : public AP4_Processor::ProgressListener {
public:
RustProgressListener(rsmp4decrypt_progress_callback callback, void *user_data)
: callback_(callback), user_data_(user_data) {}
AP4_Result OnProgress(unsigned int step, unsigned int total) override {
if (callback_) {
callback_(step, total, user_data_);
}
return AP4_SUCCESS;
}
private:
rsmp4decrypt_progress_callback callback_;
void *user_data_;
};
static AP4_Processor *
create_decrypting_processor(AP4_ProtectionKeyMap *key_map, AP4_ByteStream &input,
AP4_ByteStream *fragments_info) {
AP4_Processor *processor = NULL;
AP4_File input_file(fragments_info ? *fragments_info : input);
AP4_FtypAtom *ftyp = input_file.GetFileType();
if (ftyp) {
if (ftyp->GetMajorBrand() == AP4_OMA_DCF_BRAND_ODCF ||
ftyp->HasCompatibleBrand(AP4_OMA_DCF_BRAND_ODCF)) {
processor = new AP4_OmaDcfDecryptingProcessor(key_map);
} else if (ftyp->GetMajorBrand() == AP4_MARLIN_BRAND_MGSV ||
ftyp->HasCompatibleBrand(AP4_MARLIN_BRAND_MGSV)) {
processor = new AP4_MarlinIpmpDecryptingProcessor(key_map);
} else if (ftyp->GetMajorBrand() == AP4_PIFF_BRAND ||
ftyp->HasCompatibleBrand(AP4_PIFF_BRAND)) {
processor = new AP4_CencDecryptingProcessor(key_map);
}
}
if (processor == NULL) {
AP4_Movie *movie = input_file.GetMovie();
if (movie) {
AP4_List<AP4_Track> &tracks = movie->GetTracks();
for (unsigned int index = 0; index < tracks.ItemCount(); ++index) {
AP4_Track *track = NULL;
tracks.Get(index, track);
if (track == NULL) {
continue;
}
AP4_SampleDescription *sample_description =
track->GetSampleDescription(0);
if (sample_description == NULL ||
sample_description->GetType() !=
AP4_SampleDescription::TYPE_PROTECTED) {
continue;
}
AP4_ProtectedSampleDescription *protected_sample_description =
AP4_DYNAMIC_CAST(AP4_ProtectedSampleDescription,
sample_description);
if (protected_sample_description == NULL) {
continue;
}
AP4_UI32 scheme = protected_sample_description->GetSchemeType();
if (scheme == AP4_PROTECTION_SCHEME_TYPE_CENC ||
scheme == AP4_PROTECTION_SCHEME_TYPE_CBC1 ||
scheme == AP4_PROTECTION_SCHEME_TYPE_CENS ||
scheme == AP4_PROTECTION_SCHEME_TYPE_CBCS) {
processor = new AP4_CencDecryptingProcessor(key_map);
break;
}
}
}
}
if (processor == NULL) {
processor = new AP4_StandardDecryptingProcessor(key_map);
}
return processor;
}
static AP4_Result process_streams(RsMp4DecryptContext *ctx,
AP4_ByteStream &input,
AP4_ByteStream &output,
AP4_ByteStream *fragments_info,
rsmp4decrypt_progress_callback progress,
void *user_data) {
if (ctx == NULL) {
return AP4_ERROR_INVALID_PARAMETERS;
}
AP4_Processor *processor =
create_decrypting_processor(&ctx->key_map, input, fragments_info);
if (processor == NULL) {
return AP4_ERROR_OUT_OF_MEMORY;
}
AP4_Result result = fragments_info ? fragments_info->Seek(0) : input.Seek(0);
if (AP4_FAILED(result)) {
delete processor;
return result;
}
RustProgressListener listener(progress, user_data);
AP4_Processor::ProgressListener *listener_ptr = progress ? &listener : NULL;
result = fragments_info ? processor->Process(input, output, *fragments_info,
listener_ptr)
: processor->Process(input, output, listener_ptr);
delete processor;
return result;
}
extern "C" RsMp4DecryptContext *
rsmp4decrypt_context_new(const RsMp4DecryptKeyEntry *keys,
unsigned int key_count) {
if (keys == NULL || key_count == 0) {
return NULL;
}
RsMp4DecryptContext *ctx = new RsMp4DecryptContext();
for (unsigned int index = 0; index < key_count; ++index) {
const RsMp4DecryptKeyEntry &entry = keys[index];
if (entry.kind == RSMP4DECRYPT_KEY_KIND_TRACK_ID) {
ctx->key_map.SetKey(entry.track_id, entry.key, 16);
} else if (entry.kind == RSMP4DECRYPT_KEY_KIND_KID) {
ctx->key_map.SetKeyForKid(entry.kid, entry.key, 16);
} else {
delete ctx;
return NULL;
}
}
return ctx;
}
extern "C" void rsmp4decrypt_context_free(RsMp4DecryptContext *ctx) {
delete ctx;
}
extern "C" void rsmp4decrypt_free(unsigned char *ptr) { free(ptr); }
extern "C" const char *rsmp4decrypt_result_text(int result) {
return AP4_ResultText(result);
}
extern "C" int rsmp4decrypt_decrypt_file(
RsMp4DecryptContext *ctx, const char *input_path, const char *output_path,
const char *fragments_info_path, rsmp4decrypt_progress_callback progress,
void *user_data, unsigned int *stage) {
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_NONE;
}
if (ctx == NULL || input_path == NULL || output_path == NULL) {
return AP4_ERROR_INVALID_PARAMETERS;
}
AP4_Result result;
AP4_ByteStream *input = NULL;
result = AP4_FileByteStream::Create(input_path,
AP4_FileByteStream::STREAM_MODE_READ,
input);
if (AP4_FAILED(result)) {
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_OPEN_INPUT;
}
return result;
}
AP4_ByteStream *output = NULL;
result = AP4_FileByteStream::Create(output_path,
AP4_FileByteStream::STREAM_MODE_WRITE,
output);
if (AP4_FAILED(result)) {
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_OPEN_OUTPUT;
}
input->Release();
return result;
}
AP4_ByteStream *fragments_info = NULL;
if (fragments_info_path != NULL) {
result = AP4_FileByteStream::Create(fragments_info_path,
AP4_FileByteStream::STREAM_MODE_READ,
fragments_info);
if (AP4_FAILED(result)) {
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_OPEN_FRAGMENTS_INFO;
}
input->Release();
output->Release();
return result;
}
}
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_PROCESS;
}
result = process_streams(ctx, *input, *output, fragments_info, progress,
user_data);
input->Release();
output->Release();
if (fragments_info != NULL) {
fragments_info->Release();
}
return result;
}
extern "C" int rsmp4decrypt_decrypt_memory(
RsMp4DecryptContext *ctx, const unsigned char *input_data,
unsigned int input_size, const unsigned char *fragments_info_data,
unsigned int fragments_info_size, rsmp4decrypt_progress_callback progress,
void *user_data, unsigned char **output_data, unsigned int *output_size,
unsigned int *stage) {
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_NONE;
}
if (ctx == NULL || input_data == NULL || output_data == NULL ||
output_size == NULL) {
return AP4_ERROR_INVALID_PARAMETERS;
}
AP4_ByteStream *input = new AP4_MemoryByteStream(input_data, input_size);
AP4_ByteStream *fragments_info = NULL;
if (fragments_info_data != NULL && fragments_info_size > 0) {
fragments_info =
new AP4_MemoryByteStream(fragments_info_data, fragments_info_size);
}
AP4_MemoryByteStream *output = new AP4_MemoryByteStream();
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_PROCESS;
}
AP4_Result result = process_streams(ctx, *input, *output, fragments_info,
progress, user_data);
input->Release();
if (fragments_info != NULL) {
fragments_info->Release();
}
if (AP4_FAILED(result)) {
output->Release();
return result;
}
AP4_LargeSize data_size = output->GetDataSize();
if (data_size > 0xFFFFFFFFu) {
output->Release();
return AP4_ERROR_OUT_OF_RANGE;
}
*output_size = static_cast<unsigned int>(data_size);
if (*output_size == 0) {
*output_data = NULL;
output->Release();
return AP4_SUCCESS;
}
*output_data = static_cast<unsigned char *>(malloc(*output_size));
if (*output_data == NULL) {
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_COPY_OUTPUT;
}
output->Release();
return AP4_ERROR_OUT_OF_MEMORY;
}
if (stage != NULL) {
*stage = RSMP4DECRYPT_STAGE_COPY_OUTPUT;
}
memcpy(*output_data, output->GetData(), *output_size);
output->Release();
return AP4_SUCCESS;
}