use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, EmptyInputPayload, Error, InvariantViolationPayload,
LengthMismatchPayload, OutOfRangePayload, RankMismatchPayload, Result, try_with_capacity,
},
ops,
vlm::image::ImageProcessorConfig,
};
pub trait Model: crate::lm::model::Model {
fn embed_tokens(&self, tokens: &Array) -> Result<Array>;
fn encode_image(&self, image: &Array) -> Result<Array>;
fn merge_embeddings(
&self,
text_embeds: &Array,
image_embeds: &Array,
image_spans: &[(usize, usize)],
) -> Result<Array> {
default_merge_embeddings(text_embeds, image_embeds, image_spans)
}
fn forward_embeddings_multimodal(
&self,
embeddings: &Array,
_image_spans: &[(usize, usize)],
_cache_offset: usize,
cache: &mut [Box<dyn crate::lm::cache::KvCache>],
) -> Result<Array> {
crate::lm::model::Model::forward_embeddings(self, embeddings, cache)
}
fn image_processor_config(&self) -> ImageProcessorConfig {
ImageProcessorConfig::default()
}
}
fn default_merge_embeddings(
text_embeds: &Array,
image_embeds: &Array,
image_spans: &[(usize, usize)],
) -> Result<Array> {
let text_shape = text_embeds.shape();
let text_rank = text_shape.len() as u32;
if text_shape.len() != 3 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"merge_embeddings: text_embeds must be rank-3 [1, T, D]",
text_rank,
text_shape,
)));
}
if text_shape[0] != 1 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"merge_embeddings: text_embeds batch dim must be 1 (single-batch prompt)",
1,
text_shape[0],
)));
}
let image_shape = image_embeds.shape();
let image_rank = image_shape.len() as u32;
if image_shape.len() != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"merge_embeddings: image_embeds must be rank-2 [N, D]",
image_rank,
image_shape,
)));
}
let t = text_shape[1];
let d_text = text_shape[2];
let n_total = image_shape[0];
let d_image = image_shape[1];
if d_text != d_image {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"merge_embeddings: hidden-dim D (text_embeds vs image_embeds)",
d_text,
d_image,
)));
}
if image_spans.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"merge_embeddings: image_spans (use forward(tokens) for the text-only path)",
)));
}
let mut total_width: usize = 0;
let mut prev_end: usize = 0;
for &(s, e) in image_spans.iter() {
if s >= e {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"merge_embeddings: image span (start, end)",
"start must be strictly less than end (empty spans not allowed)",
)));
}
if e > t {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"merge_embeddings: image span end vs text seq_len T",
"must be <= T",
format!("end={e}, T={t}"),
)));
}
if s < prev_end {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"merge_embeddings: image span order (s vs prev_end)",
"spans must be monotone non-overlapping (assemble_multimodal_prompt emits them in order)",
)));
}
total_width = total_width.checked_add(e - s).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"merge_embeddings: cumulative span width (total_width + (e - s))",
"usize",
[
("total_width", total_width as u64),
("span_width", (e - s) as u64),
],
))
})?;
prev_end = e;
}
if total_width != n_total {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"merge_embeddings: sum of caller-supplied placeholder span widths vs image_embeds row \
count N (expected = total_width, actual = n_total)",
total_width,
n_total,
)));
}
let pieces_cap = image_spans
.len()
.checked_mul(2)
.and_then(|n| n.checked_add(1))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"merge_embeddings: piece-count capacity (image_spans.len() * 2 + 1)",
"usize",
[("image_spans.len()", image_spans.len() as u64)],
))
})?;
let mut pieces: Vec<Array> = try_with_capacity(pieces_cap)?;
let d_i32 = d_text as i32;
let t_i32 = t as i32;
let mut text_cursor: usize = 0;
let mut image_cursor: usize = 0;
for &(s, e) in image_spans {
if s > text_cursor {
let start = [0_i32, text_cursor as i32, 0_i32];
let stop = [1_i32, s as i32, d_i32];
let strides = [1_i32, 1_i32, 1_i32];
pieces.push(ops::indexing::slice(text_embeds, &start, &stop, &strides)?);
}
let width = e - s;
let img_start = [image_cursor as i32, 0_i32];
let img_stop = [(image_cursor + width) as i32, d_i32];
let img_strides = [1_i32, 1_i32];
let img_slab = ops::indexing::slice(image_embeds, &img_start, &img_stop, &img_strides)?;
let img_slab = ops::shape::reshape(&img_slab, &(1_usize, width, d_text))?;
pieces.push(img_slab);
text_cursor = e;
image_cursor += width;
}
if text_cursor < t {
let start = [0_i32, text_cursor as i32, 0_i32];
let stop = [1_i32, t_i32, d_i32];
let strides = [1_i32, 1_i32, 1_i32];
pieces.push(ops::indexing::slice(text_embeds, &start, &stop, &strides)?);
}
let mut refs: Vec<&Array> = try_with_capacity(pieces.len())?;
refs.extend(pieces.iter());
ops::shape::concatenate(&refs, 1)
}