#include "absl/strings/cord_analysis.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <unordered_set>
#include "absl/base/config.h"
#include "absl/base/nullability.h"
#include "absl/strings/internal/cord_data_edge.h"
#include "absl/strings/internal/cord_internal.h"
#include "absl/strings/internal/cord_rep_btree.h"
#include "absl/strings/internal/cord_rep_crc.h"
namespace absl {
ABSL_NAMESPACE_BEGIN
namespace cord_internal {
namespace {
enum class Mode { kFairShare, kTotal, kTotalMorePrecise };
template <Mode mode>
struct CordRepRef {
explicit CordRepRef(absl::Nonnull<const CordRep*> r) : rep(r) {}
CordRepRef Child(absl::Nonnull<const CordRep*> child) const {
return CordRepRef(child);
}
absl::Nonnull<const CordRep*> rep;
};
template <Mode mode>
struct RawUsage {
size_t total = 0;
void Add(size_t size, CordRepRef<mode>) { total += size; }
};
template <>
struct RawUsage<Mode::kTotalMorePrecise> {
size_t total = 0;
std::unordered_set<absl::Nonnull<const CordRep*>> counted;
void Add(size_t size, CordRepRef<Mode::kTotalMorePrecise> repref) {
if (counted.insert(repref.rep).second) {
total += size;
}
}
};
template <typename refcount_t>
double MaybeDiv(double d, refcount_t refcount) {
return refcount == 1 ? d : d / refcount;
}
template <>
struct CordRepRef<Mode::kFairShare> {
explicit CordRepRef(absl::Nonnull<const CordRep*> r, double frac = 1.0)
: rep(r), fraction(MaybeDiv(frac, r->refcount.Get())) {}
CordRepRef Child(absl::Nonnull<const CordRep*> child) const {
return CordRepRef(child, fraction);
}
absl::Nonnull<const CordRep*> rep;
double fraction;
};
template <>
struct RawUsage<Mode::kFairShare> {
double total = 0;
void Add(size_t size, CordRepRef<Mode::kFairShare> rep) {
total += static_cast<double>(size) * rep.fraction;
}
};
template <Mode mode>
void AnalyzeDataEdge(CordRepRef<mode> rep, RawUsage<mode>& raw_usage) {
assert(IsDataEdge(rep.rep));
if (rep.rep->tag == SUBSTRING) {
raw_usage.Add(sizeof(CordRepSubstring), rep);
rep = rep.Child(rep.rep->substring()->child);
}
const size_t size =
rep.rep->tag >= FLAT
? rep.rep->flat()->AllocatedSize()
: rep.rep->length + sizeof(CordRepExternalImpl<intptr_t>);
raw_usage.Add(size, rep);
}
template <Mode mode>
void AnalyzeBtree(CordRepRef<mode> rep, RawUsage<mode>& raw_usage) {
raw_usage.Add(sizeof(CordRepBtree), rep);
const CordRepBtree* tree = rep.rep->btree();
if (tree->height() > 0) {
for (CordRep* edge : tree->Edges()) {
AnalyzeBtree(rep.Child(edge), raw_usage);
}
} else {
for (CordRep* edge : tree->Edges()) {
AnalyzeDataEdge(rep.Child(edge), raw_usage);
}
}
}
template <Mode mode>
size_t GetEstimatedUsage(absl::Nonnull<const CordRep*> rep) {
RawUsage<mode> raw_usage;
CordRepRef<mode> repref(rep);
if (repref.rep->tag == CRC) {
raw_usage.Add(sizeof(CordRepCrc), repref);
if (repref.rep->crc()->child == nullptr) {
return static_cast<size_t>(raw_usage.total);
}
repref = repref.Child(repref.rep->crc()->child);
}
if (IsDataEdge(repref.rep)) {
AnalyzeDataEdge(repref, raw_usage);
} else if (repref.rep->tag == BTREE) {
AnalyzeBtree(repref, raw_usage);
} else {
assert(false);
}
return static_cast<size_t>(raw_usage.total);
}
}
size_t GetEstimatedMemoryUsage(absl::Nonnull<const CordRep*> rep) {
return GetEstimatedUsage<Mode::kTotal>(rep);
}
size_t GetEstimatedFairShareMemoryUsage(absl::Nonnull<const CordRep*> rep) {
return GetEstimatedUsage<Mode::kFairShare>(rep);
}
size_t GetMorePreciseMemoryUsage(absl::Nonnull<const CordRep*> rep) {
return GetEstimatedUsage<Mode::kTotalMorePrecise>(rep);
}
} ABSL_NAMESPACE_END
}