use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::ops::rope_multi::{
build_rope_multi_buffers, dispatch_rope_multi, RopeMultiMode, RopeMultiParams,
};
#[derive(Debug, Clone, Copy)]
pub struct RopeTrainParams {
pub batch: u32,
pub n_heads: u32,
pub seq_len: u32,
pub head_dim: u32,
pub rope_dim: u32,
pub theta_base: f32,
pub sections: [u32; 4],
}
fn to_rope_multi_params(p: &RopeTrainParams) -> RopeMultiParams {
RopeMultiParams {
head_dim: p.head_dim,
rope_dim: p.rope_dim,
n_heads: p.n_heads,
seq_len: p.seq_len * p.batch, freq_base: p.theta_base,
mode: RopeMultiMode::Imrope,
sections: p.sections,
}
}
fn validate_params(p: &RopeTrainParams) -> Result<()> {
if p.batch == 0 || p.n_heads == 0 || p.seq_len == 0 || p.head_dim == 0 || p.rope_dim == 0 {
return Err(MlxError::InvalidArgument(
"rope_train: batch, n_heads, seq_len, head_dim, rope_dim must all be > 0".into(),
));
}
if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(
"rope_train: head_dim and rope_dim must be even".into(),
));
}
if p.rope_dim > p.head_dim {
return Err(MlxError::InvalidArgument(format!(
"rope_train: rope_dim ({}) must be <= head_dim ({})",
p.rope_dim, p.head_dim
)));
}
if !p.theta_base.is_finite() || p.theta_base <= 0.0 {
return Err(MlxError::InvalidArgument(format!(
"rope_train: theta_base must be finite and positive, got {}",
p.theta_base
)));
}
Ok(())
}
fn validate_io(label: &str, buf: &MlxBuffer, expected_elems: usize, expected_dtype: DType) -> Result<()> {
if buf.element_count() != expected_elems {
return Err(MlxError::InvalidArgument(format!(
"rope_train: {label} element count {} != expected {}",
buf.element_count(),
expected_elems
)));
}
if buf.dtype() != expected_dtype {
return Err(MlxError::InvalidArgument(format!(
"rope_train: {label} dtype {} != expected {}",
buf.dtype(),
expected_dtype
)));
}
Ok(())
}
fn tensor_elems(p: &RopeTrainParams) -> usize {
p.batch as usize * p.n_heads as usize * p.seq_len as usize * p.head_dim as usize
}
fn pos_elems(p: &RopeTrainParams) -> usize {
4 * p.seq_len as usize * p.batch as usize
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_forward_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
mlx_device: &MlxDevice,
in_buf: &MlxBuffer,
pos_buf: &MlxBuffer,
out_buf: &MlxBuffer,
params: &RopeTrainParams,
) -> Result<()> {
validate_params(params)?;
let n_elems = tensor_elems(params);
validate_io("in_buf", in_buf, n_elems, DType::BF16)?;
validate_io("out_buf", out_buf, n_elems, DType::BF16)?;
if pos_buf.element_count() != pos_elems(params) {
return Err(MlxError::InvalidArgument(format!(
"rope_train forward: pos_buf element count {} != 4 * batch({}) * seq_len({}) = {}",
pos_buf.element_count(),
params.batch,
params.seq_len,
pos_elems(params)
)));
}
match pos_buf.dtype() {
DType::I32 | DType::U32 => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"rope_train forward: pos_buf dtype {other} must be i32 or u32"
)));
}
}
let mp = to_rope_multi_params(params);
let (params_buf, rope_params_buf, sections_buf) = build_rope_multi_buffers(mlx_device, mp)?;
dispatch_rope_multi(
encoder,
registry,
device,
in_buf,
out_buf,
pos_buf,
¶ms_buf,
&rope_params_buf,
§ions_buf,
mp,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_backward_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
mlx_device: &MlxDevice,
grad_out_buf: &MlxBuffer,
pos_buf: &MlxBuffer,
grad_in_buf: &MlxBuffer,
params: &RopeTrainParams,
) -> Result<()> {
validate_params(params)?;
let n_elems = tensor_elems(params);
validate_io("grad_out_buf", grad_out_buf, n_elems, DType::BF16)?;
validate_io("grad_in_buf", grad_in_buf, n_elems, DType::BF16)?;
if pos_buf.element_count() != pos_elems(params) {
return Err(MlxError::InvalidArgument(format!(
"rope_train backward: pos_buf element count {} != 4 * batch({}) * seq_len({}) = {}",
pos_buf.element_count(),
params.batch,
params.seq_len,
pos_elems(params)
)));
}
match pos_buf.dtype() {
DType::I32 | DType::U32 => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"rope_train backward: pos_buf dtype {other} must be i32 or u32"
)));
}
}
let neg_pos_buf = negate_pos_buf_i32(mlx_device, pos_buf)?;
let mp = to_rope_multi_params(params);
let (params_buf, rope_params_buf, sections_buf) = build_rope_multi_buffers(mlx_device, mp)?;
dispatch_rope_multi(
encoder,
registry,
device,
grad_out_buf,
grad_in_buf,
&neg_pos_buf,
¶ms_buf,
&rope_params_buf,
§ions_buf,
mp,
)
}
fn negate_pos_buf_i32(device: &MlxDevice, pos_buf: &MlxBuffer) -> Result<MlxBuffer> {
let n = pos_buf.element_count();
let src_bytes: Vec<i32> = match pos_buf.dtype() {
DType::I32 => pos_buf.as_slice::<i32>()?.to_vec(),
DType::U32 => pos_buf
.as_slice::<u32>()?
.iter()
.map(|&v| v as i32)
.collect(),
other => {
return Err(MlxError::InvalidArgument(format!(
"negate_pos_buf: unsupported dtype {other}"
)))
}
};
let negated: Vec<i32> = src_bytes.iter().map(|&v| v.wrapping_neg()).collect();
let mut buf = device.alloc_buffer(n * 4, DType::I32, vec![n])?;
buf.as_mut_slice::<i32>()?.copy_from_slice(&negated);
Ok(buf)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_forward_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
mlx_device: &MlxDevice,
in_buf: &MlxBuffer,
pos_buf: &MlxBuffer,
out_buf: &MlxBuffer,
params: &RopeTrainParams,
) -> Result<()> {
validate_params(params)?;
let n_elems = tensor_elems(params);
validate_io("in_buf", in_buf, n_elems, DType::F32)?;
validate_io("out_buf", out_buf, n_elems, DType::F32)?;
if pos_buf.element_count() != pos_elems(params) {
return Err(MlxError::InvalidArgument(format!(
"rope_train f32 forward: pos_buf element count {} != {}",
pos_buf.element_count(),
pos_elems(params)
)));
}
match pos_buf.dtype() {
DType::I32 | DType::U32 => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"rope_train f32 forward: pos_buf dtype {other} must be i32 or u32"
)));
}
}
let mp = to_rope_multi_params(params);
let (params_buf, rope_params_buf, sections_buf) = build_rope_multi_buffers(mlx_device, mp)?;
dispatch_rope_multi(
encoder,
registry,
device,
in_buf,
out_buf,
pos_buf,
¶ms_buf,
&rope_params_buf,
§ions_buf,
mp,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_backward_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
mlx_device: &MlxDevice,
grad_out_buf: &MlxBuffer,
pos_buf: &MlxBuffer,
grad_in_buf: &MlxBuffer,
params: &RopeTrainParams,
) -> Result<()> {
validate_params(params)?;
let n_elems = tensor_elems(params);
validate_io("grad_out_buf", grad_out_buf, n_elems, DType::F32)?;
validate_io("grad_in_buf", grad_in_buf, n_elems, DType::F32)?;
if pos_buf.element_count() != pos_elems(params) {
return Err(MlxError::InvalidArgument(format!(
"rope_train f32 backward: pos_buf element count {} != {}",
pos_buf.element_count(),
pos_elems(params)
)));
}
match pos_buf.dtype() {
DType::I32 | DType::U32 => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"rope_train f32 backward: pos_buf dtype {other} must be i32 or u32"
)));
}
}
let neg_pos_buf = negate_pos_buf_i32(mlx_device, pos_buf)?;
let mp = to_rope_multi_params(params);
let (params_buf, rope_params_buf, sections_buf) = build_rope_multi_buffers(mlx_device, mp)?;
dispatch_rope_multi(
encoder,
registry,
device,
grad_out_buf,
grad_in_buf,
&neg_pos_buf,
¶ms_buf,
&rope_params_buf,
§ions_buf,
mp,
)
}