mlx-sys-burn 0.2.2

Low-level interface and binding generation for the mlx library (fork with additional operations for burn-mlx)
import argparse
import regex
import string
import mlxtypes as mt
import type_private_generator as tpg

parser = argparse.ArgumentParser("MLX C closure code generator", add_help=False)
parser.add_argument("--implementation", default=False, action="store_true")
parser.add_argument("--private", default=False, action="store_true")
args = parser.parse_args()


def replace_match_parenthesis(string, keyword, fun):
    pattern = regex.compile(keyword + r"(\((?:[^()]++|(?1))++\))")
    res = []
    pos = 0
    for m in pattern.finditer(string):
        res.append(string[pos : m.start()])
        res.append(fun(m[1][1:-1]))
        pos = m.end()
    res.append(string[pos:])
    return "".join(res)


decl_code = """
typedef struct NAME_ {
  void* ctx;
} NAME;
NAME NAME_new();
int NAME_free(NAME cls);
NAME NAME_new_func(int (*fun)(RCARGS_UNNAMED, CARGS_UNNAMED));
NAME NAME_new_func_payload(
    int (*fun)(RCARGS_UNNAMED, CARGS_UNNAMED, void*),
    void* payload,
    void (*dtor)(void*));
int NAME_set(NAME *cls, const NAME src);
int NAME_apply(RCARGS, NAME cls, CARGS);
"""


def generate(code, name, rcpptype, cpptypes):
    rcpparg = mt.cpptypes[rcpptype]["cpp"].replace("@", "")
    cppargs = ", ".join([mt.cpptypes[cpptype]["cpp_arg"]("") for cpptype in cpptypes])

    if code is None:
        return tpg.generate(name, "std::function<" + rcpparg + "(" + cppargs + ")>")

    cargs_untyped = []
    cargs = []
    cppargs_type_name = []
    cppargs_to_cargs = []
    cargs_free = []
    cargs_ctx = []
    for i in range(len(cpptypes)):
        cpptype = mt.cpptypes[cpptypes[i]]
        cpparg = cpptype["cpp"]
        suffix = "_" + str(i) if len(cpptypes) > 1 else ""
        cargs_untyped.append(cpptype["c_arg"]("input" + suffix, untyped=True))
        cargs.append(cpptype["c_arg"]("input" + suffix))
        cppargs_type_name.append(cpptype["cpp_arg"]("cpp_input" + suffix))
        cargs_free.append(cpptype["free"]("input" + suffix) + ";")
        cargs_ctx.append(cpptype["c_to_cpp"]("input" + suffix))
        cppargs_to_cargs.append(cpptype["c_new"]("input" + suffix) + ";")
        cppargs_to_cargs.append(
            cpptype["c_assign_from_cpp"](
                "input" + suffix, "cpp_input" + suffix, returned=False
            )
            + ";"
        )

    rcargs_new = mt.cpptypes[rcpptype]["c_new"]("res") + ";"
    rcargs_free = mt.cpptypes[rcpptype]["free"]("res") + ";"
    rcargs_to_cpp = "auto cpp_res = " + mt.cpptypes[rcpptype]["c_to_cpp"]("res") + ";"

    cargs_untyped = ", ".join(cargs_untyped)
    cargs = ", ".join(cargs)
    cppargs_type_name = ", ".join(cppargs_type_name)
    cppargs_to_cargs = "\n".join(cppargs_to_cargs)
    cargs_free = "\n".join(cargs_free)
    cargs_ctx = ", ".join(cargs_ctx)
    cargs_unnamed = " ".join(
        [mt.cpptypes[cpptype]["c_arg"]("") for cpptype in cpptypes]
    )
    rcargs_unnamed = mt.cpptypes[rcpptype]["c_return_arg"]("")
    rcargs = mt.cpptypes[rcpptype]["c_return_arg"]("res")
    rcargs_untyped = mt.cpptypes[rcpptype]["c_return_arg"]("res", untyped=True)

    code = code.replace("RCARGS_UNTYPED", rcargs_untyped)
    code = code.replace("RCARGS_UNNAMED", rcargs_unnamed)
    code = code.replace("CPPARGS_TYPE_NAME", cppargs_type_name)
    code = code.replace("CPPARGS_TO_CARGS", cppargs_to_cargs)
    code = code.replace("RCARGS_NEW", rcargs_new)
    code = code.replace("RCARGS_FREE", rcargs_free)
    code = code.replace("RCARGS_TO_CPP", rcargs_to_cpp)
    code = code.replace("CARGS_UNTYPED", cargs_untyped)
    code = code.replace("CARGS_CTX", cargs_ctx)
    code = code.replace("CARGS_FREE", cargs_free)
    code = code.replace("RCPPARG", rcpparg)
    code = code.replace(
        "CARGS_UNNAMED",
        ", ".join([mt.cpptypes[cpptype]["c_arg"]("") for cpptype in cpptypes]),
    )

    code = code.replace(
        "ASSIGN_CLS_TO_RCARGS",
        mt.cpptypes[rcpptype]["c_assign_from_cpp"](
            "res", "NAME_get_(cls)(" + cargs_ctx + ")", returned=True
        )
        + ";",
    )

    code = code.replace("CPPARGS", cppargs)
    code = code.replace("NAME", name)
    code = code.replace("RCARGS", rcargs)
    code = code.replace("CARGS", cargs)

    return code


