#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/utils/to_string.h"
namespace mgb::imperative {
namespace interpreter::intl {
enum EvictType {
NONE = 0,
DROP = 1,
};
struct DsuNode {
DsuNode(double _t) : t(_t) {}
std::shared_ptr<DsuNode> parent;
bool is_root() { return !bool(parent); }
double t;
};
struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;
struct TensorInfo {
enum Status {
InvalidStatus,
Allocated,
Produced,
Dropped,
Deleted,
};
uint64_t id = -1;
std::string name;
TensorPtr ptr;
LogicalTensorDesc desc;
double compute_time;
size_t memory;
double last_used_time;
bool invalid = false;
bool allow_delete = false;
EvictType evict_type = NONE;
Status status = InvalidStatus;
HostTensorND h_value;
size_t pinned = 0;
size_t recompute_times = 0;
size_t ref_cnt = 0;
std::shared_ptr<DsuNode> dsu_ptr;
size_t ptr_use_count = 0;
struct ComputePath {
uint64_t id;
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> unique_inputs;
SmallVector<TensorInfo*> outputs;
SmallVector<LogicalTensorDesc> outputs_descs;
size_t ref_cnt() {
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
}
static ComputePath* make(
uint64_t id, std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs,
SmallVector<TensorInfo*> outputs,
SmallVector<LogicalTensorDesc> outputs_descs) {
auto* path = new TensorInfo::ComputePath();
path->id = id;
path->op = op;
path->inputs = inputs;
path->outputs = outputs;
path->outputs_descs = outputs_descs;
SmallVector<TensorInfo*> unique_inputs = inputs;
std::sort(unique_inputs.begin(), unique_inputs.end());
unique_inputs.erase(
std::unique(unique_inputs.begin(), unique_inputs.end()),
unique_inputs.end());
path->unique_inputs = unique_inputs;
for (auto input : unique_inputs) {
input->users.push_back(path);
}
for (auto output : outputs) {
output->producer = path;
}
for (auto input : inputs) {
input->ref_cnt += outputs.size();
}
return path;
}
}* producer = nullptr;
double eval_func(
double cost, double free_mem, double cur_time, double param_cost,
double param_mem, double param_time, double param_recompute_times) {
return pow(cost + 1e-3, param_cost) *
pow(param_recompute_times, (double)recompute_times) /
(pow((memory + free_mem) / 1024.0 / 1024.0, param_mem) *
pow((double)(cur_time - last_used_time + 1e-3), param_time));
}
void pin() { ++pinned; }
void unpin() { --pinned; }
bool detach_producer() {
if (!producer) {
return false;
}
auto output =
std::find(producer->outputs.begin(), producer->outputs.end(), this);
mgb_assert(output != producer->outputs.end());
*output = nullptr;
bool deleted = false;
if (producer->ref_cnt() == 0) {
for (auto* input : producer->unique_inputs) {
input->users.erase(
std::find(input->users.begin(), input->users.end(), producer));
}
delete producer;
deleted = true;
}
producer = nullptr;
return deleted;
}
bool size_exceeds_thd(size_t thd) { return memory > thd; }
SmallVector<ComputePath*> users;
size_t cand_index = UINT_MAX;
};
}
}