Skip to main content

rlx_locateanything/
mtp.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Parallel box decoding — MTP sampling (ported from HF `generate_utils`).
17
18use crate::generation::TokenIds;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum BoxFrameKind {
22    Empty,
23    Legal,
24    Illegal,
25}
26
27fn prob_at(probs: &[f32], vocab: usize, pos: usize, token: u32) -> f32 {
28    probs
29        .get(pos * vocab + token as usize)
30        .copied()
31        .unwrap_or(0.0)
32}
33
34pub fn is_valid_box_frame(probs: &[f32], vocab: usize, ids: &TokenIds) -> BoxFrameKind {
35    let p_start = prob_at(probs, vocab, 0, ids.box_start);
36    if p_start >= 0.6 {
37        let none = prob_at(probs, vocab, 1, ids.none_token);
38        let end = prob_at(probs, vocab, 2, ids.box_end);
39        let null3 = prob_at(probs, vocab, 3, ids.null_token);
40        let null4 = prob_at(probs, vocab, 4, ids.null_token);
41        if none > 0.2 && end > 0.2 && null3 > 0.1 && null4 > 0.1 {
42            return BoxFrameKind::Empty;
43        }
44    }
45    let end_score = prob_at(probs, vocab, 5, ids.box_end)
46        + prob_at(probs, vocab, 5, ids.none_token)
47        + prob_at(probs, vocab, 5, ids.im_end);
48    if end_score >= 0.2 {
49        BoxFrameKind::Legal
50    } else {
51        BoxFrameKind::Illegal
52    }
53}
54
55/// Decode 6-position MTP block from per-position logits rows `logits: [6 * vocab]`.
56pub fn decode_bbox_block(
57    logits: &[f32],
58    vocab: usize,
59    ids: &TokenIds,
60    generation_mode: &str,
61) -> Option<Vec<u32>> {
62    let block = 6usize;
63    if logits.len() < block * vocab {
64        return None;
65    }
66    let mut probs = vec![0f32; block * vocab];
67    for t in 0..block {
68        let row = &logits[t * vocab..(t + 1) * vocab];
69        let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
70        let sum: f32 = row.iter().map(|x| (x - max).exp()).sum();
71        for (i, &x) in row.iter().enumerate() {
72            probs[t * vocab + i] = (x - max).exp() / sum;
73        }
74    }
75    let frame = is_valid_box_frame(&probs[..vocab], vocab, ids);
76    match frame {
77        BoxFrameKind::Empty => Some(vec![
78            ids.box_start,
79            ids.none_token,
80            ids.box_end,
81            ids.null_token,
82            ids.null_token,
83            ids.null_token,
84        ]),
85        BoxFrameKind::Illegal => None,
86        BoxFrameKind::Legal => decode_bbox_coords(&probs, vocab, ids, generation_mode),
87    }
88}
89
90fn decode_bbox_coords(
91    probs: &[f32],
92    vocab: usize,
93    ids: &TokenIds,
94    generation_mode: &str,
95) -> Option<Vec<u32>> {
96    let keep_k = 5usize;
97    let mut coords = [0u32; 4];
98    for i in 0..4 {
99        let row = &probs[(1 + i) * vocab..(2 + i) * vocab];
100        let mut top: Vec<(f32, u32)> = row
101            .iter()
102            .enumerate()
103            .map(|(id, &p)| (p, id as u32))
104            .collect();
105        top.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
106        top.truncate(keep_k);
107        let valid: Vec<_> = top
108            .iter()
109            .filter(|(_, id)| *id >= ids.coord_start && *id <= ids.coord_end)
110            .collect();
111        if valid.is_empty() {
112            return None;
113        }
114        coords[i] = valid[0].1;
115        if generation_mode == "hybrid" && valid.len() > 1 && valid[0].0 < 0.9 {
116            let min_id = valid.iter().map(|(_, id)| *id).min().unwrap();
117            let max_id = valid.iter().map(|(_, id)| *id).max().unwrap();
118            if max_id - min_id > 60 {
119                coords[i] = 0;
120            }
121        }
122    }
123    Some(vec![
124        ids.box_start,
125        coords[0],
126        coords[1],
127        coords[2],
128        coords[3],
129        ids.box_end,
130    ])
131}
132
133#[derive(Debug, Clone)]
134pub struct PatternOut {
135    pub kind: &'static str,
136    pub tokens: Vec<u32>,
137    pub need_ar: bool,
138    pub terminal: bool,
139}
140
141pub fn handle_pattern(tokens: &[u32], ids: &TokenIds, generation_mode: &str) -> PatternOut {
142    if tokens.is_empty() {
143        return PatternOut {
144            kind: "im_end",
145            tokens: vec![ids.im_end],
146            need_ar: false,
147            terminal: true,
148        };
149    }
150    if tokens[0] == ids.null_token {
151        return PatternOut {
152            kind: "im_end",
153            tokens: vec![ids.im_end],
154            need_ar: false,
155            terminal: true,
156        };
157    }
158    if tokens[0] == ids.im_end {
159        return PatternOut {
160            kind: "im_end",
161            tokens: vec![ids.im_end],
162            need_ar: false,
163            terminal: true,
164        };
165    }
166    if tokens.len() >= 2 && tokens[0] == ids.box_start && tokens[1] == ids.none_token {
167        return PatternOut {
168            kind: "empty_box",
169            tokens: vec![ids.box_start, ids.none_token, ids.box_end],
170            need_ar: false,
171            terminal: false,
172        };
173    }
174    if tokens[0] == ids.box_start {
175        let mut coord_ix = 1usize;
176        for &c in &tokens[1..tokens.len().min(5)] {
177            if c >= ids.coord_start && c <= ids.coord_end {
178                coord_ix += 1;
179            } else {
180                break;
181            }
182        }
183        if coord_ix == 5 && tokens.get(5) == Some(&ids.box_end) {
184            return PatternOut {
185                kind: "coord_box",
186                tokens: tokens.to_vec(),
187                need_ar: false,
188                terminal: false,
189            };
190        }
191        if coord_ix == 3 && tokens.get(3) == Some(&ids.box_end) {
192            return PatternOut {
193                kind: "point_box",
194                tokens: tokens[..4].to_vec(),
195                need_ar: false,
196                terminal: false,
197            };
198        }
199        if generation_mode == "fast" {
200            return PatternOut {
201                kind: "coord_box",
202                tokens: tokens.to_vec(),
203                need_ar: false,
204                terminal: false,
205            };
206        }
207        return PatternOut {
208            kind: "error_box",
209            tokens: tokens[..coord_ix].to_vec(),
210            need_ar: true,
211            terminal: false,
212        };
213    }
214    let mut out: Vec<u32> = tokens.to_vec();
215    if let Some(pos) = out.iter().position(|&t| t == ids.null_token) {
216        out.truncate(pos);
217    }
218    if out.len() >= 2 && out[out.len() - 1] == out[out.len() - 2] {
219        out.pop();
220    }
221    PatternOut {
222        kind: "ref_object",
223        tokens: out,
224        need_ar: false,
225        terminal: false,
226    }
227}
228
229/// Greedy argmax per row in a logits slab `[n_pos * vocab]`.
230pub fn argmax_rows(logits: &[f32], vocab: usize, n_pos: usize) -> Vec<u32> {
231    let mut out = Vec::with_capacity(n_pos);
232    for t in 0..n_pos {
233        let row = &logits[t * vocab..(t + 1) * vocab];
234        let (id, _) = row
235            .iter()
236            .enumerate()
237            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
238            .unwrap();
239        out.push(id as u32);
240    }
241    out
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::generation::TokenIds;
248
249    #[test]
250    fn handle_pattern_empty_box() {
251        let ids = TokenIds {
252            box_start: 10,
253            box_end: 11,
254            coord_start: 100,
255            coord_end: 200,
256            ref_start: 0,
257            ref_end: 0,
258            none_token: 12,
259            null_token: 13,
260            switch_token: 14,
261            text_mask: 15,
262            im_end: 99,
263        };
264        let pat = handle_pattern(&[10, 12, 11, 13, 13, 13], &ids, "hybrid");
265        assert_eq!(pat.kind, "empty_box");
266        assert_eq!(pat.tokens.len(), 3);
267    }
268}