impl_code = """
extern "C" NAME NAME_new() {
  try {
    return NAME_new_();
  } catch (std::exception& e) {
    mlx_error(e.what());
    return NAME_new_();
  }
}

extern "C" int NAME_set(NAME *cls, const NAME src) {
  try {
    NAME_set_(*cls, NAME_get_(src));
  } catch (std::exception& e) {
    mlx_error(e.what());
    return 1;
  }
  return 0;
}

extern "C" int NAME_free(NAME cls) {
  try {
    NAME_free_(cls);
    return 0;
  } catch (std::exception& e) {
    mlx_error(e.what());
    return 1;
  }
}

extern "C" NAME NAME_new_func(int (*fun)(RCARGS_UNNAMED, CARGS_UNNAMED)) {
  try {
    auto cpp_closure = [fun](CPPARGS_TYPE_NAME) {
      CPPARGS_TO_CARGS
      RCARGS_NEW
      auto status = fun(RCARGS_UNTYPED, CARGS_UNTYPED);
      CARGS_FREE
      if(status) {
        RCARGS_FREE
        throw std::runtime_error("NAME returned a non-zero value");
      }
      RCARGS_TO_CPP
      RCARGS_FREE
      return cpp_res;
    };
    return NAME_new_(cpp_closure);
  } catch (std::exception& e) {
    mlx_error(e.what());
    return NAME_new_();
  }
}

extern "C" NAME NAME_new_func_payload(
    int (*fun)(RCARGS_UNNAMED, CARGS_UNNAMED, void*),
    void* payload,
    void (*dtor)(void*)) {
  try {
    std::shared_ptr<void> cpp_payload = nullptr;
    if (dtor) {
      cpp_payload = std::shared_ptr<void>(payload, dtor);
    } else {
      cpp_payload = std::shared_ptr<void>(payload, [](void*) {});
    }
    auto cpp_closure = [fun, cpp_payload, dtor](CPPARGS_TYPE_NAME) {
      CPPARGS_TO_CARGS
      RCARGS_NEW
      auto status = fun(RCARGS_UNTYPED, CARGS_UNTYPED, cpp_payload.get());
      CARGS_FREE
      if(status) {
        RCARGS_FREE
        throw std::runtime_error("NAME returned a non-zero value");
      }
      RCARGS_TO_CPP
      RCARGS_FREE
      return cpp_res;
    };
    return NAME_new_(cpp_closure);
  } catch (std::exception& e) {
    mlx_error(e.what());
    return NAME_new_();
  }
}

extern "C" int NAME_apply(RCARGS, NAME cls, CARGS) {
  try {
    ASSIGN_CLS_TO_RCARGS
  } catch (std::exception& e) {
    mlx_error(e.what());
    return 1;
  }
  return 0;
}
"""

priv_code = None

