#include "platform.h"
#include "Pzstd.h"
#include "SkippableFrame.h"
#include "utils/FileSystem.h"
#include "utils/Range.h"
#include "utils/ScopeGuard.h"
#include "utils/ThreadPool.h"
#include "utils/WorkQueue.h"
#include <chrono>
#include <cinttypes>
#include <cstddef>
#include <cstdio>
#include <memory>
#include <string>
namespace pzstd {
namespace {
#ifdef _WIN32
const std::string nullOutput = "nul";
#else
const std::string nullOutput = "/dev/null";
#endif
}
using std::size_t;
static std::uintmax_t fileSizeOrZero(const std::string &file) {
if (file == "-") {
return 0;
}
std::error_code ec;
auto size = file_size(file, ec);
if (ec) {
size = 0;
}
return size;
}
static std::uint64_t handleOneInput(const Options &options,
const std::string &inputFile,
FILE* inputFd,
const std::string &outputFile,
FILE* outputFd,
SharedState& state) {
auto inputSize = fileSizeOrZero(inputFile);
WorkQueue<std::shared_ptr<BufferWorkQueue>> outs{options.numThreads + 1};
std::uint64_t bytesRead;
std::uint64_t bytesWritten;
{
ThreadPool executor(options.numThreads);
ThreadPool readExecutor(1);
if (!options.decompress) {
readExecutor.add(
[&state, &outs, &executor, inputFd, inputSize, &options, &bytesRead] {
bytesRead = asyncCompressChunks(
state,
outs,
executor,
inputFd,
inputSize,
options.numThreads,
options.determineParameters());
});
bytesWritten = writeFile(state, outs, outputFd, options.decompress);
} else {
readExecutor.add([&state, &outs, &executor, inputFd, &bytesRead] {
bytesRead = asyncDecompressFrames(state, outs, executor, inputFd);
});
bytesWritten = writeFile(state, outs, outputFd, options.decompress);
}
}
if (!state.errorHolder.hasError()) {
std::string inputFileName = inputFile == "-" ? "stdin" : inputFile;
std::string outputFileName = outputFile == "-" ? "stdout" : outputFile;
if (!options.decompress) {
double ratio = static_cast<double>(bytesWritten) /
static_cast<double>(bytesRead + !bytesRead);
state.log(kLogInfo, "%-20s :%6.2f%% (%6" PRIu64 " => %6" PRIu64
" bytes, %s)\n",
inputFileName.c_str(), ratio * 100, bytesRead, bytesWritten,
outputFileName.c_str());
} else {
state.log(kLogInfo, "%-20s: %" PRIu64 " bytes \n",
inputFileName.c_str(),bytesWritten);
}
}
return bytesWritten;
}
static FILE *openInputFile(const std::string &inputFile,
ErrorHolder &errorHolder) {
if (inputFile == "-") {
SET_BINARY_MODE(stdin);
return stdin;
}
{
std::error_code ec;
if (is_directory(inputFile, ec)) {
errorHolder.setError("Output file is a directory -- ignored");
return nullptr;
}
}
auto inputFd = std::fopen(inputFile.c_str(), "rb");
if (!errorHolder.check(inputFd != nullptr, "Failed to open input file")) {
return nullptr;
}
return inputFd;
}
static FILE *openOutputFile(const Options &options,
const std::string &outputFile,
SharedState& state) {
if (outputFile == "-") {
SET_BINARY_MODE(stdout);
return stdout;
}
if (!options.overwrite && outputFile != nullOutput) {
auto outputFd = std::fopen(outputFile.c_str(), "rb");
if (outputFd != nullptr) {
std::fclose(outputFd);
if (!state.log.logsAt(kLogInfo)) {
state.errorHolder.setError("Output file exists");
return nullptr;
}
state.log(
kLogInfo,
"pzstd: %s already exists; do you wish to overwrite (y/n) ? ",
outputFile.c_str());
int c = getchar();
if (c != 'y' && c != 'Y') {
state.errorHolder.setError("Not overwritten");
return nullptr;
}
}
}
auto outputFd = std::fopen(outputFile.c_str(), "wb");
if (!state.errorHolder.check(
outputFd != nullptr, "Failed to open output file")) {
return nullptr;
}
return outputFd;
}
int pzstdMain(const Options &options) {
int returnCode = 0;
SharedState state(options);
for (const auto& input : options.inputFiles) {
auto printErrorGuard = makeScopeGuard([&] {
if (state.errorHolder.hasError()) {
returnCode = 1;
state.log(kLogError, "pzstd: %s: %s.\n", input.c_str(),
state.errorHolder.getError().c_str());
}
});
auto inputFd = openInputFile(input, state.errorHolder);
if (inputFd == nullptr) {
continue;
}
auto closeInputGuard = makeScopeGuard([&] { std::fclose(inputFd); });
auto outputFile = options.getOutputFile(input);
if (!state.errorHolder.check(outputFile != "",
"Input file does not have extension .zst")) {
continue;
}
auto outputFd = openOutputFile(options, outputFile, state);
if (outputFd == nullptr) {
continue;
}
auto closeOutputGuard = makeScopeGuard([&] { std::fclose(outputFd); });
handleOneInput(options, input, inputFd, outputFile, outputFd, state);
if (state.errorHolder.hasError()) {
continue;
}
if (!options.keepSource) {
if (!state.errorHolder.check(std::fclose(inputFd) == 0,
"Failed to close input file")) {
continue;
}
closeInputGuard.dismiss();
if (!state.errorHolder.check(std::fclose(outputFd) == 0,
"Failed to close output file")) {
continue;
}
closeOutputGuard.dismiss();
if (std::remove(input.c_str()) != 0) {
state.errorHolder.setError("Failed to remove input file");
continue;
}
}
}
return returnCode;
}
static ZSTD_inBuffer makeZstdInBuffer(const Buffer& buffer) {
return ZSTD_inBuffer{buffer.data(), buffer.size(), 0};
}
void advance(Buffer& buffer, ZSTD_inBuffer& inBuffer) {
auto pos = inBuffer.pos;
inBuffer.src = static_cast<const unsigned char*>(inBuffer.src) + pos;
inBuffer.size -= pos;
inBuffer.pos = 0;
return buffer.advance(pos);
}
static ZSTD_outBuffer makeZstdOutBuffer(Buffer& buffer) {
return ZSTD_outBuffer{buffer.data(), buffer.size(), 0};
}
Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) {
auto pos = outBuffer.pos;
outBuffer.dst = static_cast<unsigned char*>(outBuffer.dst) + pos;
outBuffer.size -= pos;
outBuffer.pos = 0;
return buffer.splitAt(pos);
}
static void compress(
SharedState& state,
std::shared_ptr<BufferWorkQueue> in,
std::shared_ptr<BufferWorkQueue> out,
size_t maxInputSize) {
auto& errorHolder = state.errorHolder;
auto guard = makeScopeGuard([&] { out->finish(); });
auto ctx = state.cStreamPool->get();
if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) {
return;
}
{
auto err = ZSTD_CCtx_reset(ctx.get(), ZSTD_reset_session_only);
if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
return;
}
}
auto outBuffer = Buffer(ZSTD_compressBound(maxInputSize));
auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
{
Buffer inBuffer;
while (in->pop(inBuffer) && !errorHolder.hasError()) {
auto zstdInBuffer = makeZstdInBuffer(inBuffer);
while (!inBuffer.empty() && !errorHolder.hasError()) {
if (!errorHolder.check(
!outBuffer.empty(), "ZSTD_compressBound() was too small")) {
return;
}
auto err =
ZSTD_compressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
return;
}
out->push(split(outBuffer, zstdOutBuffer));
advance(inBuffer, zstdInBuffer);
}
}
}
size_t bytesLeft;
do {
if (!errorHolder.check(
!outBuffer.empty(), "ZSTD_compressBound() was too small")) {
return;
}
bytesLeft = ZSTD_endStream(ctx.get(), &zstdOutBuffer);
if (!errorHolder.check(
!ZSTD_isError(bytesLeft), ZSTD_getErrorName(bytesLeft))) {
return;
}
out->push(split(outBuffer, zstdOutBuffer));
} while (bytesLeft != 0 && !errorHolder.hasError());
}
static size_t calculateStep(
std::uintmax_t size,
size_t numThreads,
const ZSTD_parameters ¶ms) {
(void)size;
(void)numThreads;
return size_t{1} << (params.cParams.windowLog + 2);
}
namespace {
enum class FileStatus { Continue, Done, Error };
FileStatus fileStatus(FILE* fd) {
if (std::feof(fd)) {
return FileStatus::Done;
} else if (std::ferror(fd)) {
return FileStatus::Error;
}
return FileStatus::Continue;
}
}
static FileStatus
readData(BufferWorkQueue& queue, size_t chunkSize, size_t size, FILE* fd,
std::uint64_t *totalBytesRead) {
Buffer buffer(size);
while (!buffer.empty()) {
auto bytesRead =
std::fread(buffer.data(), 1, std::min(chunkSize, buffer.size()), fd);
*totalBytesRead += bytesRead;
queue.push(buffer.splitAt(bytesRead));
auto status = fileStatus(fd);
if (status != FileStatus::Continue) {
return status;
}
}
return FileStatus::Continue;
}
std::uint64_t asyncCompressChunks(
SharedState& state,
WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks,
ThreadPool& executor,
FILE* fd,
std::uintmax_t size,
size_t numThreads,
ZSTD_parameters params) {
auto chunksGuard = makeScopeGuard([&] { chunks.finish(); });
std::uint64_t bytesRead = 0;
size_t step = calculateStep(size, numThreads, params);
state.log(kLogDebug, "Chosen frame size: %zu\n", step);
auto status = FileStatus::Continue;
while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
auto in = std::make_shared<BufferWorkQueue>();
auto inGuard = makeScopeGuard([&] { in->finish(); });
auto out = std::make_shared<BufferWorkQueue>();
executor.add([&state, in, out, step] {
return compress(
state, std::move(in), std::move(out), step);
});
chunks.push(std::move(out));
state.log(kLogVerbose, "%s\n", "Starting a new frame");
status = readData(*in, ZSTD_CStreamInSize(), step, fd, &bytesRead);
}
state.errorHolder.check(status != FileStatus::Error, "Error reading input");
return bytesRead;
}
static void decompress(
SharedState& state,
std::shared_ptr<BufferWorkQueue> in,
std::shared_ptr<BufferWorkQueue> out) {
auto& errorHolder = state.errorHolder;
auto guard = makeScopeGuard([&] { out->finish(); });
auto ctx = state.dStreamPool->get();
if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) {
return;
}
{
auto err = ZSTD_DCtx_reset(ctx.get(), ZSTD_reset_session_only);
if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
return;
}
}
const size_t outSize = ZSTD_DStreamOutSize();
Buffer inBuffer;
size_t returnCode = 0;
while (in->pop(inBuffer) && !errorHolder.hasError()) {
auto zstdInBuffer = makeZstdInBuffer(inBuffer);
while (!inBuffer.empty() && !errorHolder.hasError()) {
Buffer outBuffer(outSize);
auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
returnCode =
ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
if (!errorHolder.check(
!ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
return;
}
out->push(split(outBuffer, zstdOutBuffer));
advance(inBuffer, zstdInBuffer);
if (returnCode == 0) {
ZSTD_initDStream(ctx.get());
}
}
}
if (!errorHolder.check(returnCode <= 1, "Incomplete block")) {
return;
}
while (returnCode == 1) {
Buffer outBuffer(outSize);
auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
ZSTD_inBuffer zstdInBuffer{nullptr, 0, 0};
returnCode =
ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
if (!errorHolder.check(
!ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
return;
}
out->push(split(outBuffer, zstdOutBuffer));
}
}
std::uint64_t asyncDecompressFrames(
SharedState& state,
WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames,
ThreadPool& executor,
FILE* fd) {
auto framesGuard = makeScopeGuard([&] { frames.finish(); });
std::uint64_t totalBytesRead = 0;
const size_t chunkSize = ZSTD_DStreamInSize();
auto status = FileStatus::Continue;
while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
auto in = std::make_shared<BufferWorkQueue>();
auto inGuard = makeScopeGuard([&] { in->finish(); });
auto out = std::make_shared<BufferWorkQueue>();
size_t frameSize;
{
Buffer buffer(SkippableFrame::kSize);
auto bytesRead = std::fread(buffer.data(), 1, buffer.size(), fd);
totalBytesRead += bytesRead;
status = fileStatus(fd);
if (bytesRead == 0 && status != FileStatus::Continue) {
break;
}
buffer.subtract(buffer.size() - bytesRead);
frameSize = SkippableFrame::tryRead(buffer.range());
in->push(std::move(buffer));
}
if (frameSize == 0) {
in->setMaxSize(64);
out->setMaxSize(64);
}
executor.add([&state, in, out] {
return decompress(state, std::move(in), std::move(out));
});
frames.push(std::move(out));
if (frameSize == 0) {
state.log(kLogVerbose, "%s\n",
"Input not in pzstd format, falling back to serial decompression");
while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
status = readData(*in, chunkSize, chunkSize, fd, &totalBytesRead);
}
break;
}
state.log(kLogVerbose, "Decompressing a frame of size %zu", frameSize);
status = readData(*in, chunkSize, frameSize, fd, &totalBytesRead);
}
state.errorHolder.check(status != FileStatus::Error, "Error reading input");
return totalBytesRead;
}
static bool writeData(ByteRange data, FILE* fd) {
while (!data.empty()) {
data.advance(std::fwrite(data.begin(), 1, data.size(), fd));
if (std::ferror(fd)) {
return false;
}
}
return true;
}
std::uint64_t writeFile(
SharedState& state,
WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs,
FILE* outputFd,
bool decompress) {
auto& errorHolder = state.errorHolder;
auto lineClearGuard = makeScopeGuard([&state] {
state.log.clear(kLogInfo);
});
std::uint64_t bytesWritten = 0;
std::shared_ptr<BufferWorkQueue> out;
while (outs.pop(out)) {
if (errorHolder.hasError()) {
continue;
}
if (!decompress) {
SkippableFrame frame(out->size());
if (!writeData(frame.data(), outputFd)) {
errorHolder.setError("Failed to write output");
return bytesWritten;
}
bytesWritten += frame.kSize;
}
Buffer buffer;
while (out->pop(buffer) && !errorHolder.hasError()) {
if (!writeData(buffer.range(), outputFd)) {
errorHolder.setError("Failed to write output");
return bytesWritten;
}
bytesWritten += buffer.size();
state.log.update(kLogInfo, "Written: %u MB ",
static_cast<std::uint32_t>(bytesWritten >> 20));
}
}
return bytesWritten;
}
}