#include "rc4_cryption_impl.h"
#include "../../misc.h"
#include <cstring>
using namespace lite;
void RC4Impl::init_rc4_state() {
rc4::RC4RandStream enc_stream(m_enc_key);
rc4::FastHash64 dechash(m_hash_key);
size_t offset = 0;
std::vector<uint64_t> buffer(128);
size_t remaining = m_model_length - sizeof(uint64_t);
while (remaining > 0) {
size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t));
memcpy(buffer.data(), static_cast<const uint8_t*>(m_model_mem) + offset,
toread);
offset += toread;
remaining -= toread;
for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) {
uint64_t value = buffer[i];
value ^= enc_stream.next64();
dechash.feed(value);
}
}
uint64_t hashvalue;
memcpy(&hashvalue, static_cast<const uint8_t*>(m_model_mem) + offset,
sizeof(hashvalue));
offset += sizeof(hashvalue);
hashvalue ^= dechash.get() ^ enc_stream.next64();
m_state.hash_stream.reset(hashvalue);
m_state.enc_stream.reset(m_enc_key);
}
std::vector<uint8_t> RC4Impl::decrypt_model() {
std::vector<uint8_t> result(m_model_length, 0);
uint8_t* ptr = result.data();
for (size_t i = 0; i < m_model_length; ++i) {
ptr[i] = static_cast<const uint8_t*>(m_model_mem)[i];
ptr[i] ^= m_state.hash_stream.next8() ^ m_state.enc_stream.next8();
}
return result;
}
std::vector<uint8_t> RC4Impl::encrypt_model() {
size_t total_length =
(m_model_length + (sizeof(size_t) - 1)) / sizeof(size_t) * sizeof(size_t);
std::vector<uint8_t> pad_model(total_length, 0);
memcpy(pad_model.data(), m_model_mem, m_model_length);
rc4::FastHash64 plainhash(m_hash_key);
uint64_t* ptr = reinterpret_cast<uint64_t*>(pad_model.data());
size_t len = pad_model.size() / sizeof(uint64_t);
for (size_t i = 0; i < len; ++i)
plainhash.feed(ptr[i]);
uint64_t plainhash_value = plainhash.get();
rc4::RC4RandStream hash_enc(plainhash_value);
rc4::RC4RandStream outmost_enc(m_enc_key);
rc4::FastHash64 afterhashenc_hash(m_hash_key);
for (size_t i = 0; i < len; ++i) {
uint64_t value = ptr[i] ^ hash_enc.next64();
afterhashenc_hash.feed(value);
ptr[i] = value ^ outmost_enc.next64();
}
uint64_t protected_hash =
plainhash_value ^ afterhashenc_hash.get() ^ outmost_enc.next64();
size_t end = pad_model.size();
pad_model.resize(pad_model.size() + sizeof(uint64_t));
ptr = reinterpret_cast<uint64_t*>(&pad_model[end]);
*ptr = protected_hash;
return pad_model;
}
void SimpleFastRC4Impl::init_sfrc4_state() {
rc4::RC4RandStream enc_stream(m_enc_key);
rc4::FastHash64 dechash(m_hash_key);
size_t offset = 0;
std::vector<uint64_t> buffer(128);
size_t remaining = m_model_length - sizeof(uint64_t);
while (remaining > 0) {
size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t));
memcpy(buffer.data(), static_cast<const uint8_t*>(m_model_mem) + offset,
toread);
offset += toread;
remaining -= toread;
for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) {
uint64_t value = buffer[i];
dechash.feed(value);
}
}
uint64_t hashvalue;
memcpy(&hashvalue, static_cast<const uint8_t*>(m_model_mem) + offset,
sizeof(hashvalue));
offset += sizeof(hashvalue);
if (hashvalue != dechash.get())
LITE_THROW(
"The checksum of the file cannot be verified. The file may "
"be encrypted in the wrong algorithm or different keys.");
m_state.hash_stream.reset(m_hash_key);
m_state.enc_stream.reset(m_enc_key);
}
std::vector<uint8_t> SimpleFastRC4Impl::decrypt_model() {
std::vector<uint8_t> result(m_model_length, 0);
uint8_t* ptr = result.data();
for (size_t i = 0; i < m_model_length; ++i) {
ptr[i] = static_cast<const uint8_t*>(m_model_mem)[i];
ptr[i] ^= m_state.enc_stream.next8();
}
return result;
}
std::vector<uint8_t> SimpleFastRC4Impl::encrypt_model() {
size_t total_length =
(m_model_length + (sizeof(size_t) - 1)) / sizeof(size_t) * sizeof(size_t);
std::vector<uint8_t> pad_model(total_length, 0);
memcpy(pad_model.data(), m_model_mem, m_model_length);
rc4::FastHash64 enchash(m_hash_key);
uint64_t* ptr = reinterpret_cast<uint64_t*>(pad_model.data());
size_t len = pad_model.size() / sizeof(uint64_t);
rc4::RC4RandStream out_enc(m_enc_key);
for (size_t i = 0; i < len; ++i) {
ptr[i] = ptr[i] ^ out_enc.next64();
enchash.feed(ptr[i]);
}
uint64_t hash_value = enchash.get();
size_t end = pad_model.size();
pad_model.resize(pad_model.size() + sizeof(uint64_t));
ptr = reinterpret_cast<uint64_t*>(&pad_model[end]);
*ptr = hash_value;
return pad_model;
}