mlxrs-sys 0.1.0

Bindings for MLX-C API
// Copyright © 2023-2025 Apple Inc.

#include <optional>
#include <sstream>

#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>

#include "mlx/device.h"
#include "mlx/utils.h"

namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;

void init_device(nb::module_& m) {
  auto device_class = nb::class_<mx::Device>(
      m, "Device", R"pbdoc(A device to run operations on.)pbdoc");
  nb::enum_<mx::Device::DeviceType>(m, "DeviceType")
      .value("cpu", mx::Device::DeviceType::cpu)
      .value("gpu", mx::Device::DeviceType::gpu)
      .export_values()
      .def(
          "__eq__",
          [](const mx::Device::DeviceType& d, const nb::object& other) {
            if (!nb::isinstance<mx::Device>(other) &&
                !nb::isinstance<mx::Device::DeviceType>(other)) {
              return false;
            }
            return d == nb::cast<mx::Device>(other);
          });

  device_class
      .def(nb::init<mx::Device::DeviceType, int>(), "type"_a, "index"_a = 0)
      .def_ro("type", &mx::Device::type)
      .def(
          "__repr__",
          [](const mx::Device& d) {
            std::ostringstream os;
            os << d;
            return os.str();
          })
      .def("__eq__", [](const mx::Device& d, const nb::object& other) {
        if (!nb::isinstance<mx::Device>(other) &&
            !nb::isinstance<mx::Device::DeviceType>(other)) {
          return false;
        }
        return d == nb::cast<mx::Device>(other);
      });

  nb::implicitly_convertible<mx::Device::DeviceType, mx::Device>();

  m.def(
      "default_device",
      &mx::default_device,
      R"pbdoc(Get the default device.)pbdoc");
  m.def(
      "set_default_device",
      &mx::set_default_device,
      "device"_a,
      R"pbdoc(Set the default device.)pbdoc");
  m.def(
      "is_available",
      &mx::is_available,
      "device"_a,
      R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
  m.def(
      "device_count",
      &mx::device_count,
      "device_type"_a,
      R"pbdoc(
      Get the number of available devices for the given device type.

      Args:
          device_type (DeviceType): The type of device to query (cpu or gpu).

      Returns:
          int: Number of devices.
      )pbdoc");
  m.def(
      "device_info",
      [](std::optional<mx::Device> d) {
        return mx::device_info(d.value_or(mx::default_device()));
      },
      "d"_a = nb::none(),
      R"pbdoc(
      Get information about a device.

      Returns a dictionary with device properties. Available keys depend
      on the backend and device type. Common keys include ``device_name``,
      ``architecture``, and ``total_memory`` (or ``memory_size``).

      Args:
          d (Device): The device to query (defaults to the default device).

      Returns:
          dict: Device information.
      )pbdoc");
}