#include <numeric>
#include <optional>
#include <sstream>
#include <nanobind/ndarray.h>
#include "mlx/ops.h"
#include "python/src/convert.h"
#include "python/src/indexing.h"
bool is_none_slice(const nb::slice& in_slice) {
return (
nb::getattr(in_slice, "start").is_none() &&
nb::getattr(in_slice, "stop").is_none() &&
nb::getattr(in_slice, "step").is_none());
}
bool is_index_scalar(const nb::object& obj) {
if (nb::isinstance<nb::bool_>(obj)) {
return false;
}
if (!PyIndex_Check(obj.ptr())) {
return false;
}
if (nb::ndarray_check(obj) && nb::cast<nb::ndarray<nb::ro>>(obj).ndim() > 0) {
return false;
}
return true;
}
int safe_to_int32(nb::object obj) {
auto idx = nb::steal<nb::object>(PyNumber_Index(obj.ptr()));
if (!idx.is_valid()) {
throw nb::python_error();
}
auto val = nb::cast<int64_t>(nb::cast<nb::int_>(idx));
if (val > INT32_MAX || val < INT32_MIN) {
throw std::invalid_argument("Slice indices must be 32-bit integers.");
}
return static_cast<int>(val);
}
int get_slice_int(nb::object obj, int default_val) {
if (!obj.is_none()) {
if (!is_index_scalar(obj)) {
throw std::invalid_argument("Slice indices must be integers or None.");
}
return safe_to_int32(obj);
}
return default_val;
}
void get_slice_params(
mx::ShapeElem& starts,
mx::ShapeElem& ends,
mx::ShapeElem& strides,
const nb::slice& in_slice,
int axis_size) {
strides = get_slice_int(nb::getattr(in_slice, "step"), 1);
starts = get_slice_int(
nb::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
ends = get_slice_int(
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
}
mx::array get_int_index(nb::object idx, int axis_size) {
int idx_ = safe_to_int32(idx);
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
return mx::array(idx_, mx::uint32);
}
bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || is_index_scalar(obj) ||
nb::isinstance<mx::array>(obj) || obj.is_none() ||
nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
}
mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
if (is_none_slice(in_slice)) {
return src;
}
mx::Shape starts(src.ndim(), 0);
auto ends = src.shape();
mx::Shape strides(src.ndim(), 1);
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
return slice(src, starts, ends, strides);
}
mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
if (indices.dtype() == mx::bool_) {
throw std::invalid_argument("boolean indices are not yet supported");
}
return take(src, indices, 0);
}
mx::array mlx_get_item_int(const mx::array& src, const nb::object& idx) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
return take(src, get_int_index(idx, src.shape(0)), 0);
}
mx::array mlx_gather_nd(
mx::array src,
const std::vector<nb::object>& indices,
bool gather_first,
int& max_dims) {
max_dims = 0;
std::vector<mx::array> gather_indices;
std::vector<bool> is_slice(indices.size(), false);
int num_slices = 0;
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (nb::isinstance<nb::slice>(idx)) {
mx::ShapeElem start, end, stride;
get_slice_params(
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
start = (start < 0) ? start + src.shape(i) : start;
end = (end < 0) ? end + src.shape(i) : end;
gather_indices.push_back(arange(start, end, stride, mx::uint32));
num_slices++;
is_slice[i] = true;
} else if (is_index_scalar(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (nb::isinstance<mx::array>(idx)) {
auto arr = nb::cast<mx::array>(idx);
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
gather_indices.push_back(arr);
}
}
if (gather_first) {
int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) {
if (is_slice[i]) {
mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
slice_index++;
} else {
auto index_shape = gather_indices[i].shape();
index_shape.insert(index_shape.end(), num_slices, 1);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
}
}
} else {
for (int i = 0; i < gather_indices.size(); i++) {
if (i < num_slices) {
mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
}
}
}
std::vector<int> axes(indices.size());
std::iota(axes.begin(), axes.end(), 0);
auto slice_sizes = src.shape();
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
src = gather(src, gather_indices, axes, slice_sizes);
for (auto& ax : axes) {
ax += max_dims + num_slices;
}
return mx::squeeze(src, axes);
}
auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
std::vector<nb::object> indices;
int non_none_indices_before = 0;
int non_none_indices_after = 0;
std::vector<nb::object> r_indices;
int i = 0;
bool has_ellipsis = false;
for (; i < entries.size(); i++) {
auto idx = entries[i];
if (!is_valid_index_type(idx)) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
if (!nb::ellipsis().is(idx)) {
indices.push_back(idx);
non_none_indices_before += !idx.is_none();
} else {
has_ellipsis = true;
break;
}
}
for (int j = entries.size() - 1; j > i; j--) {
auto idx = entries[j];
if (!is_valid_index_type(idx)) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
if (nb::ellipsis().is(idx)) {
throw std::invalid_argument(
"An index can only have a single ellipsis (...)");
}
r_indices.push_back(idx);
non_none_indices_after += !idx.is_none();
}
int non_none_indices = non_none_indices_before + non_none_indices_after;
if (has_ellipsis) {
for (int axis = non_none_indices_before;
axis < shape.size() - non_none_indices_after;
axis++) {
indices.push_back(
nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));
non_none_indices++;
}
}
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
return std::make_pair(non_none_indices, indices);
}
mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
if (entries.size() == 0) {
return src;
}
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
for (auto& idx : indices) {
if (nb::isinstance<nb::list>(idx)) {
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
}
}
if (non_none_indices > src.ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
std::vector<nb::object> remaining_indices;
bool have_array = false;
{
bool have_non_array = false;
bool gather_first = false;
for (auto& idx : indices) {
if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {
if (have_array && have_non_array) {
gather_first = true;
break;
}
have_array = true;
} else {
have_non_array |= have_array;
}
}
int n_arr = 0;
for (auto& idx : indices) {
n_arr += nb::isinstance<mx::array>(idx);
}
have_array &= n_arr > 0;
if (have_array) {
int last_array;
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array];
if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {
break;
}
}
std::vector<nb::object> gather_indices;
for (int i = 0; i <= last_array; i++) {
auto& idx = indices[i];
if (!idx.is_none()) {
gather_indices.push_back(idx);
}
}
int max_dims;
src = mlx_gather_nd(src, gather_indices, gather_first, max_dims);
if (gather_first) {
for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back(
nb::slice(nb::none(), nb::none(), nb::none()));
}
for (int i = 0; i < last_array; i++) {
auto& idx = indices[i];
if (idx.is_none()) {
remaining_indices.push_back(indices[i]);
} else if (nb::isinstance<nb::slice>(idx)) {
remaining_indices.push_back(
nb::slice(nb::none(), nb::none(), nb::none()));
}
}
for (int i = last_array + 1; i < indices.size(); i++) {
remaining_indices.push_back(indices[i]);
}
} else {
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (nb::isinstance<mx::array>(idx) || is_index_scalar(idx)) {
break;
} else if (idx.is_none()) {
remaining_indices.push_back(idx);
} else {
remaining_indices.push_back(
nb::slice(nb::none(), nb::none(), nb::none()));
}
}
for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back(
nb::slice(nb::none(), nb::none(), nb::none()));
}
for (int i = last_array + 1; i < indices.size(); i++) {
remaining_indices.push_back(indices[i]);
}
}
}
}
if (have_array && remaining_indices.empty()) {
return src;
}
if (remaining_indices.empty()) {
remaining_indices = indices;
}
bool squeeze_needed = false;
bool unsqueeze_needed = false;
{
mx::Shape starts(src.ndim(), 0);
auto ends = src.shape();
mx::Shape strides(src.ndim(), 1);
int axis = 0;
for (auto& idx : remaining_indices) {
if (!idx.is_none()) {
if (!have_array && is_index_scalar(idx)) {
int st = safe_to_int32(idx);
st = (st < 0) ? st + src.shape(axis) : st;
starts[axis] = st;
ends[axis] = st + 1;
squeeze_needed = true;
} else {
get_slice_params(
starts[axis],
ends[axis],
strides[axis],
nb::cast<nb::slice>(idx),
ends[axis]);
}
axis++;
} else {
unsqueeze_needed = true;
}
}
src = slice(src, starts, ends, strides);
}
if (unsqueeze_needed || squeeze_needed) {
std::vector<int> squeeze_axes;
std::vector<int> unsqueeze_axes;
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
auto& idx = remaining_indices[axis];
if (unsqueeze_needed && idx.is_none()) {
unsqueeze_axes.push_back(axis - squeeze_axes.size());
} else if (squeeze_needed && is_index_scalar(idx)) {
squeeze_axes.push_back(axis - unsqueeze_axes.size());
}
}
if (!squeeze_axes.empty()) {
src = squeeze(src, std::move(squeeze_axes));
}
if (!unsqueeze_axes.empty()) {
src = expand_dims(src, std::move(unsqueeze_axes));
}
}
return src;
}
mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
if (nb::isinstance<nb::slice>(obj)) {
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
} else if (nb::isinstance<mx::array>(obj)) {
return mlx_get_item_array(src, nb::cast<mx::array>(obj));
} else if (is_index_scalar(obj)) {
return mlx_get_item_int(src, obj);
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
} else if (nb::isinstance<nb::ellipsis>(obj)) {
return src;
} else if (obj.is_none()) {
return expand_dims(src, 0);
} else if (nb::isinstance<nb::list>(obj)) {
return mlx_get_item_array(
src, array_from_list(nb::cast<nb::list>(obj), {}));
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_int(
const mx::array& src,
const nb::object& idx,
const mx::array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());
auto shape = src.shape();
shape[0] = 1;
return {
{get_int_index(idx, src.shape(0))},
broadcast_to(reshape(update, up_shape), shape),
{0}};
}
mx::array squeeze_leading_singletons(const mx::array& in) {
int s = 0;
for (; s < in.ndim() && in.shape(s) == 1; s++)
;
auto squeeze_axes = std::vector<int>(s);
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
return mx::squeeze(in, squeeze_axes);
}
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_array(
const mx::array& src,
const mx::array& indices,
const mx::array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
auto up = squeeze_leading_singletons(update);
auto up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
return {{indices}, up, {0}};
}
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_slice(
const mx::array& src,
const nb::slice& in_slice,
const mx::array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
"too many indices for array: array is 0-dimensional");
}
if (is_none_slice(in_slice)) {
return {
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
}
mx::ShapeElem start = 0;
auto end = src.shape(0);
mx::ShapeElem stride = 1;
get_slice_params(start, end, stride, in_slice, end);
if (stride == 1) {
auto up = squeeze_leading_singletons(update);
auto idx = mx::array({start}, {1}, mx::uint32);
int slice_size = (end - start);
mx::Shape up_shape_broadcast = {1, slice_size};
up_shape_broadcast.insert(
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape_broadcast);
auto indices = std::vector<mx::array>{idx};
auto axes = std::vector<int>{0};
return {indices, up, axes};
}
return mlx_scatter_args_array(
src, arange(start, end, stride, mx::uint32), update);
}
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_nd(
const mx::array& src,
const nb::tuple& entries,
const mx::array& update) {
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
for (auto& idx : indices) {
if (nb::isinstance<nb::list>(idx)) {
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
}
}
if (non_none_indices > src.ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto up = squeeze_leading_singletons(update);
if (non_none_indices == 0) {
return {{}, broadcast_to(up, src.shape()), {}};
}
size_t max_dim = 0;
bool arrays_first = false;
int num_none = 0;
int num_slices = 0;
int num_arrays = 0;
int num_strided_slices = 0;
int num_simple_slices_post = 0;
{
bool have_array = false;
bool have_non_array = false;
for (auto& idx : indices) {
if (idx.is_none()) {
have_non_array = have_array;
num_none++;
} else if (nb::isinstance<nb::slice>(idx)) {
have_non_array = have_array;
num_slices++;
auto slice = nb::cast<nb::slice>(idx);
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
if (stride != 1) {
num_strided_slices++;
num_simple_slices_post = 0;
} else {
num_simple_slices_post++;
}
} else if (nb::isinstance<mx::array>(idx)) {
have_array = true;
if (have_array && have_non_array) {
arrays_first = true;
}
max_dim = std::max(nb::cast<mx::array>(idx).ndim(), max_dim);
num_arrays++;
num_simple_slices_post = 0;
}
}
}
int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post;
idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
std::vector<mx::array> arr_indices;
int slice_num = 0;
int array_num = 0;
int ax = 0;
std::vector<int> update_shape(non_none_indices, 1);
std::vector<int> slice_shapes;
for (int i = 0; i < indices.size(); ++i) {
auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) {
mx::ShapeElem start, end, stride;
auto axis_size = src.shape(ax++);
get_slice_params(
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
start = (start < 0) ? start + axis_size : start;
end = (end < 0) ? end + axis_size : end;
mx::Shape idx_shape(idx_ndim, 1);
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
auto idx = mx::array({start}, idx_shape, mx::uint32);
slice_shapes.push_back(end - start);
arr_indices.push_back(idx);
update_shape[ax - 1] = slice_shapes.back();
}
else {
auto idx = arange(start, end, stride, mx::uint32);
auto loc = slice_num + (arrays_first ? max_dim : 0);
idx_shape[loc] = idx.size();
arr_indices.push_back(reshape(idx, idx_shape));
slice_num++;
num_strided_slices--;
update_shape[ax - 1] = 1;
}
} else if (is_index_scalar(pyidx)) {
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
update_shape[ax - 1] = 1;
} else if (pyidx.is_none()) {
slice_num++;
} else if (nb::isinstance<mx::array>(pyidx)) {
ax++;
auto idx = nb::cast<mx::array>(pyidx);
mx::Shape idx_shape(idx_ndim, 1);
int st = (!arrays_first) * slice_num + max_dim - idx.ndim();
for (int j = 0; j < idx.ndim(); j++) {
idx_shape[st + j] = idx.shape()[j];
}
arr_indices.push_back(reshape(idx, idx_shape));
if (!arrays_first && ++array_num == num_arrays) {
slice_num += max_dim;
}
update_shape[ax - 1] = 1;
} else {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
}
arr_indices = broadcast_arrays(arr_indices);
auto up_shape_broadcast = arr_indices[0].shape();
up_shape_broadcast.insert(
up_shape_broadcast.end(), slice_shapes.begin(), slice_shapes.end());
up_shape_broadcast.insert(
up_shape_broadcast.end(),
src.shape().begin() + non_none_indices,
src.shape().end());
up = broadcast_to(up, up_shape_broadcast);
auto up_reshape = arr_indices[0].shape();
up_reshape.insert(up_reshape.end(), update_shape.begin(), update_shape.end());
up_reshape.insert(
up_reshape.end(),
src.shape().begin() + non_none_indices,
src.shape().end());
up = reshape(up, up_reshape);
std::vector<int> axes(arr_indices.size(), 0);
std::iota(axes.begin(), axes.end(), 0);
return {arr_indices, up, axes};
}
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_compute_scatter_args(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
if (nb::isinstance<nb::slice>(obj)) {
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
} else if (nb::isinstance<mx::array>(obj)) {
return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);
} else if (is_index_scalar(obj)) {
return mlx_scatter_args_int(src, obj, vals);
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
} else if (obj.is_none()) {
return {{}, broadcast_to(vals, src.shape()), {}};
} else if (nb::isinstance<nb::list>(obj)) {
return mlx_scatter_args_array(
src, array_from_list(nb::cast<nb::list>(obj), {}), vals);
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
std::tuple<std::optional<mx::array>, mx::Shape, mx::Shape, mx::Shape>
mlx_compute_slice_update_args(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
mx::Shape starts(src.ndim(), 0);
mx::Shape stops = src.shape();
mx::Shape strides(src.ndim(), 1);
if (src.ndim() == 0 ||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
!is_index_scalar(obj))) {
return std::make_tuple(
std::nullopt, std::move(starts), std::move(stops), std::move(strides));
}
if (nb::isinstance<nb::tuple>(obj)) {
for (auto idx : nb::cast<nb::tuple>(obj)) {
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::list>(idx)) {
return std::make_tuple(
std::nullopt,
std::move(starts),
std::move(stops),
std::move(strides));
}
}
}
auto update = to_array(v, src.dtype());
int s = 0;
for (; s < static_cast<int>(update.ndim()) - 1 && update.shape(s) == 1 &&
(update.ndim() - s) > src.ndim();
s++) {
};
auto squeeze_axes = std::vector<int>(s);
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
update = mx::squeeze(update, squeeze_axes);
if (is_index_scalar(obj)) {
if (src.ndim() < 1) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto idx = safe_to_int32(obj);
idx = idx < 0 ? idx + stops[0] : idx;
starts[0] = idx;
stops[0] = idx + 1;
return std::make_tuple(
update, std::move(starts), std::move(stops), std::move(strides));
}
if (nb::isinstance<nb::slice>(obj)) {
get_slice_params(
starts[0],
stops[0],
strides[0],
nb::cast<nb::slice>(obj),
src.shape(0));
return std::make_tuple(
update, std::move(starts), std::move(stops), std::move(strides));
}
auto entries = nb::cast<nb::tuple>(obj);
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
if (non_none_indices > src.ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (non_none_indices == 0) {
return std::make_tuple(
broadcast_to(update, src.shape()),
std::move(starts),
std::move(stops),
std::move(strides));
}
int unspecified = src.ndim() - non_none_indices;
std::vector<int> squeeze_dims;
std::vector<int> expand_dims;
for (int i = indices.size() - 1,
ax = non_none_indices - 1,
upd_ax = update.ndim() - unspecified - 1;
i >= 0;
--i) {
auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) {
get_slice_params(
starts[ax],
stops[ax],
strides[ax],
nb::cast<nb::slice>(pyidx),
src.shape(ax));
ax--;
upd_ax--;
} else if (is_index_scalar(pyidx)) {
int st = safe_to_int32(pyidx);
st = (st < 0) ? st + src.shape(i) : st;
starts[ax] = st;
stops[ax] = st + 1;
if (upd_ax >= 0) {
expand_dims.push_back(i - indices.size() - unspecified);
}
ax--;
} else if (pyidx.is_none()) {
if (upd_ax-- >= 0) {
squeeze_dims.push_back(i - indices.size() - unspecified);
}
}
}
update = mx::squeeze(
mx::expand_dims(update, std::move(expand_dims)), std::move(squeeze_dims));
return std::make_tuple(
update, std::move(starts), std::move(stops), std::move(strides));
}
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
if (nb::isinstance<nb::bool_>(obj)) {
return mx::array(nb::cast<bool>(obj), mx::bool_);
} else if (nb::isinstance<mx::array>(obj)) {
auto mask = nb::cast<mx::array>(obj);
if (mask.dtype() == mx::bool_) {
return mask;
}
} else if (nb::isinstance<NDArray>(obj)) {
auto mask = nb::cast<NDArray>(obj);
if (mask.dtype() == nb::dtype<bool>()) {
return nd_array_to_mlx(mask, mx::bool_);
}
} else if (nb::isinstance<nb::list>(obj)) {
auto mask = array_from_list(nb::cast<nb::list>(obj), {});
if (mask.dtype() == mx::bool_) {
return mask;
}
}
return std::nullopt;
}
void mlx_set_item(
mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [update, starts, stops, strides] =
mlx_compute_slice_update_args(src, obj, v);
if (update) {
src.overwrite_descriptor(
slice_update(src, *update, starts, stops, strides));
return;
}
if (auto mask = extract_boolean_mask(obj)) {
auto updates = to_array(v, src.dtype());
auto result = masked_scatter(src, *mask, updates);
src.overwrite_descriptor(result);
return;
}
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
auto out = scatter(src, indices, updates, axes);
src.overwrite_descriptor(out);
} else {
src.overwrite_descriptor(updates);
}
}
mx::array mlx_add_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [update, starts, stops, strides] =
mlx_compute_slice_update_args(src, obj, v);
if (update) {
return slice_update_add(src, *update, starts, stops, strides);
}
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_add(src, indices, updates, axes);
} else {
return src + updates;
}
}
mx::array mlx_subtract_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [update, starts, stops, strides] =
mlx_compute_slice_update_args(src, obj, v);
if (update) {
return slice_update_add(src, -(*update), starts, stops, strides);
}
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_add(src, indices, -updates, axes);
} else {
return src - updates;
}
}
mx::array mlx_multiply_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [update, starts, stops, strides] =
mlx_compute_slice_update_args(src, obj, v);
if (update) {
return slice_update_prod(src, *update, starts, stops, strides);
}
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_prod(src, indices, updates, axes);
} else {
return src * updates;
}
}
mx::array mlx_divide_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [update, starts, stops, strides] =
mlx_compute_slice_update_args(src, obj, v);
if (update) {
return slice_update_prod(src, reciprocal(*update), starts, stops, strides);
}
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_prod(src, indices, reciprocal(updates), axes);
} else {
return src / updates;
}
}
mx::array mlx_maximum_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [update, starts, stops, strides] =
mlx_compute_slice_update_args(src, obj, v);
if (update) {
return slice_update_max(src, *update, starts, stops, strides);
}
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_max(src, indices, updates, axes);
} else {
return maximum(src, updates);
}
}
mx::array mlx_minimum_item(
const mx::array& src,
const nb::object& obj,
const ScalarOrArray& v) {
auto [update, starts, stops, strides] =
mlx_compute_slice_update_args(src, obj, v);
if (update) {
return slice_update_min(src, *update, starts, stops, strides);
}
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_min(src, indices, updates, axes);
} else {
return minimum(src, updates);
}
}