decl_begin = """/* Copyright © 2023-2024 Apple Inc.                   */
/*                                                    */
/* This file is auto-generated. Do not edit manually. */
/*                                                    */

#ifndef MLX_CLOSURE_H
#define MLX_CLOSURE_H

#include "mlx/c/array.h"
#include "mlx/c/map.h"
#include "mlx/c/optional.h"
#include "mlx/c/stream.h"
#include "mlx/c/vector.h"

#ifdef __cplusplus
extern "C" {
#endif

/**
 * \defgroup mlx_closure Closures
 * MLX closure objects.
 */
/**@{*/
"""

decl_end = """
/**@}*/

#ifdef __cplusplus
}
#endif

#endif
"""

impl_begin = """/* Copyright © 2023-2024 Apple Inc.                   */
/*                                                    */
/* This file is auto-generated. Do not edit manually. */
/*                                                    */

#include "mlx/c/closure.h"
#include "mlx/c/error.h"
#include "mlx/c/private/mlx.h"
"""

impl_end = """
"""

priv_begin = """/* Copyright © 2023-2024 Apple Inc.                   */
/*                                                    */
/* This file is auto-generated. Do not edit manually. */
/*                                                    */

#ifndef MLX_CLOSURE_PRIVATE_H
#define MLX_CLOSURE_PRIVATE_H

#include "mlx/c/closure.h"
#include "mlx/mlx.h"

"""

priv_end = """
#endif
"""

if args.implementation:
    code = impl_code
    begin = impl_begin
    end = impl_end
elif args.private:
    code = priv_code
    begin = priv_begin
    end = priv_end
else:
    code = decl_code
    begin = decl_begin
    end = decl_end


print(begin)
print(
    generate(
        code,
        "mlx_closure",
        "std::vector<mlx::core::array>",
        ["std::vector<mlx::core::array>"],
    )
)
if args.implementation:
    print(
        """
extern "C" mlx_closure mlx_closure_new_unary(
    int (*fun)(mlx_array*, const mlx_array)) {
  try {
    auto cpp_closure = [fun](const std::vector<mlx::core::array>& cpp_input) {
      if (cpp_input.size() != 1) {
        throw std::runtime_error("closure: expected unary input");
      }
      auto input = mlx_array_new_(cpp_input[0]);
      auto res = mlx_array_new_();
      auto status = fun(&res, input);
      if(status) {
        mlx_array_free_(res);
        throw std::runtime_error("mlx_closure returned a non-zero value");
      }
      mlx_array_free(input);
      std::vector<mlx::core::array> cpp_res = {mlx_array_get_(res)};
      mlx_array_free(res);
      return cpp_res;
    };
    return mlx_closure_new_(cpp_closure);
  } catch (std::exception& e) {
    mlx_error(e.what());
    return mlx_closure_new_();
  }
}
"""
    )
elif args.private:
    pass
else:
    print(
        """
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
    """
    )
print(
    generate(
        code,
        "mlx_closure_kwargs",
        "std::vector<mlx::core::array>",
        [
            "std::vector<mlx::core::array>",
            "std::unordered_map<std::string, mlx::core::array>",
        ],
    )
)
print(
    generate(
        code,
        "mlx_closure_value_and_grad",
        "std::pair<std::vector<mlx::core::array>, std::vector<mlx::core::array>>",
        ["std::vector<mlx::core::array>"],
    )
)
print(
    generate(
        code,
        "mlx_closure_custom",
        "std::vector<mlx::core::array>",
        ["std::vector<mlx::core::array>"] * 3,
    )
)
print(
    generate(
        code,
        "mlx_closure_custom_jvp",
        "std::vector<mlx::core::array>",
        [
            "std::vector<mlx::core::array>",
            "std::vector<mlx::core::array>",
            "std::vector<int>",
        ],
    )
)
print(
    generate(
        code,
        "mlx_closure_custom_vmap",
        "std::pair<std::vector<mlx::core::array>, @std::vector<int>>",
        ["std::vector<mlx::core::array>", "std::vector<int>"],
    )
)
if args.private:
    print(
        """
     """
    )

print(end)