#ifndef RLX_MLX_SHIM_H
#define RLX_MLX_SHIM_H
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef struct rlx_mlx_array_s rlx_mlx_array_t;
typedef enum {
RLX_MLX_DTYPE_F32 = 0,
RLX_MLX_DTYPE_F16 = 1,
RLX_MLX_DTYPE_BF16 = 2,
RLX_MLX_DTYPE_I32 = 3,
RLX_MLX_DTYPE_F64 = 4,
RLX_MLX_DTYPE_I8 = 5,
RLX_MLX_DTYPE_I16 = 6,
RLX_MLX_DTYPE_I64 = 7,
RLX_MLX_DTYPE_U8 = 8,
RLX_MLX_DTYPE_U32 = 9,
RLX_MLX_DTYPE_BOOL = 10,
} rlx_mlx_dtype_t;
#define RLX_MLX_OK 0
#define RLX_MLX_ERR_GENERIC 1
#define RLX_MLX_ERR_BAD_DTYPE 2
#define RLX_MLX_ERR_BAD_SHAPE 3
const char* rlx_mlx_last_error(void);
void rlx_mlx_set_last_error(const char* msg);
int rlx_mlx_array_from_data(
const int* shape, size_t ndim,
const float* data, size_t nelems,
rlx_mlx_dtype_t dtype,
rlx_mlx_array_t** out);
int rlx_mlx_array_from_bytes(
const int* shape, size_t ndim,
const void* data, size_t nbytes,
rlx_mlx_dtype_t dtype,
rlx_mlx_array_t** out);
int rlx_mlx_array_to_bytes(
rlx_mlx_array_t* h,
void* dst, size_t dst_cap, size_t* out_nbytes);
size_t rlx_mlx_dtype_size(rlx_mlx_dtype_t dtype);
void rlx_mlx_array_free(rlx_mlx_array_t* h);
int rlx_mlx_array_clone(rlx_mlx_array_t* h, rlx_mlx_array_t** out);
int rlx_mlx_array_shape(
rlx_mlx_array_t* h,
int* out_shape, size_t cap, size_t* out_ndim);
int rlx_mlx_array_to_f32(
rlx_mlx_array_t* h,
float* dst, size_t nelems);
int rlx_mlx_eval(rlx_mlx_array_t* const* handles, size_t n);
int rlx_mlx_async_eval(rlx_mlx_array_t* const* handles, size_t n);
int rlx_mlx_synchronize(void);
int rlx_mlx_op_matmul(rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_add (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_mul (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_sub (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_div (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_softmax (rlx_mlx_array_t* a, int axis, rlx_mlx_array_t** out);
int rlx_mlx_op_gelu (rlx_mlx_array_t* a, rlx_mlx_array_t** out);
int rlx_mlx_op_silu (rlx_mlx_array_t* a, rlx_mlx_array_t** out);
int rlx_mlx_op_cast (rlx_mlx_array_t* a, rlx_mlx_dtype_t dtype, rlx_mlx_array_t** out);
int rlx_mlx_op_layernorm(
rlx_mlx_array_t* x,
rlx_mlx_array_t* gamma,
rlx_mlx_array_t* beta_or_null,
float eps,
rlx_mlx_array_t** out);
int rlx_mlx_op_max (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_min (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_pow (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_solve(rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_metal_kernel_dispatch(
const char* name,
const char* source,
const char* header,
const char* const* input_names,
size_t n_inputs,
const char* output_name,
rlx_mlx_array_t* const* inputs,
const int* output_shape,
size_t output_rank,
rlx_mlx_dtype_t output_dtype,
int grid_x, int grid_y, int grid_z,
int tg_x, int tg_y, int tg_z,
rlx_mlx_array_t** out);
int rlx_mlx_op_eq (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_ne (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_lt (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_le (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_gt (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_ge (rlx_mlx_array_t* a, rlx_mlx_array_t* b, rlx_mlx_array_t** out);
int rlx_mlx_op_where(
rlx_mlx_array_t* cond,
rlx_mlx_array_t* x,
rlx_mlx_array_t* y,
rlx_mlx_array_t** out);
typedef enum {
RLX_MLX_UN_RELU = 0,
RLX_MLX_UN_SIGMOID = 1,
RLX_MLX_UN_TANH = 2,
RLX_MLX_UN_EXP = 3,
RLX_MLX_UN_LOG = 4,
RLX_MLX_UN_SQRT = 5,
RLX_MLX_UN_RSQRT = 6,
RLX_MLX_UN_NEG = 7,
RLX_MLX_UN_ABS = 8,
RLX_MLX_UN_ERF = 9,
RLX_MLX_UN_ROUND = 10,
RLX_MLX_UN_SIN = 11,
RLX_MLX_UN_COS = 12,
RLX_MLX_UN_TAN = 13,
RLX_MLX_UN_ATAN = 14,
} rlx_mlx_unary_t;
int rlx_mlx_op_unary(rlx_mlx_array_t* a, rlx_mlx_unary_t kind, rlx_mlx_array_t** out);
int rlx_mlx_op_reshape(
rlx_mlx_array_t* a,
const int* new_shape, size_t ndim,
rlx_mlx_array_t** out);
int rlx_mlx_op_transpose(
rlx_mlx_array_t* a,
const int* perm, size_t ndim,
rlx_mlx_array_t** out);
int rlx_mlx_op_slice(
rlx_mlx_array_t* a,
const int* start, const int* stop, size_t ndim,
rlx_mlx_array_t** out);
int rlx_mlx_op_concat(
rlx_mlx_array_t* const* arrays, size_t n,
int axis,
rlx_mlx_array_t** out);
int rlx_mlx_op_broadcast_to(
rlx_mlx_array_t* a,
const int* shape, size_t ndim,
rlx_mlx_array_t** out);
int rlx_mlx_op_take(
rlx_mlx_array_t* a,
rlx_mlx_array_t* indices,
int axis,
rlx_mlx_array_t** out);
typedef enum {
RLX_MLX_RED_SUM = 0,
RLX_MLX_RED_MEAN = 1,
RLX_MLX_RED_MAX = 2,
RLX_MLX_RED_MIN = 3,
RLX_MLX_RED_PROD = 4,
} rlx_mlx_reduce_t;
int rlx_mlx_op_reduce(
rlx_mlx_array_t* a,
rlx_mlx_reduce_t kind,
const int* axes, size_t n_axes,
int keep_dim,
rlx_mlx_array_t** out);
int rlx_mlx_op_cumsum(
rlx_mlx_array_t* a,
int axis,
int exclusive,
rlx_mlx_array_t** out);
int rlx_mlx_op_fft(
rlx_mlx_array_t* a,
int inverse,
int norm_tag,
rlx_mlx_array_t** out);
int rlx_mlx_op_rmsnorm(
rlx_mlx_array_t* x,
rlx_mlx_array_t* gamma,
float eps,
rlx_mlx_array_t** out);
typedef enum {
RLX_MLX_MASK_NONE = 0,
RLX_MLX_MASK_CAUSAL = 1,
RLX_MLX_MASK_SLIDING = 2,
RLX_MLX_MASK_CUSTOM = 3,
} rlx_mlx_mask_t;
int rlx_mlx_op_attention(
rlx_mlx_array_t* q,
rlx_mlx_array_t* k,
rlx_mlx_array_t* v,
float scale,
rlx_mlx_mask_t mask_kind,
rlx_mlx_array_t* mask_or_null,
rlx_mlx_array_t** out);
int rlx_mlx_op_conv2d(
rlx_mlx_array_t* input,
rlx_mlx_array_t* weight,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dil_h, int dil_w,
int groups,
rlx_mlx_array_t** out);
int rlx_mlx_op_conv1d(
rlx_mlx_array_t* input,
rlx_mlx_array_t* weight,
int stride, int padding, int dilation, int groups,
rlx_mlx_array_t** out);
int rlx_mlx_op_conv3d(
rlx_mlx_array_t* input,
rlx_mlx_array_t* weight,
int stride_d, int stride_h, int stride_w,
int pad_d, int pad_h, int pad_w,
int dil_d, int dil_h, int dil_w,
int groups,
rlx_mlx_array_t** out);
int rlx_mlx_op_contiguous(
rlx_mlx_array_t* a,
rlx_mlx_array_t** out);
int rlx_mlx_op_maxpool2d_backward_metal(
rlx_mlx_array_t* x,
rlx_mlx_array_t* dy,
int n, int c, int h, int w,
int h_out, int w_out,
int kh, int kw,
int sh, int sw,
int ph, int pw,
rlx_mlx_array_t** out);
int rlx_mlx_op_take_along_axis(
rlx_mlx_array_t* a,
rlx_mlx_array_t* indices,
int axis,
rlx_mlx_array_t** out);
int rlx_mlx_op_scatter_add_axis(
rlx_mlx_array_t* a,
rlx_mlx_array_t* indices,
rlx_mlx_array_t* updates,
int axis,
rlx_mlx_array_t** out);
int rlx_mlx_op_conv_general(
rlx_mlx_array_t* input,
rlx_mlx_array_t* weight,
const int* stride, size_t stride_n,
const int* padding_lo, size_t padding_lo_n,
const int* padding_hi, size_t padding_hi_n,
const int* kernel_dilation, size_t kernel_dilation_n,
const int* input_dilation, size_t input_dilation_n,
int groups,
int flip,
rlx_mlx_array_t** out);
int rlx_mlx_op_argpartition(
rlx_mlx_array_t* a,
int kth, int axis,
rlx_mlx_array_t** out);
int rlx_mlx_op_scatter_add(
rlx_mlx_array_t* a,
rlx_mlx_array_t* indices,
rlx_mlx_array_t* updates,
int axis,
rlx_mlx_array_t** out);
int rlx_mlx_op_gather_mm(
rlx_mlx_array_t* a,
rlx_mlx_array_t* b,
rlx_mlx_array_t* idx,
rlx_mlx_array_t** out);
int rlx_mlx_op_quantized_matmul(
rlx_mlx_array_t* x,
rlx_mlx_array_t* w,
rlx_mlx_array_t* scales,
rlx_mlx_array_t* biases_or_null,
int transpose,
int group_size,
int bits,
rlx_mlx_array_t** out);
int rlx_mlx_op_categorical(
rlx_mlx_array_t* logits,
int axis,
uint64_t seed,
rlx_mlx_array_t** out);
int rlx_mlx_op_argmax(
rlx_mlx_array_t* a,
int axis, int keep_dim,
rlx_mlx_array_t** out);
int rlx_mlx_op_slice_strided(
rlx_mlx_array_t* a,
const int* start, const int* stop, const int* strides, size_t ndim,
rlx_mlx_array_t** out);
int rlx_mlx_op_pad(
rlx_mlx_array_t* a,
const int* low, const int* high, size_t ndim,
float pad_value,
rlx_mlx_array_t** out);
int rlx_mlx_op_topk_values(
rlx_mlx_array_t* a,
int k, int axis,
rlx_mlx_array_t** out);
int rlx_mlx_op_sort(
rlx_mlx_array_t* a,
int axis,
rlx_mlx_array_t** out);
typedef int (*rlx_mlx_lower_fn)(
void* ud,
rlx_mlx_array_t* const* inputs, size_t n_inputs,
rlx_mlx_array_t** out_outputs, size_t cap, size_t* out_n_outputs);
typedef struct rlx_mlx_compiled_s rlx_mlx_compiled_t;
int rlx_mlx_compile(
rlx_mlx_lower_fn fn, void* ud,
int shapeless,
rlx_mlx_compiled_t** out);
int rlx_mlx_compiled_call(
rlx_mlx_compiled_t* compiled,
rlx_mlx_array_t* const* inputs, size_t n_inputs,
rlx_mlx_array_t** out_outputs, size_t cap, size_t* out_n_outputs);
void rlx_mlx_compiled_free(rlx_mlx_compiled_t* compiled);
size_t rlx_mlx_compile_output_cap(void);
void rlx_mlx_set_compile_output_cap(size_t cap);
void rlx_mlx_reset_compile_output_cap(void);
const char* rlx_mlx_version(void);
const char* rlx_mlx_device_name(void);
int rlx_mlx_dist_is_available(int* out);
int rlx_mlx_dist_init(int strict, const char* backend, int* out_rank, int* out_size);
int rlx_mlx_dist_rank(int* out_rank);
int rlx_mlx_dist_size(int* out_size);
int rlx_mlx_dist_all_sum_f32(const float* in, float* out, size_t nelems);
int rlx_mlx_dist_all_gather_f32(
const float* in, size_t nelems, float* out, size_t out_cap);
int rlx_mlx_dist_send_f32(const float* data, size_t nelems, int dst);
int rlx_mlx_dist_recv_f32(float* out, size_t nelems, int src);
int rlx_mlx_dist_barrier(void);
int rlx_mlx_dist_all_sum_array(rlx_mlx_array_t* in, rlx_mlx_array_t** out);
#ifdef __cplusplus
}
#endif
#endif