#include "mlx/c/fast.h"
#include "mlx/c/error.h"
#include "mlx/c/private/mlx.h"
#include "mlx/fast.h"
extern "C" int mlx_fast_affine_dequantize(
mlx_array* res,
const mlx_array w,
const mlx_array scales,
const mlx_array biases,
int group_size,
int bits,
const mlx_stream s) {
try {
mlx_array_set_(
*res,
mlx::core::fast::affine_dequantize(
mlx_array_get_(w),
mlx_array_get_(scales),
mlx_array_get_(biases),
group_size,
bits,
mlx_stream_get_(s)));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_affine_quantize(
mlx_array* res_0,
mlx_array* res_1,
mlx_array* res_2,
const mlx_array w,
int group_size,
int bits,
const mlx_stream s) {
try {
{
auto [tpl_0, tpl_1, tpl_2] = mlx::core::fast::affine_quantize(
mlx_array_get_(w), group_size, bits, mlx_stream_get_(s));
mlx_array_set_(*res_0, tpl_0);
mlx_array_set_(*res_1, tpl_1);
mlx_array_set_(*res_2, tpl_2);
};
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_layer_norm(
mlx_array* res,
const mlx_array x,
const mlx_array weight ,
const mlx_array bias ,
float eps,
const mlx_stream s) {
try {
mlx_array_set_(
*res,
mlx::core::fast::layer_norm(
mlx_array_get_(x),
(weight.ctx ? std::make_optional(mlx_array_get_(weight))
: std::nullopt),
(bias.ctx ? std::make_optional(mlx_array_get_(bias))
: std::nullopt),
eps,
mlx_stream_get_(s)));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
struct mlx_fast_metal_kernel_config_cpp_ {
std::vector<std::vector<int>> output_shapes;
std::vector<mlx::core::Dtype> output_dtypes;
std::tuple<int, int, int> grid;
std::tuple<int, int, int> thread_group;
std::vector<std::pair<std::string, mlx::core::fast::TemplateArg>>
template_args;
std::optional<float> init_value;
bool verbose;
};
inline mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new_() {
return mlx_fast_metal_kernel_config(
{new mlx_fast_metal_kernel_config_cpp_()});
}
inline mlx_fast_metal_kernel_config_cpp_& mlx_fast_metal_kernel_config_get_(
mlx_fast_metal_kernel_config d) {
if (!d.ctx) {
throw std::runtime_error(
"expected a non-empty mlx_fast_metal_kernel_config");
}
return *static_cast<mlx_fast_metal_kernel_config_cpp_*>(d.ctx);
}
inline void mlx_fast_metal_kernel_config_free_(mlx_fast_metal_kernel_config d) {
if (d.ctx) {
delete static_cast<mlx_fast_metal_kernel_config_cpp_*>(d.ctx);
}
}
extern "C" mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new() {
try {
return mlx_fast_metal_kernel_config_new_();
} catch (std::exception& e) {
mlx_error(e.what());
}
return {nullptr};
}
extern "C" void mlx_fast_metal_kernel_config_free(
mlx_fast_metal_kernel_config cls) {
mlx_fast_metal_kernel_config_free_(cls);
}
struct mlx_fast_metal_kernel_cpp_ {
mlx::core::fast::MetalKernelFunction mkf;
mlx_fast_metal_kernel_cpp_(mlx::core::fast::MetalKernelFunction mkf)
: mkf(mkf) {};
};
inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new_(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
bool atomic_outputs) {
return mlx_fast_metal_kernel(
{new mlx_fast_metal_kernel_cpp_(mlx::core::fast::metal_kernel(
name,
input_names,
output_names,
source,
header,
ensure_row_contiguous,
atomic_outputs))});
}
inline mlx::core::fast::MetalKernelFunction& mlx_fast_metal_kernel_get_(
mlx_fast_metal_kernel d) {
if (!d.ctx) {
throw std::runtime_error("expected a non-empty mlx_fast_metal_kernel");
}
return static_cast<mlx_fast_metal_kernel_cpp_*>(d.ctx)->mkf;
}
inline void mlx_fast_metal_kernel_free_(mlx_fast_metal_kernel d) {
if (d.ctx) {
delete static_cast<mlx_fast_metal_kernel_cpp_*>(d.ctx);
}
}
extern "C" mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
bool atomic_outputs) {
try {
return mlx_fast_metal_kernel_new_(
name,
mlx_vector_string_get_(input_names),
mlx_vector_string_get_(output_names),
source,
header,
ensure_row_contiguous,
atomic_outputs);
} catch (std::exception& e) {
mlx_error(e.what());
}
return {nullptr};
}
extern "C" void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) {
mlx_fast_metal_kernel_free_(cls);
}
extern "C" int mlx_fast_metal_kernel_config_add_output_arg(
mlx_fast_metal_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype) {
try {
mlx_fast_metal_kernel_config_get_(cls).output_shapes.push_back(
std::vector<int>(shape, shape + size));
mlx_fast_metal_kernel_config_get_(cls).output_dtypes.push_back(
mlx_dtype_to_cpp(dtype));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_metal_kernel_config_set_grid(
mlx_fast_metal_kernel_config cls,
int grid1,
int grid2,
int grid3) {
try {
mlx_fast_metal_kernel_config_get_(cls).grid =
std::make_tuple(grid1, grid2, grid3);
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_metal_kernel_config_set_thread_group(
mlx_fast_metal_kernel_config cls,
int thread1,
int thread2,
int thread3) {
try {
mlx_fast_metal_kernel_config_get_(cls).thread_group =
std::make_tuple(thread1, thread2, thread3);
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_metal_kernel_config_set_init_value(
mlx_fast_metal_kernel_config cls,
float value) {
try {
mlx_fast_metal_kernel_config_get_(cls).init_value = value;
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_metal_kernel_config_set_verbose(
mlx_fast_metal_kernel_config cls,
bool verbose) {
try {
mlx_fast_metal_kernel_config_get_(cls).verbose = verbose;
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_metal_kernel_config_add_template_arg_dtype(
mlx_fast_metal_kernel_config cls,
const char* name,
mlx_dtype dtype) {
try {
mlx_fast_metal_kernel_config_get_(cls).template_args.push_back(
std::make_pair(std::string(name), mlx_dtype_to_cpp(dtype)));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_metal_kernel_config_add_template_arg_int(
mlx_fast_metal_kernel_config cls,
const char* name,
int value) {
try {
mlx_fast_metal_kernel_config_get_(cls).template_args.push_back(
std::make_pair(std::string(name), value));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_metal_kernel_config_add_template_arg_bool(
mlx_fast_metal_kernel_config cls,
const char* name,
bool value) {
try {
mlx_fast_metal_kernel_config_get_(cls).template_args.push_back(
std::make_pair(std::string(name), value));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_metal_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_metal_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_metal_kernel_config config,
const mlx_stream stream) {
try {
auto config_ctx = mlx_fast_metal_kernel_config_get_(config);
mlx_vector_array_set_(
*outputs,
mlx_fast_metal_kernel_get_(cls)(
mlx_vector_array_get_(inputs),
config_ctx.output_shapes,
config_ctx.output_dtypes,
config_ctx.grid,
config_ctx.thread_group,
config_ctx.template_args,
config_ctx.init_value,
config_ctx.verbose,
mlx_stream_get_(stream)));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_rms_norm(
mlx_array* res,
const mlx_array x,
const mlx_array weight ,
float eps,
const mlx_stream s) {
try {
mlx_array_set_(
*res,
mlx::core::fast::rms_norm(
mlx_array_get_(x),
(weight.ctx ? std::make_optional(mlx_array_get_(weight))
: std::nullopt),
eps,
mlx_stream_get_(s)));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_rope(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
int offset,
const mlx_array freqs ,
const mlx_stream s) {
try {
mlx_array_set_(
*res,
mlx::core::fast::rope(
mlx_array_get_(x),
dims,
traditional,
(base.has_value ? std::make_optional<float>(base.value)
: std::nullopt),
scale,
offset,
(freqs.ctx ? std::make_optional(mlx_array_get_(freqs))
: std::nullopt),
mlx_stream_get_(s)));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}
extern "C" int mlx_fast_scaled_dot_product_attention(
mlx_array* res,
const mlx_array queries,
const mlx_array keys,
const mlx_array values,
float scale,
const char* mask_mode,
const mlx_vector_array mask_arrs,
const mlx_stream s) {
try {
mlx_array_set_(
*res,
mlx::core::fast::scaled_dot_product_attention(
mlx_array_get_(queries),
mlx_array_get_(keys),
mlx_array_get_(values),
scale,
std::string(mask_mode),
mlx_vector_array_get_(mask_arrs),
mlx_stream_get_(s)));
} catch (std::exception& e) {
mlx_error(e.what());
return 1;
}
return 0;
}