1use crate::generation::TokenIds;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum BoxFrameKind {
22 Empty,
23 Legal,
24 Illegal,
25}
26
27fn prob_at(probs: &[f32], vocab: usize, pos: usize, token: u32) -> f32 {
28 probs
29 .get(pos * vocab + token as usize)
30 .copied()
31 .unwrap_or(0.0)
32}
33
34pub fn is_valid_box_frame(probs: &[f32], vocab: usize, ids: &TokenIds) -> BoxFrameKind {
35 let p_start = prob_at(probs, vocab, 0, ids.box_start);
36 if p_start >= 0.6 {
37 let none = prob_at(probs, vocab, 1, ids.none_token);
38 let end = prob_at(probs, vocab, 2, ids.box_end);
39 let null3 = prob_at(probs, vocab, 3, ids.null_token);
40 let null4 = prob_at(probs, vocab, 4, ids.null_token);
41 if none > 0.2 && end > 0.2 && null3 > 0.1 && null4 > 0.1 {
42 return BoxFrameKind::Empty;
43 }
44 }
45 let end_score = prob_at(probs, vocab, 5, ids.box_end)
46 + prob_at(probs, vocab, 5, ids.none_token)
47 + prob_at(probs, vocab, 5, ids.im_end);
48 if end_score >= 0.2 {
49 BoxFrameKind::Legal
50 } else {
51 BoxFrameKind::Illegal
52 }
53}
54
55pub fn decode_bbox_block(
57 logits: &[f32],
58 vocab: usize,
59 ids: &TokenIds,
60 generation_mode: &str,
61) -> Option<Vec<u32>> {
62 let block = 6usize;
63 if logits.len() < block * vocab {
64 return None;
65 }
66 let mut probs = vec![0f32; block * vocab];
67 for t in 0..block {
68 let row = &logits[t * vocab..(t + 1) * vocab];
69 let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
70 let sum: f32 = row.iter().map(|x| (x - max).exp()).sum();
71 for (i, &x) in row.iter().enumerate() {
72 probs[t * vocab + i] = (x - max).exp() / sum;
73 }
74 }
75 let frame = is_valid_box_frame(&probs[..vocab], vocab, ids);
76 match frame {
77 BoxFrameKind::Empty => Some(vec![
78 ids.box_start,
79 ids.none_token,
80 ids.box_end,
81 ids.null_token,
82 ids.null_token,
83 ids.null_token,
84 ]),
85 BoxFrameKind::Illegal => None,
86 BoxFrameKind::Legal => decode_bbox_coords(&probs, vocab, ids, generation_mode),
87 }
88}
89
90fn decode_bbox_coords(
91 probs: &[f32],
92 vocab: usize,
93 ids: &TokenIds,
94 generation_mode: &str,
95) -> Option<Vec<u32>> {
96 let keep_k = 5usize;
97 let mut coords = [0u32; 4];
98 for i in 0..4 {
99 let row = &probs[(1 + i) * vocab..(2 + i) * vocab];
100 let mut top: Vec<(f32, u32)> = row
101 .iter()
102 .enumerate()
103 .map(|(id, &p)| (p, id as u32))
104 .collect();
105 top.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
106 top.truncate(keep_k);
107 let valid: Vec<_> = top
108 .iter()
109 .filter(|(_, id)| *id >= ids.coord_start && *id <= ids.coord_end)
110 .collect();
111 if valid.is_empty() {
112 return None;
113 }
114 coords[i] = valid[0].1;
115 if generation_mode == "hybrid" && valid.len() > 1 && valid[0].0 < 0.9 {
116 let min_id = valid.iter().map(|(_, id)| *id).min().unwrap();
117 let max_id = valid.iter().map(|(_, id)| *id).max().unwrap();
118 if max_id - min_id > 60 {
119 coords[i] = 0;
120 }
121 }
122 }
123 Some(vec![
124 ids.box_start,
125 coords[0],
126 coords[1],
127 coords[2],
128 coords[3],
129 ids.box_end,
130 ])
131}
132
133#[derive(Debug, Clone)]
134pub struct PatternOut {
135 pub kind: &'static str,
136 pub tokens: Vec<u32>,
137 pub need_ar: bool,
138 pub terminal: bool,
139}
140
141pub fn handle_pattern(tokens: &[u32], ids: &TokenIds, generation_mode: &str) -> PatternOut {
142 if tokens.is_empty() {
143 return PatternOut {
144 kind: "im_end",
145 tokens: vec![ids.im_end],
146 need_ar: false,
147 terminal: true,
148 };
149 }
150 if tokens[0] == ids.null_token {
151 return PatternOut {
152 kind: "im_end",
153 tokens: vec![ids.im_end],
154 need_ar: false,
155 terminal: true,
156 };
157 }
158 if tokens[0] == ids.im_end {
159 return PatternOut {
160 kind: "im_end",
161 tokens: vec![ids.im_end],
162 need_ar: false,
163 terminal: true,
164 };
165 }
166 if tokens.len() >= 2 && tokens[0] == ids.box_start && tokens[1] == ids.none_token {
167 return PatternOut {
168 kind: "empty_box",
169 tokens: vec![ids.box_start, ids.none_token, ids.box_end],
170 need_ar: false,
171 terminal: false,
172 };
173 }
174 if tokens[0] == ids.box_start {
175 let mut coord_ix = 1usize;
176 for &c in &tokens[1..tokens.len().min(5)] {
177 if c >= ids.coord_start && c <= ids.coord_end {
178 coord_ix += 1;
179 } else {
180 break;
181 }
182 }
183 if coord_ix == 5 && tokens.get(5) == Some(&ids.box_end) {
184 return PatternOut {
185 kind: "coord_box",
186 tokens: tokens.to_vec(),
187 need_ar: false,
188 terminal: false,
189 };
190 }
191 if coord_ix == 3 && tokens.get(3) == Some(&ids.box_end) {
192 return PatternOut {
193 kind: "point_box",
194 tokens: tokens[..4].to_vec(),
195 need_ar: false,
196 terminal: false,
197 };
198 }
199 if generation_mode == "fast" {
200 return PatternOut {
201 kind: "coord_box",
202 tokens: tokens.to_vec(),
203 need_ar: false,
204 terminal: false,
205 };
206 }
207 return PatternOut {
208 kind: "error_box",
209 tokens: tokens[..coord_ix].to_vec(),
210 need_ar: true,
211 terminal: false,
212 };
213 }
214 let mut out: Vec<u32> = tokens.to_vec();
215 if let Some(pos) = out.iter().position(|&t| t == ids.null_token) {
216 out.truncate(pos);
217 }
218 if out.len() >= 2 && out[out.len() - 1] == out[out.len() - 2] {
219 out.pop();
220 }
221 PatternOut {
222 kind: "ref_object",
223 tokens: out,
224 need_ar: false,
225 terminal: false,
226 }
227}
228
229pub fn argmax_rows(logits: &[f32], vocab: usize, n_pos: usize) -> Vec<u32> {
231 let mut out = Vec::with_capacity(n_pos);
232 for t in 0..n_pos {
233 let row = &logits[t * vocab..(t + 1) * vocab];
234 let (id, _) = row
235 .iter()
236 .enumerate()
237 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
238 .unwrap();
239 out.push(id as u32);
240 }
241 out
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::generation::TokenIds;
248
249 #[test]
250 fn handle_pattern_empty_box() {
251 let ids = TokenIds {
252 box_start: 10,
253 box_end: 11,
254 coord_start: 100,
255 coord_end: 200,
256 ref_start: 0,
257 ref_end: 0,
258 none_token: 12,
259 null_token: 13,
260 switch_token: 14,
261 text_mask: 15,
262 im_end: 99,
263 };
264 let pat = handle_pattern(&[10, 12, 11, 13, 13, 13], &ids, "hybrid");
265 assert_eq!(pat.kind, "empty_box");
266 assert_eq!(pat.tokens.len(), 3);
267 }
268}