use std::{cell::RefCell, path::PathBuf};
use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, EmptyInputPayload, Error, LengthMismatchPayload, OutOfRangePayload,
RankMismatchPayload, Result, try_extend_from_slice, try_with_capacity,
},
lm::{
cache::KvCache,
generate::{
FinishReason, GenConfig, GenStep, LogitsProcessor, Sampler, make_logits_processors,
make_sampler,
},
},
ops,
vlm::{
image::{ImageProcessorConfig, load_image, preprocess},
model::Model,
prompt::{MarkerPolicy, insert_image_tokens},
},
};
#[derive(Debug, Clone)]
pub struct VlmGenConfig {
lm: GenConfig,
image_token_id: u32,
image_marker_id: Option<u32>,
num_tokens_per_image: usize,
marker_policy: MarkerPolicy,
}
impl VlmGenConfig {
pub fn new(
lm: GenConfig,
image_token_id: u32,
num_tokens_per_image: usize,
marker_policy: MarkerPolicy,
) -> Self {
Self {
lm,
image_token_id,
image_marker_id: None,
num_tokens_per_image,
marker_policy,
}
}
#[must_use]
pub fn with_image_marker_id(mut self, v: Option<u32>) -> Self {
self.image_marker_id = v;
self
}
#[inline(always)]
pub fn lm_ref(&self) -> &GenConfig {
&self.lm
}
#[inline(always)]
pub fn lm_mut(&mut self) -> &mut GenConfig {
&mut self.lm
}
#[inline(always)]
pub fn image_token_id(&self) -> u32 {
self.image_token_id
}
#[inline(always)]
pub fn image_marker_id(&self) -> Option<u32> {
self.image_marker_id
}
#[inline(always)]
pub fn num_tokens_per_image(&self) -> usize {
self.num_tokens_per_image
}
#[inline(always)]
pub fn marker_policy(&self) -> MarkerPolicy {
self.marker_policy
}
}
pub fn vlm_generate<'a, M: Model + ?Sized>(
model: &'a M,
image_processor_config: &ImageProcessorConfig,
text_tokens: &[u32],
images: &[PathBuf],
cache: Vec<Box<dyn KvCache>>,
cfg: VlmGenConfig,
) -> Result<impl Iterator<Item = Result<GenStep>> + 'a> {
cfg.lm.validate()?;
if cfg.lm.max_tokens == 0 {
return Ok(Box::new(std::iter::empty()) as Box<dyn Iterator<Item = Result<GenStep>> + 'a>);
}
if images.is_empty() {
let mut lm = cfg.lm;
lm.collect_logprobs = true;
let iter = crate::lm::generate::generate_step(model, text_tokens, cache, lm);
return Ok(Box::new(iter) as Box<dyn Iterator<Item = Result<GenStep>> + 'a>);
}
let marker_id = cfg.image_marker_id.unwrap_or(cfg.image_token_id);
let assembled_tokens = insert_image_tokens(
text_tokens,
images.len(),
marker_id,
cfg.image_token_id,
cfg.num_tokens_per_image,
cfg.marker_policy,
)?;
let base: usize = text_tokens
.iter()
.position(|&t| t == marker_id)
.unwrap_or_default();
let mut image_spans: Vec<(usize, usize)> = try_with_capacity(images.len())?;
for i in 0..images.len() {
let start = base + i * cfg.num_tokens_per_image;
let end = start + cfg.num_tokens_per_image;
image_spans.push((start, end));
}
let mut image_slabs: Vec<Array> = try_with_capacity(images.len())?;
for path in images.iter() {
let img = load_image(path)?;
let pre = preprocess(&img, image_processor_config)?;
let encoded = model.encode_image(&pre)?;
let enc_shape = encoded.shape();
let (rows, _d) = match enc_shape.as_slice() {
[n, d] => (*n, *d),
_ => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"vlm_generate: encode_image must return rank-2 [N, D]",
enc_shape.len() as u32,
enc_shape.clone(),
)));
}
};
if rows != cfg.num_tokens_per_image {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"vlm_generate: encode_image feature rows vs cfg.num_tokens_per_image (cross-model splice \
contract requires exactly num_tokens_per_image per image — a variable-per-image model \
must pad/truncate inside encode_image or override vlm_generate)",
cfg.num_tokens_per_image,
rows,
)));
}
image_slabs.push(encoded);
}
let built = (|| -> Result<(Sampler, Vec<LogitsProcessor>)> {
let sampler = make_sampler(
cfg.lm.temp,
cfg.lm.top_p,
cfg.lm.min_p,
cfg.lm.min_tokens_to_keep,
cfg.lm.top_k,
cfg.lm.xtc_probability,
cfg.lm.xtc_threshold,
&cfg.lm.xtc_special_tokens,
cfg.lm.seed,
)?;
let processors = make_logits_processors(
&cfg.lm.logit_bias,
cfg.lm.repetition_penalty,
cfg.lm.repetition_context_size,
cfg.lm.presence_penalty,
cfg.lm.presence_context_size,
cfg.lm.frequency_penalty,
cfg.lm.frequency_context_size,
)?;
Ok((sampler, processors))
})();
match built {
Ok((sampler, processors)) => Ok(Box::new(VlmDecode {
model,
cache: RefCell::new(cache),
sampler: RefCell::new(sampler),
processors,
history: RefCell::new(Vec::new()),
eos: cfg.lm.eos,
max_tokens: cfg.lm.max_tokens,
produced: 0,
prefill_step_size: cfg.lm.prefill_step_size.max(1),
last: None,
prefilled: false,
image_slabs: Some(image_slabs),
image_spans: Some(image_spans),
prompt_history: Some(assembled_tokens),
pending_err: None,
done: false,
}) as Box<dyn Iterator<Item = Result<GenStep>> + 'a>),
Err(e) => Ok(Box::new(VlmDecode {
model,
cache: RefCell::new(cache),
sampler: RefCell::new(Sampler::Argmax),
processors: Vec::new(),
history: RefCell::new(Vec::new()),
eos: Vec::new(),
max_tokens: cfg.lm.max_tokens,
produced: 0,
prefill_step_size: 1,
last: None,
prefilled: true, image_slabs: None,
image_spans: None,
prompt_history: None,
pending_err: Some(e),
done: false,
}) as Box<dyn Iterator<Item = Result<GenStep>> + 'a>),
}
}
struct VlmDecode<'a, M: Model + ?Sized> {
model: &'a M,
cache: RefCell<Vec<Box<dyn KvCache>>>,
sampler: RefCell<Sampler>,
processors: Vec<LogitsProcessor>,
history: RefCell<Vec<u32>>,
eos: Vec<u32>,
max_tokens: usize,
produced: usize,
prefill_step_size: usize,
last: Option<u32>,
prefilled: bool,
image_slabs: Option<Vec<Array>>,
image_spans: Option<Vec<(usize, usize)>>,
prompt_history: Option<Vec<u32>>,
pending_err: Option<Error>,
done: bool,
}
impl<M: Model + ?Sized> VlmDecode<'_, M> {
fn sample_from_logits(&self, logits: &Array, step_inputs: &[u32]) -> Result<GenStep> {
let logits = last_position(logits)?;
let mut logits = logits;
if !self.processors.is_empty() && !step_inputs.is_empty() {
let mut history = self.history.borrow_mut();
try_extend_from_slice(&mut history, step_inputs)?;
for p in &self.processors {
logits = p.apply(&history, &logits)?;
}
}
let lse = ops::reduction::logsumexp(&logits, true)?;
let logprobs = ops::arithmetic::subtract(&logits, &lse)?;
let mut sampled = self.sampler.borrow_mut().sample(&logprobs)?;
let token: u32 = sampled.item::<u32>()?;
let logprobs = ops::shape::squeeze_axes(&logprobs, &[0])?;
Ok(GenStep {
token,
logprobs: Some(logprobs),
step_index: self.produced,
finish_reason: None,
})
}
fn prefill_step(
&self,
prompt_tokens: &[u32],
image_spans: &[(usize, usize)],
image_slabs: &[Array],
) -> Result<GenStep> {
let t = prompt_tokens.len();
if t == 0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"vlm_generate: assembled prompt (T=0); prefill cannot produce logits",
)));
}
let initial_offset = {
let cache = self.cache.borrow();
let mut iter = cache.iter();
match iter.next() {
None => 0, Some(first) => {
let off = first.offset();
for layer in iter {
if layer.offset() != off {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"vlm_generate: KV cache layer offsets must agree (layer 0 vs layer i; \
chunked-multimodal prefill needs one consistent cache offset to size per-chunk \
attention masks — a faithfully restored prompt cache has all layers at the \
same offset)",
off,
layer.offset(),
)));
}
}
off
}
}
};
let step = self.prefill_step_size.max(1);
let mut cursor: usize = 0;
let mut last_logits: Option<Array> = None;
while cursor < t {
let mut end = (cursor + step).min(t);
for &(s, e) in image_spans {
if s < end && end < e {
end = end.max(e);
}
}
let end = end.min(t);
let chunk_len = end - cursor;
let chunk_window = {
let mut row: Vec<i32> = try_with_capacity(chunk_len)?;
row.extend(prompt_tokens[cursor..end].iter().map(|&x| x as i32));
Array::from_slice::<i32>(&row, &(1_usize, chunk_len))?
};
let chunk_text_embeds = self.model.embed_tokens(&chunk_window)?;
let mut chunk_spans: Vec<(usize, usize)> = try_with_capacity(image_spans.len())?;
let mut chunk_slab_refs: Vec<&Array> = try_with_capacity(image_spans.len())?;
for (i, &(s, e)) in image_spans.iter().enumerate() {
if cursor <= s && e <= end {
chunk_spans.push((s - cursor, e - cursor));
chunk_slab_refs.push(&image_slabs[i]);
}
}
let chunk_merged = if chunk_spans.is_empty() {
chunk_text_embeds
} else {
let chunk_image_embeds = if chunk_slab_refs.len() == 1 {
chunk_slab_refs[0].try_clone()?
} else {
ops::shape::concatenate(&chunk_slab_refs, 0)?
};
self
.model
.merge_embeddings(&chunk_text_embeds, &chunk_image_embeds, &chunk_spans)?
};
let chunk_offset = initial_offset.checked_add(cursor).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"vlm_generate: initial cache offset + chunk cursor",
"usize",
[
("initial_offset", initial_offset as u64),
("cursor", cursor as u64),
],
))
})?;
let logits = self.model.forward_embeddings_multimodal(
&chunk_merged,
&chunk_spans,
chunk_offset,
&mut self.cache.borrow_mut(),
)?;
last_logits = Some(logits);
cursor = end;
}
let logits = last_logits.expect("at least one prefill chunk ran (t > 0 guarded above)");
self.sample_from_logits(&logits, prompt_tokens)
}
fn decode_step(&self, last_token: u32) -> Result<GenStep> {
let tokens = Array::from_slice::<i32>(&[last_token as i32], &(1_usize, 1_usize))?;
let logits = self.model.forward(&tokens, &mut self.cache.borrow_mut())?;
self.sample_from_logits(&logits, &[last_token])
}
}
impl<M: Model + ?Sized> Iterator for VlmDecode<'_, M> {
type Item = Result<GenStep>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
if let Some(e) = self.pending_err.take() {
self.done = true;
return Some(Err(e));
}
if self.produced >= self.max_tokens {
self.done = true;
return None;
}
let step_result = if !self.prefilled {
self.prefilled = true;
let prompt_tokens = self.prompt_history.take().unwrap_or_default();
let spans = self.image_spans.take().unwrap_or_default();
let slabs = self.image_slabs.take().unwrap_or_default();
self.prefill_step(&prompt_tokens, &spans, &slabs)
} else {
match self.last {
Some(t) => self.decode_step(t),
None => {
self.done = true;
return None;
}
}
};
match step_result {
Ok(mut step) => {
self.produced += 1;
self.last = Some(step.token);
if self.eos.contains(&step.token) {
self.done = true;
step.finish_reason = Some(FinishReason::Eos);
}
Some(Ok(step))
}
Err(e) => {
self.done = true;
Some(Err(e))
}
}
}
}
fn last_position(logits: &Array) -> Result<Array> {
let shape = logits.shape();
let rank = shape.len() as u32;
if shape.len() != 3 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"vlm_generate: expected [B, S, V] logits from forward (rank 3)",
rank,
shape,
)));
}
if shape[1] == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"vlm_generate: forward logits S axis (logits[:, -1, :] requires S >= 1)",
"must be >= 1",
format!("{} (full shape {:?})", shape[1], shape),
)));
}
if shape[2] == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"vlm_generate: forward logits V axis (logits[:, -1, :] requires V >= 1)",
"must be >= 1",
format!("{} (full shape {:?})", shape[2], shape),
)));
}
let (b, s, v) = (shape[0] as i32, shape[1] as i32, shape[2] as i32);
let sliced = ops::indexing::slice(logits, &[0, s - 1, 0], &[b, s, v], &[1, 1, 1])?;
ops::shape::squeeze_axes(&sliced, &[1])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lm::cache::{CacheConfig, KvCache, MaskMode, make_prompt_cache};
#[derive(Clone, Copy)]
enum EncodeShape {
Rank2 { rows: usize, hidden: usize },
Rank1 { n: usize },
}
struct VlmMock {
vocab: usize,
hidden: usize,
encode: EncodeShape,
}
impl VlmMock {
fn new(vocab: usize, hidden: usize) -> Self {
Self {
vocab,
hidden,
encode: EncodeShape::Rank2 { rows: 1, hidden },
}
}
fn with_encode(mut self, encode: EncodeShape) -> Self {
self.encode = encode;
self
}
fn logits(&self, batch: usize, seq: usize) -> Result<Array> {
let mut data: Vec<f32> = Vec::with_capacity(batch * seq * self.vocab);
for _ in 0..batch * seq {
for v in 0..self.vocab {
data.push(v as f32);
}
}
Array::from_slice::<f32>(&data, &(batch, seq, self.vocab))
}
}
impl crate::lm::model::Model for VlmMock {
fn forward(&self, tokens: &Array, _cache: &mut [Box<dyn KvCache>]) -> Result<Array> {
let shape = tokens.shape();
let (b, s) = match shape.as_slice() {
[b, s] => (*b, *s),
[s] => (1, *s),
_ => (1, 1),
};
self.logits(b, s)
}
fn forward_embeddings(
&self,
embeddings: &Array,
_cache: &mut [Box<dyn KvCache>],
) -> Result<Array> {
let shape = embeddings.shape();
let s = if shape.len() == 3 { shape[1] } else { 1 };
self.logits(1, s)
}
fn supports_input_embeddings(&self) -> bool {
true
}
}
impl Model for VlmMock {
fn embed_tokens(&self, tokens: &Array) -> Result<Array> {
let shape = tokens.shape();
let t = match shape.as_slice() {
[_b, t] => *t,
[t] => *t,
_ => 1,
};
let data = vec![0.0_f32; t * self.hidden];
Array::from_slice::<f32>(&data, &(1_usize, t, self.hidden))
}
fn encode_image(&self, _image: &Array) -> Result<Array> {
match self.encode {
EncodeShape::Rank2 { rows, hidden } => {
let data = vec![1.0_f32; rows * hidden];
Array::from_slice::<f32>(&data, &(rows, hidden))
}
EncodeShape::Rank1 { n } => {
let data = vec![1.0_f32; n];
Array::from_slice::<f32>(&data, &(n,))
}
}
}
}
struct FixedOffsetCache {
offset: usize,
}
impl KvCache for FixedOffsetCache {
fn offset(&self) -> usize {
self.offset
}
fn update(&mut self, keys: &Array, values: &Array) -> Result<(Array, Array)> {
Ok((keys.try_clone()?, values.try_clone()?))
}
fn state(&self) -> Result<Vec<Array>> {
Ok(Vec::new())
}
fn set_state(&mut self, _state: Vec<Array>) -> Result<()> {
Ok(())
}
fn materialize(&mut self) -> Result<()> {
Ok(())
}
fn make_mask(
&self,
_n: usize,
_window_size: Option<usize>,
_return_array: bool,
) -> Result<MaskMode> {
Ok(MaskMode::None)
}
fn nbytes(&self) -> usize {
0
}
fn is_empty(&self) -> bool {
self.offset == 0
}
fn copy(&self) -> Result<Box<dyn KvCache>> {
Ok(Box::new(FixedOffsetCache {
offset: self.offset,
}))
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn reference_class_name(&self) -> &'static str {
"FixedOffsetCache"
}
}
fn decode_with<'a>(
model: &'a VlmMock,
cache: Vec<Box<dyn KvCache>>,
prefill_step_size: usize,
prompt: Vec<u32>,
spans: Vec<(usize, usize)>,
slabs: Vec<Array>,
) -> VlmDecode<'a, VlmMock> {
VlmDecode {
model,
cache: RefCell::new(cache),
sampler: RefCell::new(Sampler::Argmax),
processors: Vec::new(),
history: RefCell::new(Vec::new()),
eos: Vec::new(),
max_tokens: 8,
produced: 0,
prefill_step_size,
last: None,
prefilled: false,
image_slabs: Some(slabs),
image_spans: Some(spans),
prompt_history: Some(prompt),
pending_err: None,
done: false,
}
}
#[test]
fn vlm_gen_config_accessors_roundtrip() {
let lm = GenConfig::default().with_max_tokens(7);
let cfg = VlmGenConfig::new(lm, 99, 3, MarkerPolicy::Required);
assert_eq!(cfg.lm_ref().max_tokens, 7);
assert_eq!(cfg.image_token_id(), 99); assert_eq!(cfg.image_marker_id(), None); assert_eq!(cfg.num_tokens_per_image(), 3); assert!(cfg.marker_policy().is_required());
let mut cfg = cfg.with_image_marker_id(Some(42));
assert_eq!(cfg.image_marker_id(), Some(42));
cfg.lm_mut().max_tokens = 11;
assert_eq!(cfg.lm_ref().max_tokens, 11);
assert_eq!(cfg.image_token_id(), 99);
assert_eq!(cfg.num_tokens_per_image(), 3);
assert!(cfg.marker_policy().is_required());
}
#[test]
fn last_position_rejects_non_rank3() {
let two_d = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2_usize, 2)).unwrap();
let err = last_position(&two_d).unwrap_err();
match err {
Error::RankMismatch(p) => {
assert!(p.context().contains("rank 3"), "ctx: {}", p.context());
assert_eq!(p.actual(), 2, "observed rank carried");
assert_eq!(p.actual_shape(), &[2, 2], "observed shape carried");
}
other => panic!("expected RankMismatch, got {other:?}"),
}
}
#[test]
fn last_position_rejects_zero_s_axis() {
let data: Vec<f32> = Vec::new();
let z = Array::from_slice::<f32>(&data, &(1_usize, 0, 4)).unwrap();
let err = last_position(&z).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert!(p.context().contains("S axis"), "ctx: {}", p.context());
assert!(
p.value().starts_with('0'),
"value reports S=0: {}",
p.value()
);
}
other => panic!("expected OutOfRange(S), got {other:?}"),
}
}
#[test]
fn last_position_rejects_zero_v_axis() {
let data: Vec<f32> = Vec::new();
let z = Array::from_slice::<f32>(&data, &(1_usize, 2, 0)).unwrap();
let err = last_position(&z).unwrap_err();
match err {
Error::OutOfRange(p) => {
assert!(p.context().contains("V axis"), "ctx: {}", p.context());
assert!(
p.value().starts_with('0'),
"value reports V=0: {}",
p.value()
);
}
other => panic!("expected OutOfRange(V), got {other:?}"),
}
}
#[test]
fn last_position_slices_final_position() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let logits = Array::from_slice::<f32>(&data, &(1_usize, 3, 2)).unwrap();
let mut out = last_position(&logits).unwrap();
assert_eq!(out.shape(), vec![1, 2], "S axis dropped");
assert_eq!(
out.to_vec::<f32>().unwrap(),
vec![4.0, 5.0],
"final position kept"
);
}
#[test]
fn vlm_decode_pending_err_is_first_then_fuses() {
let model = VlmMock::new(4, 2);
let cache: Vec<Box<dyn KvCache>> = Vec::new();
let mut it = VlmDecode {
model: &model,
cache: RefCell::new(cache),
sampler: RefCell::new(Sampler::Argmax),
processors: Vec::new(),
history: RefCell::new(Vec::new()),
eos: Vec::new(),
max_tokens: 5,
produced: 0,
prefill_step_size: 1,
last: None,
prefilled: true, image_slabs: None,
image_spans: None,
prompt_history: None,
pending_err: Some(Error::EmptyInput(EmptyInputPayload::new(
"sentinel pending error",
))),
done: false,
};
let err = it.next().expect("yields the pending err").unwrap_err();
assert!(
matches!(err, Error::EmptyInput(ref p) if p.context().contains("sentinel")),
"deferred pending_err surfaced, got {err:?}"
);
assert!(it.next().is_none(), "fuses after the single deferred Err");
}
#[test]
fn vlm_decode_prefilled_but_no_last_ends() {
let model = VlmMock::new(4, 2);
let cache: Vec<Box<dyn KvCache>> = Vec::new();
let mut it = VlmDecode {
model: &model,
cache: RefCell::new(cache),
sampler: RefCell::new(Sampler::Argmax),
processors: Vec::new(),
history: RefCell::new(Vec::new()),
eos: Vec::new(),
max_tokens: 5,
produced: 0,
prefill_step_size: 1,
last: None,
prefilled: true,
image_slabs: None,
image_spans: None,
prompt_history: None,
pending_err: None,
done: false,
};
assert!(
it.next().is_none(),
"prefilled + last==None ends the iterator"
);
assert!(it.next().is_none());
}
#[test]
fn prefill_step_empty_prompt_is_empty_input() {
let model = VlmMock::new(4, 2);
let cache: Vec<Box<dyn KvCache>> = Vec::new();
let it = decode_with(&model, cache, 4, Vec::new(), Vec::new(), Vec::new());
let err = it.prefill_step(&[], &[], &[]).unwrap_err();
match err {
Error::EmptyInput(p) => assert!(
p.context().contains("T=0") || p.context().contains("prefill"),
"ctx names the empty-prompt prefill: {}",
p.context()
),
other => panic!("expected EmptyInput, got {other:?}"),
}
}
#[test]
fn prefill_step_rejects_mismatched_layer_offsets() {
let model = VlmMock::new(4, 2);
let cache: Vec<Box<dyn KvCache>> = vec![
Box::new(FixedOffsetCache { offset: 0 }),
Box::new(FixedOffsetCache { offset: 5 }),
];
let it = decode_with(&model, cache, 4, vec![1, 2, 3], Vec::new(), Vec::new());
let err = it.prefill_step(&[1, 2, 3], &[], &[]).unwrap_err();
match err {
Error::LengthMismatch(p) => {
assert!(
p.context().contains("offset"),
"ctx names the offset disagreement: {}",
p.context()
);
assert_eq!(p.expected(), 0, "layer-0 offset is the reference");
assert_eq!(p.actual(), 5, "the disagreeing layer's offset");
}
other => panic!("expected LengthMismatch(offsets), got {other:?}"),
}
}
#[test]
fn prefill_step_offset_overflow_is_arithmetic_overflow() {
let model = VlmMock::new(4, 2);
let cache: Vec<Box<dyn KvCache>> = vec![Box::new(FixedOffsetCache {
offset: usize::MAX - 1,
})];
let it = decode_with(&model, cache, 1, vec![10, 11, 12], Vec::new(), Vec::new());
let err = it.prefill_step(&[10, 11, 12], &[], &[]).unwrap_err();
match err {
Error::ArithmeticOverflow(p) => assert!(
p.context().contains("cache offset") || p.context().contains("cursor"),
"ctx names the offset+cursor add: {}",
p.context()
),
other => panic!("expected ArithmeticOverflow, got {other:?}"),
}
}
#[test]
fn vlm_generate_rejects_rank1_encode_output() {
let dir =
std::env::temp_dir().join(format!("mlxrs-vlm-generate-encode-{}", std::process::id()));
std::fs::create_dir_all(&dir).expect("create temp dir");
let path = dir.join("tiny.png");
let mut buf = ::image::RgbImage::new(2, 2);
for y in 0..2 {
for x in 0..2 {
buf.put_pixel(x, y, ::image::Rgb([128, 64, 200]));
}
}
::image::DynamicImage::ImageRgb8(buf)
.save_with_format(&path, ::image::ImageFormat::Png)
.expect("encode tiny PNG");
let model = VlmMock::new(4, 2).with_encode(EncodeShape::Rank1 { n: 1 });
let proc_cfg = ImageProcessorConfig::default();
let cfg = VlmGenConfig::new(
GenConfig::default().with_max_tokens(4),
7, 1, MarkerPolicy::Required,
);
let cache = make_prompt_cache(&CacheConfig {
num_hidden_layers: 1,
sliding_window: None,
});
let res = vlm_generate(
&model,
&proc_cfg,
&[7u32],
std::slice::from_ref(&path),
cache,
cfg,
);
let _ = std::fs::remove_file(&path);
let _ = std::fs::remove_dir(&dir);
let err = res.err().expect("rank-1 encode_image output must error");
match err {
Error::RankMismatch(p) => {
assert!(
p.context().contains("encode_image") && p.context().contains("[N, D]"),
"ctx names the encode_image rank-2 contract: {}",
p.context()
);
assert_eq!(p.actual(), 1, "observed rank-1 carried");
}
other => panic!("expected RankMismatch from the encode-shape check, got {other:?}"),
}
}
#[test]
fn vlm_generate_zero_max_tokens_is_empty_no_vision() {
let model = VlmMock::new(4, 2);
let proc_cfg = ImageProcessorConfig::default();
let cfg = VlmGenConfig::new(
GenConfig::default().with_max_tokens(0),
7,
1,
MarkerPolicy::Required,
);
let cache = make_prompt_cache(&CacheConfig {
num_hidden_layers: 1,
sliding_window: None,
});
let bogus = PathBuf::from("/nonexistent/mlxrs-vlm-no-such-image.png");
let mut it = vlm_generate(&model, &proc_cfg, &[7u32], &[bogus], cache, cfg)
.expect("max_tokens==0 short-circuits to Ok(empty) before any vision work");
assert!(it.next().is_none(), "zero-budget run yields nothing");
}
#[test]
fn vlm_generate_invalid_cfg_is_eager_err() {
let model = VlmMock::new(4, 2);
let proc_cfg = ImageProcessorConfig::default();
let cfg = VlmGenConfig::new(
GenConfig::default().with_temp(-1.0),
7,
1,
MarkerPolicy::Required,
);
let cache: Vec<Box<dyn KvCache>> = Vec::new();
let res = vlm_generate(&model, &proc_cfg, &[7u32], &[], cache, cfg);
match res.err().expect("invalid temp must be an eager Err") {
Error::OutOfRange(p) => assert!(
p.context().contains("temp"),
"eager validate() surfaced temp range error: {}",
p.context()
),
other => panic!("expected eager OutOfRange(temp), got {other:?}"),
}
}
}