mlxrs-sys 0.1.0

Bindings for MLX-C API
// Copyright © 2023-2024 Apple Inc.

#include "python/src/trees.h"

template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<nb::object>& subtrees) {
  int len = nb::cast<T>(subtrees[0]).size();
  for (auto& subtree : subtrees) {
    if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||
        nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
      throw std::invalid_argument(
          "[tree_map] Additional input tree is not a valid prefix of the first tree.");
    }
  }
}

nb::object tree_map(
    const std::vector<nb::object>& trees,
    std::function<nb::object(const std::vector<nb::object>&)> transform) {
  std::function<nb::object(const std::vector<nb::object>&)> recurse;

  recurse = [&](const std::vector<nb::object>& subtrees) {
    if (nb::isinstance<nb::list>(subtrees[0])) {
      nb::list l;
      std::vector<nb::object> items(subtrees.size());
      validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
      for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
        for (int j = 0; j < subtrees.size(); ++j) {
          if (nb::isinstance<nb::list>(subtrees[j])) {
            items[j] = nb::cast<nb::list>(subtrees[j])[i];
          } else {
            items[j] = subtrees[j];
          }
        }
        l.append(recurse(items));
      }
      return nb::cast<nb::object>(l);
    } else if (nb::isinstance<nb::tuple>(subtrees[0])) {
      //  Check the rest of the subtrees
      std::vector<nb::object> items(subtrees.size());
      int len = nb::cast<nb::tuple>(subtrees[0]).size();
      nb::list l;
      validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
      auto type = subtrees[0].type();
      for (int i = 0; i < len; ++i) {
        for (int j = 0; j < subtrees.size(); ++j) {
          if (nb::isinstance<nb::tuple>(subtrees[j])) {
            items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
          } else {
            items[j] = subtrees[j];
          }
        }
        l.append(recurse(items));
      }
      if (PyTuple_CheckExact(subtrees[0].ptr())) {
        return nb::cast<nb::object>(nb::tuple(l));
      }
      return nb::hasattr(type, "_fields") ? type(*l) : type(l);
    } else if (nb::isinstance<nb::dict>(subtrees[0])) {
      std::vector<nb::object> items(subtrees.size());
      validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
      nb::dict d;
      for (auto item : nb::cast<nb::dict>(subtrees[0])) {
        for (int j = 0; j < subtrees.size(); ++j) {
          if (nb::isinstance<nb::dict>(subtrees[j])) {
            auto subdict = nb::cast<nb::dict>(subtrees[j]);
            if (!subdict.contains(item.first)) {
              throw std::invalid_argument(
                  "[tree_map] Tree is not a valid prefix tree of the first tree.");
            }
            items[j] = subdict[item.first];
          } else {
            items[j] = subtrees[j];
          }
        }
        d[item.first] = recurse(items);
      }
      return nb::cast<nb::object>(d);
    } else {
      return transform(subtrees);
    }
  };
  return recurse(trees);
}

nb::object tree_map(
    nb::object tree,
    std::function<nb::object(nb::handle)> transform) {
  return tree_map({tree}, [&](std::vector<nb::object> inputs) {
    return transform(inputs[0]);
  });
}

void tree_visit(
    const std::vector<nb::object>& trees,
    std::function<void(const std::vector<nb::object>&)> visitor) {
  std::function<void(const std::vector<nb::object>&)> recurse;

  recurse = [&](const std::vector<nb::object>& subtrees) {
    if (nb::isinstance<nb::list>(subtrees[0])) {
      std::vector<nb::object> items(subtrees.size());
      validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
      for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
        for (int j = 0; j < subtrees.size(); ++j) {
          if (nb::isinstance<nb::list>(subtrees[j])) {
            items[j] = nb::cast<nb::list>(subtrees[j])[i];
          } else {
            items[j] = subtrees[j];
          }
        }
        recurse(items);
      }
    } else if (nb::isinstance<nb::tuple>(subtrees[0])) {
      //  Check the rest of the subtrees
      std::vector<nb::object> items(subtrees.size());
      int len = nb::cast<nb::tuple>(subtrees[0]).size();
      validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
      for (int i = 0; i < len; ++i) {
        for (int j = 0; j < subtrees.size(); ++j) {
          if (nb::isinstance<nb::tuple>(subtrees[j])) {
            items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
          } else {
            items[j] = subtrees[j];
          }
        }
        recurse(items);
      }
    } else if (nb::isinstance<nb::dict>(subtrees[0])) {
      std::vector<nb::object> items(subtrees.size());
      validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
      for (auto item : nb::cast<nb::dict>(subtrees[0])) {
        for (int j = 0; j < subtrees.size(); ++j) {
          if (nb::isinstance<nb::dict>(subtrees[j])) {
            auto subdict = nb::cast<nb::dict>(subtrees[j]);
            if (!subdict.contains(item.first)) {
              throw std::invalid_argument(
                  "[tree_visit] Tree is not a valid prefix tree of the first tree.");
            }
            items[j] = subdict[item.first];
          } else {
            items[j] = subtrees[j];
          }
        }
        recurse(items);
      }
    } else {
      visitor(subtrees);
    }
  };
  return recurse(trees);
}

