use crate::config::GemmaConfig;
use crate::multimodal::GemmaMultimodalConfig;
const BLOCK: f32 = -1e4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MediaContentSpan {
pub start: usize,
pub end: usize,
}
pub fn media_content_spans(
token_ids: &[u32],
mm_cfg: &GemmaMultimodalConfig,
) -> Vec<MediaContentSpan> {
let mut spans = Vec::new();
let mut i = 0usize;
while i < token_ids.len() {
if Some(token_ids[i]) == mm_cfg.boi_token_id {
let start = i + 1;
let mut j = start;
while j < token_ids.len() && Some(token_ids[j]) != mm_cfg.eoi_token_id {
j += 1;
}
if j > start {
spans.push(MediaContentSpan { start, end: j });
}
i = j.saturating_add(1);
continue;
}
if Some(token_ids[i]) == mm_cfg.boa_token_id {
let start = i + 1;
let mut j = start;
while j < token_ids.len() && Some(token_ids[j]) != mm_cfg.eoa_token_index {
j += 1;
}
if j > start {
spans.push(MediaContentSpan { start, end: j });
}
i = j.saturating_add(1);
continue;
}
i += 1;
}
spans
}
fn in_span(spans: &[MediaContentSpan], i: usize, j: usize) -> bool {
spans
.iter()
.any(|s| i >= s.start && i < s.end && j >= s.start && j < s.end)
}
pub fn build_vision_bidirectional_mask_2d(
token_ids: &[u32],
mm_cfg: &GemmaMultimodalConfig,
) -> Vec<f32> {
let seq = token_ids.len();
let spans = media_content_spans(token_ids, mm_cfg);
let mut out = vec![0f32; seq * seq];
for q in 0..seq {
for k in 0..seq {
let allow = if in_span(&spans, q, k) { true } else { k <= q };
if !allow {
out[q * seq + k] = BLOCK;
}
}
}
out
}
pub fn expand_attn_bias(mask_2d: &[f32], batch: usize, num_heads: usize, seq: usize) -> Vec<f32> {
assert_eq!(mask_2d.len(), seq * seq);
let mut out = vec![0f32; batch * num_heads * seq * seq];
for b in 0..batch {
for h in 0..num_heads {
let base = (b * num_heads + h) * seq * seq;
out[base..base + seq * seq].copy_from_slice(mask_2d);
}
}
out
}
pub fn build_multimodal_prefill_attn_bias(
token_ids: &[u32],
lm_cfg: &GemmaConfig,
mm_cfg: &GemmaMultimodalConfig,
batch: usize,
) -> Option<Vec<f32>> {
if !lm_cfg.use_bidirectional_vision() {
return None;
}
let seq = token_ids.len();
let mask = build_vision_bidirectional_mask_2d(token_ids, mm_cfg);
Some(expand_attn_bias(
&mask,
batch,
lm_cfg.num_attention_heads,
seq,
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn media_span_bidirectional_not_causal_across_future() {
let mm = GemmaMultimodalConfig {
image_token_id: Some(10),
boi_token_id: Some(1),
eoi_token_id: Some(2),
..Default::default()
};
let ids = [99, 1, 10, 10, 2, 88];
let m = build_vision_bidirectional_mask_2d(&ids, &mm);
let seq = ids.len();
assert_eq!(m[2 * seq + 3], 0.0);
assert_eq!(m[5 * seq + 3], 0.0);
assert_eq!(m[5], BLOCK);
}
}