use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, EmptyInputPayload, Error, InvariantViolationPayload,
LengthMismatchPayload, OutOfRangePayload, Result, try_with_capacity,
},
};
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Default, derive_more::Display, derive_more::IsVariant,
)]
#[display("{}", self.as_str())]
pub enum PaddingSide {
#[default]
Left,
Right,
}
impl PaddingSide {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Left => "left",
Self::Right => "right",
}
}
}
#[derive(Debug)]
pub struct PreparedInputs {
input_ids: Array,
attention_mask: Array,
pixel_values: Option<Array>,
input_features: Option<Array>,
pixel_values_videos: Option<Array>,
}
impl PreparedInputs {
pub fn new(
input_ids: Array,
attention_mask: Array,
pixel_values: Option<Array>,
input_features: Option<Array>,
pixel_values_videos: Option<Array>,
) -> Self {
Self {
input_ids,
attention_mask,
pixel_values,
input_features,
pixel_values_videos,
}
}
#[inline(always)]
pub fn input_ids_ref(&self) -> &Array {
&self.input_ids
}
#[inline(always)]
pub fn input_ids_mut(&mut self) -> &mut Array {
&mut self.input_ids
}
#[inline(always)]
pub fn attention_mask_ref(&self) -> &Array {
&self.attention_mask
}
#[inline(always)]
pub fn attention_mask_mut(&mut self) -> &mut Array {
&mut self.attention_mask
}
#[inline(always)]
pub fn pixel_values_ref(&self) -> Option<&Array> {
self.pixel_values.as_ref()
}
#[inline(always)]
pub fn pixel_values_mut(&mut self) -> Option<&mut Array> {
self.pixel_values.as_mut()
}
#[inline(always)]
pub fn input_features_ref(&self) -> Option<&Array> {
self.input_features.as_ref()
}
#[inline(always)]
pub fn input_features_mut(&mut self) -> Option<&mut Array> {
self.input_features.as_mut()
}
#[inline(always)]
pub fn pixel_values_videos_ref(&self) -> Option<&Array> {
self.pixel_values_videos.as_ref()
}
#[inline(always)]
pub fn pixel_values_videos_mut(&mut self) -> Option<&mut Array> {
self.pixel_values_videos.as_mut()
}
}
#[derive(Debug, Clone)]
pub struct PrepareInputsOpts {
pad_token_id: u32,
padding: bool,
padding_side: PaddingSide,
attention_mask: Vec<Vec<bool>>,
}
impl Default for PrepareInputsOpts {
fn default() -> Self {
Self::new()
}
}
impl PrepareInputsOpts {
pub fn new() -> Self {
Self {
pad_token_id: 0,
padding: true,
padding_side: PaddingSide::Left,
attention_mask: Vec::new(),
}
}
#[must_use]
pub fn with_pad_token_id(mut self, v: u32) -> Self {
self.pad_token_id = v;
self
}
#[must_use]
pub fn with_padding(mut self, v: bool) -> Self {
self.padding = v;
self
}
#[must_use]
pub fn with_padding_side(mut self, v: PaddingSide) -> Self {
self.padding_side = v;
self
}
#[must_use]
pub fn with_attention_mask(mut self, v: Vec<Vec<bool>>) -> Self {
self.attention_mask = v;
self
}
#[inline(always)]
pub fn pad_token_id(&self) -> u32 {
self.pad_token_id
}
#[inline(always)]
pub fn padding(&self) -> bool {
self.padding
}
#[inline(always)]
pub fn padding_side(&self) -> PaddingSide {
self.padding_side
}
#[inline(always)]
pub fn attention_mask(&self) -> &[Vec<bool>] {
&self.attention_mask
}
}
pub fn prepare_inputs(
text_token_batches: &[&[u32]],
pixel_values: Option<Array>,
input_features: Option<Array>,
pixel_values_videos: Option<Array>,
opts: &PrepareInputsOpts,
) -> Result<PreparedInputs> {
if text_token_batches.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"prepare_inputs: text_token_batches",
)));
}
if !opts.attention_mask.is_empty() {
let masks = &opts.attention_mask;
if masks.len() != text_token_batches.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"prepare_inputs: opts.attention_mask outer vs text_token_batches",
text_token_batches.len(),
masks.len(),
)));
}
for (m, b) in masks.iter().zip(text_token_batches.iter()) {
if m.len() != b.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"prepare_inputs: opts.attention_mask[i] vs text_token_batches[i] length",
b.len(),
m.len(),
)));
}
}
}
let batch_size = text_token_batches.len();
let lens: Vec<usize> = text_token_batches.iter().map(|b| b.len()).collect();
let target_t = if opts.padding {
*lens.iter().max().unwrap_or(&0)
} else {
let first = lens[0];
if !lens.iter().all(|&l| l == first) {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"prepare_inputs: text_token_batches",
"all batches must have equal length when padding=false (enable padding or pre-pad upstream)",
)));
}
first
};
if target_t > i32::MAX as usize {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"prepare_inputs: target_t",
"must be <= i32::MAX (mlx dimension limit)",
format!("{target_t}"),
)));
}
if batch_size > i32::MAX as usize {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"prepare_inputs: batch_size",
"must be <= i32::MAX (mlx dimension limit)",
format!("{batch_size}"),
)));
}
let total = batch_size.checked_mul(target_t).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"prepare_inputs: batch_size * target_t",
"usize",
[
("batch_size", batch_size as u64),
("target_t", target_t as u64),
],
))
})?;
let mut ids_buf: Vec<i32> = try_with_capacity(total)?;
let mut mask_buf: Vec<bool> = try_with_capacity(total)?;
let has_caller_mask = !opts.attention_mask.is_empty();
for (b, batch) in text_token_batches.iter().enumerate() {
let pad_count = target_t - lens[b];
let row_mask: Option<&[bool]> = if has_caller_mask {
Some(opts.attention_mask[b].as_slice())
} else {
None
};
match opts.padding_side {
PaddingSide::Left => {
for _ in 0..pad_count {
ids_buf.push(opts.pad_token_id as i32);
mask_buf.push(false);
}
for (i, &t) in batch.iter().enumerate() {
ids_buf.push(t as i32);
mask_buf.push(row_mask.is_none_or(|m| m[i]));
}
}
PaddingSide::Right => {
for (i, &t) in batch.iter().enumerate() {
ids_buf.push(t as i32);
mask_buf.push(row_mask.is_none_or(|m| m[i]));
}
for _ in 0..pad_count {
ids_buf.push(opts.pad_token_id as i32);
mask_buf.push(false);
}
}
}
}
let input_ids = Array::from_slice::<i32>(&ids_buf, &(batch_size, target_t))?;
let attention_mask = Array::from_slice::<bool>(&mask_buf, &(batch_size, target_t))?;
Ok(PreparedInputs::new(
input_ids,
attention_mask,
pixel_values,
input_features,
pixel_values_videos,
))
}
#[cfg(all(feature = "vlm", feature = "audio"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "vlm", feature = "audio"))))]
pub fn read_audio(path: &std::path::Path) -> Result<(Vec<f32>, u32)> {
crate::audio::io::load_audio(path)
}
#[cfg(all(feature = "vlm", feature = "audio"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "vlm", feature = "audio"))))]
pub fn load_audio_vlm(path: &std::path::Path, sr: u32) -> Result<Vec<f32>> {
let (samples, sample_rate) = read_audio(path)?;
if sample_rate != sr {
crate::audio::io::resample_linear(&samples, sample_rate, sr)
} else {
Ok(samples)
}
}
#[cfg(all(feature = "vlm", feature = "audio"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "vlm", feature = "audio"))))]
pub fn normalize_audio_features(features: &Array) -> Result<Array> {
let mean = features.mean(false)?;
let std = features.std(false, 0)?;
let centered = crate::ops::arithmetic::subtract(features, &mean)?;
let eps = Array::full::<f32>(&(1usize,), 1e-6_f32)?;
let denom = crate::ops::arithmetic::add(&std, &eps)?;
crate::ops::arithmetic::divide(¢ered, &denom)
}
#[cfg(feature = "vlm")]
pub fn load_video(
frames: &[::image::DynamicImage],
cfg: &crate::vlm::image::ImageProcessorConfig,
) -> Result<Array> {
crate::vlm::video::process_frames(frames, cfg)
}