void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
  std::function<void(nb::handle)> recurse;
  recurse = [&](nb::handle subtree) {
    if (nb::isinstance<nb::list>(subtree) ||
        nb::isinstance<nb::tuple>(subtree)) {
      for (auto item : subtree) {
        recurse(item);
      }
    } else if (nb::isinstance<nb::dict>(subtree)) {
      for (auto item : nb::cast<nb::dict>(subtree)) {
        recurse(item.second);
      }
    } else {
      visitor(subtree);
    }
  };

  recurse(tree);
}

void tree_visit_update(
    nb::object tree,
    std::function<nb::object(nb::handle)> visitor) {
  std::function<nb::object(nb::handle)> recurse;
  recurse = [&](nb::handle subtree) {
    if (nb::isinstance<nb::list>(subtree)) {
      auto l = nb::cast<nb::list>(subtree);
      for (int i = 0; i < l.size(); ++i) {
        l[i] = recurse(l[i]);
      }
      return nb::cast<nb::object>(l);
    } else if (nb::isinstance<nb::tuple>(subtree)) {
      auto type = subtree.type();
      nb::list l(subtree);
      for (int i = 0; i < l.size(); ++i) {
        l[i] = recurse(l[i]);
      }
      if (PyTuple_CheckExact(subtree.ptr())) {
        return nb::cast<nb::object>(nb::tuple(l));
      }
      return nb::hasattr(type, "_fields") ? type(*l) : type(l);
    } else if (nb::isinstance<nb::dict>(subtree)) {
      auto d = nb::cast<nb::dict>(subtree);
      for (auto item : d) {
        d[item.first] = recurse(item.second);
      }
      return nb::cast<nb::object>(d);
    } else if (nb::isinstance<mx::array>(subtree)) {
      return visitor(subtree);
    } else {
      return nb::cast<nb::object>(subtree);
    }
  };
  recurse(tree);
}

// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
  size_t index = 0;
  tree_visit_update(
      tree, [&](nb::handle node) { return nb::cast(values[index++]); });
}

// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
    nb::object& tree,
    const std::vector<mx::array>& src,
    const std::vector<mx::array>& dst) {
  std::unordered_map<uintptr_t, mx::array> src_to_dst;
  for (int i = 0; i < src.size(); ++i) {
    src_to_dst.insert({src[i].id(), dst[i]});
  }
  tree_visit_update(tree, [&](nb::handle node) {
    auto arr = nb::cast<mx::array>(node);
    if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
      return nb::cast(it->second);
    }
    return nb::cast(arr);
  });
}

std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {
  std::vector<mx::array> flat_tree;

  tree_visit(tree, [&](nb::handle obj) {
    if (nb::isinstance<mx::array>(obj)) {
      flat_tree.push_back(nb::cast<mx::array>(obj));
    } else if (strict) {
      throw std::invalid_argument(
          "[tree_flatten] The argument should contain only arrays");
    }
  });

  return flat_tree;
}

nb::object tree_unflatten(
    nb::object tree,
    const std::vector<mx::array>& values,
    int index /* = 0 */) {
  return tree_map(tree, [&](nb::handle obj) {
    if (nb::isinstance<mx::array>(obj)) {
      return nb::cast(values[index++]);
    } else {
      return nb::cast<nb::object>(obj);
    }
  });
}

nb::object structure_sentinel() {
  static nb::object sentinel;

  if (sentinel.ptr() == nullptr) {
    sentinel = nb::capsule(&sentinel);
    // probably not needed but this should make certain that we won't ever
    // delete the sentinel
    sentinel.inc_ref();
  }

  return sentinel;
}

std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure(
    nb::object tree,
    bool strict /* = true */) {
  auto sentinel = structure_sentinel();
  std::vector<mx::array> flat_tree;
  auto structure = tree_map(
      tree,
      [&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) {
        if (nb::isinstance<mx::array>(obj)) {
          flat_tree.push_back(nb::cast<mx::array>(obj));
          return sentinel;
        } else if (!strict) {
          return nb::cast<nb::object>(obj);
        } else {
          throw std::invalid_argument(
              "[tree_flatten] The argument should contain only arrays");
        }
      });

  return {flat_tree, structure};
}

nb::object tree_unflatten_from_structure(
    nb::object structure,
    const std::vector<mx::array>& values,
    int index /* = 0 */) {
  auto sentinel = structure_sentinel();
  return tree_map(structure, [&](nb::handle obj) {
    if (obj.is(sentinel)) {
      return nb::cast(values[index++]);
    } else {
      return nb::cast<nb::object>(obj);
    }
  });
}