#pragma once
#include <optional>
#include <string>
#include "pybind11/pybind11.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/value.h"
#include "megbrain/utils/small_vector.h"
namespace mgb::imperative::python {
struct TransformationManager {
public:
enum Segment {
ModuleTrace,
DTypePromote,
DimExpansion,
Grad,
Scalar,
Trace,
Eval,
};
std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments;
private:
template <Segment segment>
void unregister(std::shared_ptr<Transformation> transformation) noexcept {
mgb_assert(segment < segments.size());
auto iter = std::find(
segments[segment].begin(), segments[segment].end(), transformation);
mgb_assert(iter != segments[segment].end());
transformation->unregister();
segments[segment].erase(iter);
}
public:
template <Segment segment>
[[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at(
std::shared_ptr<Transformation> transformation) {
mgb_assert(segment < segments.size());
std::shared_ptr<Transformation> next;
for (size_t i = segment; i < segments.size(); ++i) {
if (!segments[i].empty()) {
next = segments[i].back();
break;
}
}
if (!next) {
transformation->register_at(Transformation::bottom());
} else {
transformation->register_at(next->pos());
}
segments[segment].push_back(transformation);
return std::make_unique<CleanupGuard<>>(
[this, transformation]() { unregister<segment>(transformation); });
}
static TransformationManager& get_instance() {
static TransformationManager sl_instance;
return sl_instance;
}
};
class PyValue final : public PrimitiveValue<PyValue, pybind11::object> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const {
return pybind11::str((const pybind11::object&)*this).cast<std::string>();
}
};
}