#include "./ops.h"
#include "./helper.h"
#include "./tensor.h"
#include "megbrain/common.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/rng.h"
#include "megbrain/imperative/ops/utility.h"
#include <Python.h>
#include <unordered_map>
namespace py = pybind11;
using namespace mgb::imperative;
namespace {
auto normalize_enum(const std::string& in) {
std::string ret;
for (auto&& c : in) {
ret += toupper(c);
}
return ret;
}
}
#define CATCH_ALL(RETVAL) \
catch (py::error_already_set & e) { \
e.restore(); \
return RETVAL; \
} \
catch (py::builtin_exception & e) { \
e.set_error(); \
return RETVAL; \
} \
catch (std::exception & e) { \
PyErr_SetString(PyExc_RuntimeError, e.what()); \
return RETVAL; \
}
namespace {
#define PyOp(name) Py##name
#define PyOpType(name) PyOp(name)::py_type
#define PyOpDefBegin(name) \
struct PyOp(name) : PyOpDef { \
using Ty = name; \
Ty& inst() { return op->cast_final_safe<Ty>(); } \
static PyTypeObject py_type;
#define PyOpDefEnd(name) \
} \
; \
PyTypeObject PyOpType(name);
#define RETURN_RICHCOMPARE(val1, val2, op) \
do { \
switch (op) { \
case Py_EQ: \
if ((val1) == (val2)) \
Py_RETURN_TRUE; \
Py_RETURN_FALSE; \
case Py_NE: \
if ((val1) != (val2)) \
Py_RETURN_TRUE; \
Py_RETURN_FALSE; \
case Py_LT: \
if ((val1) < (val2)) \
Py_RETURN_TRUE; \
Py_RETURN_FALSE; \
case Py_GT: \
if ((val1) > (val2)) \
Py_RETURN_TRUE; \
Py_RETURN_FALSE; \
case Py_LE: \
if ((val1) <= (val2)) \
Py_RETURN_TRUE; \
Py_RETURN_FALSE; \
case Py_GE: \
if ((val1) >= (val2)) \
Py_RETURN_TRUE; \
Py_RETURN_FALSE; \
default: \
Py_FatalError("Unreachable C code path reached"); \
} \
} while (0)
template <typename T>
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
T* self = reinterpret_cast<T*>(obj);
if (self != NULL) {
self->op = T::Ty::make();
}
return obj;
}
template <typename T, typename SNIFAE = void>
struct serialization {
static T load(py::object obj) { return py::cast<T>(obj); }
template <
typename U, typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
static py::object dump(U&& t) {
return py::cast(std::forward<U>(t));
}
};
template <typename T>
void py_dealloc_generic(PyObject* obj) {
reinterpret_cast<T*>(obj)->op.reset();
Py_TYPE(obj)->tp_free(obj);
}
template <typename T, typename U, U T::Ty::*attr>
PyObject* py_get_generic_impl(PyObject* obj, void* ) {
auto& op = reinterpret_cast<T*>(obj)->inst();
return py::cast(op.*attr).release().ptr();
}
#define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template <typename T, typename U, U T::Ty::*attr>
int py_set_generic_impl(PyObject* obj, PyObject* value, void* ) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
py::detail::loader_life_support guard{};
op.*attr = py::cast<U>(py::handle(value));
}
CATCH_ALL(-1)
return 0;
}
#define py_set_generic(name, attr) \
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
struct PyOpDef {
PyObject_HEAD std::shared_ptr<OpDef> op;
static PyTypeObject py_type;
static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
static PyGetSetDef py_getsetters[];
static Py_hash_t tp_hash(PyObject* obj);
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op);
static PyObject* py_repr(PyObject* self) {
return py::cast(reinterpret_cast<PyOpDef*>(self)->op->make_name())
.release()
.ptr();
}
};
PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
PyObject* py_get_scope(PyObject* obj, void* ) {
return py::cast(reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
}
int py_set_scope(PyObject* obj, PyObject* value, void* ) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
try {
reinterpret_cast<PyOp(OpDef)*>(obj)->op->set_scope(
py::cast<std::string>(py::handle(value)));
}
CATCH_ALL(-1)
return 0;
}
PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
{const_cast<char*>("scope"), py_get_scope, py_set_scope,
const_cast<char*>("scope"), NULL},
{NULL}};
Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) {
return static_cast<Py_hash_t>(reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
}
PyObject* PyOp(OpDef)::tp_richcompare(PyObject* self, PyObject* other, int op) {
bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
*reinterpret_cast<PyOp(OpDef)*>(other)->op);
if (op == Py_EQ || op == Py_NE) {
RETURN_RICHCOMPARE(same, true, op);
}
Py_RETURN_NOTIMPLEMENTED;
}
template <typename T>
struct EnumTrait;
#define PyEnumHead \
static_assert(std::is_enum_v<T>); \
PyObject_HEAD T value; \
constexpr static const char* name = EnumTrait<T>::name; \
static PyTypeObject* type; \
static const char* members[]; \
static std::unordered_map<std::string, T> mem2value; \
static PyObject* pyobj_insts[];
template <typename T>
struct EnumWrapper {
PyEnumHead std::string to_string() const {
return members[static_cast<size_t>(value)];
}
static PyObject* py_repr(PyObject* self) {
return py::cast(
std::string(name) + "." +
reinterpret_cast<EnumWrapper*>(self)->to_string())
.release()
.ptr();
}
static PyObject* py_dump(PyObject* self) {
return py::cast(reinterpret_cast<EnumWrapper*>(self)->to_string())
.release()
.ptr();
}
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
if (op == Py_EQ || op == Py_NE) {
T lhs, rhs;
if (load(other, rhs) && load(self, lhs)) {
RETURN_RICHCOMPARE(lhs, rhs, op);
} else {
RETURN_RICHCOMPARE(0, 1, op);
}
}
Py_RETURN_NOTIMPLEMENTED;
}
static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, type)) {
value = reinterpret_cast<EnumWrapper*>(obj)->value;
return true;
}
if (py::isinstance<py::str>(src)) {
auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
if (iter != mem2value.end()) {
value = iter->second;
return true;
} else {
return false;
}
}
return false;
}
static PyObject* cast(const T& value) {
auto v = static_cast<std::underlying_type_t<T>>(value);
mgb_assert(v <= EnumTrait<T>::max);
PyObject* obj = pyobj_insts[v];
Py_INCREF(obj);
return obj;
}
};
template <typename T>
struct BitCombinedEnumWrapper {
PyEnumHead std::string to_string() const {
uint32_t value_int = static_cast<uint32_t>(value);
if (value_int == 0) {
return "None";
} else {
std::string ret;
bool first = true;
for (uint32_t i = 0; i < 32; i++) {
if (value_int >> i & 1) {
if (!first) {
ret += " + ";
} else {
first = false;
}
ret += (std::string(name) + "." + members[i]);
}
}
return ret;
}
}
static PyObject* py_new_combined_enum(
PyTypeObject* type, PyObject* args, PyObject*) {
if (!PyTuple_Size(args)) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
return obj;
} else {
PyObject* input;
if (!PyArg_ParseTuple(args, "|O", &input)) {
return nullptr;
}
T value;
if (load(input, value)) {
return cast(value);
} else {
PyErr_SetString(
PyExc_RuntimeError,
mgb::ssprintf(
"Cannot convert type %s to type %s\n",
input->ob_type->tp_name, name)
.c_str());
return nullptr;
}
}
}
static PyObject* py_repr(PyObject* self) {
return py::cast(reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
.release()
.ptr();
}
static PyObject* py_dump(PyObject* self) {
std::vector<std::string> result;
auto value = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value;
uint32_t value_int = static_cast<uint32_t>(value);
for (uint32_t i = 0; i < 32; i++) {
if (value_int >> i & 1) {
result.push_back(members[i]);
}
}
return py::tuple(py::cast(result)).release().ptr();
}
static PyObject* py_or(PyObject* self, PyObject* other) {
if (!(self->ob_type == other->ob_type)) {
return PyErr_Format(
PyExc_RuntimeError,
"Operand in or operator must be the same type.");
}
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
return cast(lhs | rhs);
}
static PyObject* py_and(PyObject* self, PyObject* other) {
if (!(self->ob_type == other->ob_type)) {
return PyErr_Format(
PyExc_RuntimeError,
"Operand in and operator must be the same type.");
}
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
return cast(lhs & rhs);
}
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
if (op == Py_EQ || op == Py_NE) {
T lhs, rhs;
if (load(other, rhs) && load(self, lhs)) {
RETURN_RICHCOMPARE(lhs, rhs, op);
} else {
RETURN_RICHCOMPARE(0, 1, op);
}
}
Py_RETURN_NOTIMPLEMENTED;
}
static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, type)) {
value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
return true;
}
if (py::isinstance<py::str>(src)) {
auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
if (iter != mem2value.end()) {
value = iter->second;
return true;
} else {
return false;
}
}
if (py::isinstance<py::tuple>(src)) {
auto params = py::cast<std::vector<std::string>>(src);
bool first = true;
for (auto s : params) {
auto&& iter = mem2value.find(normalize_enum(s));
if (iter != mem2value.end()) {
if (first) {
value = iter->second;
first = false;
} else {
value |= iter->second;
}
} else {
return false;
}
}
return true;
}
if (py::isinstance<py::int_>(obj)) {
auto v = py::cast<std::underlying_type_t<T>>(src);
if (v > EnumTrait<T>::max) {
return false;
}
value = static_cast<T>(v);
return true;
}
return false;
}
static PyObject* cast(const T& value) {
auto v = static_cast<std::underlying_type_t<T>>(value);
mgb_assert(v <= EnumTrait<T>::max);
if ((!v) || (v & (v - 1))) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
} else {
PyObject* obj = pyobj_insts[__builtin_ctz(v)];
Py_INCREF(obj);
return obj;
}
}
};
template <typename T>
struct serialization<T, std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
static T load(py::object obj) {
auto caster = pybind11::detail::type_caster<T>();
if (caster.load(obj, true)) {
return caster;
} else {
PyErr_SetString(PyExc_RuntimeError, "load faild \n");
return caster;
}
}
static py::object dump(T t) { return py::cast(t).attr("dump")(); }
};
void _init_py_op_def(py::module m) {
using py_op = PyOp(OpDef);
auto& py_type = PyOpType(OpDef);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.OpDef";
py_type.tp_basicsize = sizeof(PyOp(OpDef));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "OpDef";
py_type.tp_base = &PyBaseObject_Type;
py_type.tp_hash = PyOp(OpDef)::tp_hash;
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_repr = py_op::py_repr;
py_type.tp_dealloc = py_dealloc_generic<PyOp(OpDef)>;
mgb_assert(PyType_Ready(&py_type) >= 0);
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
}
struct PyOpBase : PyOpDef {
static PyTypeObject py_type;
static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
auto* obj = type->tp_alloc(type, 0);
if (obj) {
auto* self = reinterpret_cast<PyOpBase*>(obj);
new (&self->op) decltype(self->op);
}
return obj;
}
};
PyTypeObject PyOpBase::py_type;
void _init_py_op_base(py::module m) {
using py_op = PyOpBase;
auto& py_type = PyOpBase::py_type;
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
py_type.tp_basicsize = sizeof(py_op);
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "PyOpBase";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_op::tp_new;
mgb_assert(PyType_Ready(&py_type) >= 0);
m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
}
#include "opdef.cpy.inl"
#undef CATCH_ALL
}
namespace PYBIND11_NAMESPACE {
namespace detail {
bool type_caster<OpDef>::load(handle src, bool convert) {
PyObject* obj = src.ptr();
if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
return false;
}
value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
if (!value) {
value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
}
return true;
}
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
return object(pyop->obj).release();
}
PyTypeObject* pytype;
auto& c2p = PyOp(OpDef)::ctype2pytype;
auto&& iter = c2p.find(op.dyn_typeinfo());
if (iter != c2p.end()) { pytype = iter->second;
} else { pytype = &PyOpType(OpDef);
}
PyObject* obj = pytype->tp_alloc(pytype, 0);
mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
return py::handle(obj);
}
#define ENUM_CASTER_IMPL(T) \
bool type_caster<T>::load(handle src, bool) { \
return EnumWrapper<T>::load(src, value); \
} \
handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
return EnumWrapper<T>::cast(value); \
}
FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)
#define BIT_COMBINED_ENUM_CASTER_IMPL(T) \
bool type_caster<T>::load(handle src, bool) { \
return BitCombinedEnumWrapper<T>::load(src, value); \
} \
handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
return BitCombinedEnumWrapper<T>::cast(value); \
}
FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
} }
void init_ops(py::module m) {
_init_py_op_def(m);
_init_py_op_base(m);
INIT_ALL_OP(m)
m.def("new_rng_handle", &rng::new_handle);
m.def(
"delete_rng_handle",
[](size_t handle) {
mgb::CompNode::sync_all();
py_task_q.wait_all_task_finish();
rng::delete_handle(handle);
},
py::call_guard<py::gil_scoped_release>());
m.def("set_global_rng_seed", [](uint64_t seed) -> void {
mgb_assert(
python::interpreter_for_py->check_available(),
"set global random seed failed since imperative interpreter has been "
"destroyed");
python::interpreter_for_py->sync();
mgb::CompNode::sync_all();
rng::set_global_rng_seed(seed);
});
m.def("get_global_rng_seed", &rng::get_global_rng_seed);
m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
struct PySubgraphBuilder {
explicit PySubgraphBuilder(std::string name) : name{name} {}
std::string name;
Subgraph graph;
mgb::SmallVector<bool> output_grad_mask;
Subgraph::var_t next_var = 1;
std::shared_ptr<mgb::Hashable> key = nullptr;
std::shared_ptr<OpDef> build() {
if (key == nullptr) {
key = std::make_shared<UniqueKey>();
}
return SubgraphOp::make(
name, std::make_shared<Subgraph>(graph), output_grad_mask, key);
}
};
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
.def(py::init<std::string>())
.def(py::init<PySubgraphBuilder>())
.def("input",
[](PySubgraphBuilder& self) {
mgb_assert(self.key == nullptr);
auto var = self.next_var++;
self.graph.inputs.push_back(var);
return var;
})
.def("apply",
[](PySubgraphBuilder& self, std::shared_ptr<OpDef> op,
Subgraph::vars_t inputs, size_t nr_outputs) {
mgb_assert(self.key == nullptr);
Subgraph::vars_t outputs;
for (size_t i = 0; i < nr_outputs; ++i) {
outputs.push_back(self.next_var++);
}
self.graph.exprs.push_back({op, inputs, outputs});
return outputs;
})
.def("apply_const",
[](PySubgraphBuilder& self, py::object value, mgb::DType dtype,
mgb::CompNode cn) {
mgb_assert(self.key == nullptr);
auto var = self.next_var++;
mgb::HostTensorND hvalue(cn);
npy::np2tensor(
value.cast<py::array>().ptr(),
npy::Meth::copy_into(&hvalue), dtype);
self.graph.constants.push_back({var, Tensor::make(hvalue)});
return var;
})
.def("outputs",
[](PySubgraphBuilder& self, Subgraph::vars_t outputs) {
mgb_assert(self.key == nullptr);
self.graph.outputs = outputs;
self.output_grad_mask.resize(outputs.size(), true);
})
.def("outputs_has_grad",
[](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) {
mgb_assert(self.key == nullptr);
mgb_assert(
self.graph.outputs.size() == self.output_grad_mask.size());
self.output_grad_mask = outputs_has_grad;
})
.def("get",
[](PySubgraphBuilder& self) {
return (std::shared_ptr<OpDef>)self.build();
})
.def("compile",
[](PySubgraphBuilder& self, int gopt_level) {
return (std::shared_ptr<OpDef>)CompiledOp::make(
self.build(), gopt_level);
})
.def("jit_fuse", [](PySubgraphBuilder& self) {
return (std::shared_ptr<OpDef>)CompiledOp::make(
JITFusionOp::make(self.build()));
});
m.def("set_jit_enabled", &JITFusionOp::set_enabled);
bool jit_supported = false;
#if MGB_JIT
jit_supported = true;
#endif
m.attr("jit_supported") = jit_supported;
auto custom = submodule(m, "_custom");
init_custom(custom);
}
#define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \
case custom::ParamDynType::dyn_type: { \
param_val = py::handle(kv.second).cast<static_type>(); \
break; \
}
#define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \
case custom::ParamDynType::dyn_type: { \
auto pyvals = py::handle(kv.second).cast<py::list>(); \
static_type vals; \
using basic_type = custom::get_vector_template_arg_type<static_type>::type; \
for (auto& pyval : pyvals) { \
vals.push_back(py::handle(pyval).cast<basic_type>()); \
} \
param_val = vals; \
break; \
}
PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
#if MGB_CUSTOM_OP
auto op_name = py::handle(args[0]).cast<std::string>();
auto kwargs = py::handle(args[1]).cast<py::dict>();
std::shared_ptr<OpDef> opdef = CustomOpDefFactory::inst()->create_opdef(op_name);
auto& custom_opdef = static_cast<mgb::imperative::CustomOpDef&>(*opdef);
auto& param = custom_opdef.param();
for (auto&& kv : kwargs) {
std::string param_name = py::handle(kv.first).cast<std::string>();
std::string type_name = py::handle(kv.second).ptr()->ob_type->tp_name;
if (!param.exist(param_name)) {
mgb_log_warn(
"op %s have no param named %s, ignore this param parsed from "
"python",
op_name.c_str(), param_name.c_str());
continue;
}
auto& param_val = param[param_name];
switch (param_val.type()) {
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
default: {
mgb_assert(
false, "param dtype of %s:%s is invalid", op_name.c_str(),
param_name.c_str());
}
}
}
PyTypeObject* pytype;
pytype = &PyOpType(OpDef);
PyObject* obj = pytype->tp_alloc(pytype, 0);
reinterpret_cast<PyOp(OpDef)*>(obj)->op = opdef;
return obj;
#else
mgb_assert(
false,
"Custom Op is disabled now, please build megengine with Custom Op open");
return nullptr;
#endif
}
#undef CUSTOM_CASE_TO_PARSE_LIST
#undef CUSTOM_CASE_TO_PARSE_NON_LIST
py::list install_custom(const std::string& name, const std::string& path) {
#if MGB_CUSTOM_OP
py::list ret;
const auto& ops_in_lib = custom::LibManager::inst()->install(name, path);
for (const auto& op : ops_in_lib) {
ret.append(op);
}
return ret;
#else
mgb_assert(
false,
"Custom Op is disabled now, please build megengine with Custom Op open");
py::list ret;
return ret;
#endif
}
bool uninstall_custom(const std::string& name) {
#if MGB_CUSTOM_OP
return custom::LibManager::inst()->uninstall(name);
#else
mgb_assert(
false,
"Custom Op is disabled now, please build megengine with Custom Op open");
return false;
#endif
}
py::list get_custom_op_list(void) {
#if MGB_CUSTOM_OP
std::vector<std::string> all_ops = CustomOpDefFactory::inst()->op_list();
py::list ret;
for (auto& op : all_ops) {
ret.append(op);
}
return ret;
#else
mgb_assert(
false,
"Custom Op is disabled now, please build megengine with Custom Op open");
py::list ret;
return ret;
#endif
}
#ifndef METH_FASTCALL
PyObject* py35_make_custom_op(PyObject* self, PyObject* args) {
auto* arr = &PyTuple_GET_ITEM(args, 0);
auto size = PyTuple_GET_SIZE(args);
return make_custom_op(self, arr, size);
};
#endif
void init_custom(pybind11::module m) {
m.def("_install", &install_custom);
m.def("_uninstall", &uninstall_custom);
m.def("_get_custom_op_list", &get_custom_op_list);
m.def("get_custom_op_abi_tag", [](void) -> int {
int ret = 0;
#ifdef _GLIBCXX_USE_CXX11_ABI
ret = _GLIBCXX_USE_CXX11_ABI;
#endif
return ret;
});
static PyMethodDef method_def = {
#ifdef METH_FASTCALL
"_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
#else
"_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, ""
#endif
};
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
pybind11::setattr(m, method_def.ml_name, func);
}