#pragma once
#include "utils/Buffer.h"
#include <atomic>
#include <cassert>
#include <cstddef>
#include <condition_variable>
#include <cstddef>
#include <functional>
#include <mutex>
#include <queue>
namespace pzstd {
template <typename T>
class WorkQueue {
std::mutex mutex_;
std::condition_variable readerCv_;
std::condition_variable writerCv_;
std::condition_variable finishCv_;
std::queue<T> queue_;
bool done_;
std::size_t maxSize_;
bool full() const {
if (maxSize_ == 0) {
return false;
}
return queue_.size() >= maxSize_;
}
public:
WorkQueue(std::size_t maxSize = 0) : done_(false), maxSize_(maxSize) {}
bool push(T&& item) {
{
std::unique_lock<std::mutex> lock(mutex_);
while (full() && !done_) {
writerCv_.wait(lock);
}
if (done_) {
return false;
}
queue_.push(std::move(item));
}
readerCv_.notify_one();
return true;
}
bool pop(T& item) {
{
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.empty() && !done_) {
readerCv_.wait(lock);
}
if (queue_.empty()) {
assert(done_);
return false;
}
item = std::move(queue_.front());
queue_.pop();
}
writerCv_.notify_one();
return true;
}
void setMaxSize(std::size_t maxSize) {
{
std::lock_guard<std::mutex> lock(mutex_);
maxSize_ = maxSize;
}
writerCv_.notify_all();
}
void finish() {
{
std::lock_guard<std::mutex> lock(mutex_);
assert(!done_);
done_ = true;
}
readerCv_.notify_all();
writerCv_.notify_all();
finishCv_.notify_all();
}
void waitUntilFinished() {
std::unique_lock<std::mutex> lock(mutex_);
while (!done_) {
finishCv_.wait(lock);
}
}
};
class BufferWorkQueue {
WorkQueue<Buffer> queue_;
std::atomic<std::size_t> size_;
public:
BufferWorkQueue(std::size_t maxSize = 0) : queue_(maxSize), size_(0) {}
void push(Buffer buffer) {
size_.fetch_add(buffer.size());
queue_.push(std::move(buffer));
}
bool pop(Buffer& buffer) {
bool result = queue_.pop(buffer);
if (result) {
size_.fetch_sub(buffer.size());
}
return result;
}
void setMaxSize(std::size_t maxSize) {
queue_.setMaxSize(maxSize);
}
void finish() {
queue_.finish();
}
std::size_t size() {
queue_.waitUntilFinished();
return size_.load();
}
};
}