rlx_locateanything/
generation.rs1use crate::config::LocateAnythingConfig;
19use crate::embed::argmax_token;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
23pub enum GenerationMode {
24 Fast,
26 Slow,
28 #[default]
30 Hybrid,
31}
32
33impl GenerationMode {
34 pub fn parse(s: &str) -> Option<Self> {
35 match s.to_lowercase().as_str() {
36 "fast" => Some(Self::Fast),
37 "slow" => Some(Self::Slow),
38 "hybrid" => Some(Self::Hybrid),
39 _ => None,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct TokenIds {
47 pub box_start: u32,
48 pub box_end: u32,
49 pub coord_start: u32,
50 pub coord_end: u32,
51 pub ref_start: u32,
52 pub ref_end: u32,
53 pub none_token: u32,
54 pub null_token: u32,
55 pub switch_token: u32,
56 pub text_mask: u32,
57 pub im_end: u32,
58}
59
60impl TokenIds {
61 pub fn from_config(cfg: &LocateAnythingConfig) -> Self {
62 Self {
63 box_start: cfg.box_start_token_id,
64 box_end: cfg.box_end_token_id,
65 coord_start: cfg.coord_start_token_id,
66 coord_end: cfg.coord_end_token_id,
67 ref_start: cfg.ref_start_token_id,
68 ref_end: cfg.ref_end_token_id,
69 none_token: cfg.none_token_id,
70 null_token: cfg.text_config.null_token_id.unwrap_or(152_678),
71 switch_token: cfg.text_config.switch_token_id.unwrap_or(152_679),
72 text_mask: cfg.text_config.text_mask_token_id.unwrap_or(151_676),
73 im_end: cfg.text_config.eos_token_id,
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
79pub struct SampleOpts {
80 pub temperature: f32,
81 pub top_p: f32,
82 pub repetition_penalty: f32,
83 pub max_new_tokens: usize,
84 pub mode: GenerationMode,
85}
86
87impl Default for SampleOpts {
88 fn default() -> Self {
89 Self {
90 temperature: 0.7,
91 top_p: 0.9,
92 repetition_penalty: 1.1,
93 max_new_tokens: 2048,
94 mode: GenerationMode::Hybrid,
95 }
96 }
97}
98
99pub fn sample_token(logits: &[f32], opts: &SampleOpts, history: &[u32]) -> u32 {
101 debug_assert!(!logits.is_empty());
102 let mut scores: Vec<f32> = logits.to_vec();
103 if opts.repetition_penalty != 1.0 {
104 for &tok in history {
105 let i = tok as usize;
106 if i < scores.len() {
107 if scores[i] > 0.0 {
108 scores[i] /= opts.repetition_penalty;
109 } else {
110 scores[i] *= opts.repetition_penalty;
111 }
112 }
113 }
114 }
115 if opts.temperature > 0.0 {
116 for s in &mut scores {
117 *s /= opts.temperature;
118 }
119 sample_stochastic(&scores, opts.top_p)
120 } else {
121 argmax_token(&scores)
122 }
123}
124
125fn sample_stochastic(logits: &[f32], top_p: f32) -> u32 {
126 let mut idx: Vec<usize> = (0..logits.len()).collect();
127 idx.sort_by(|&a, &b| {
128 logits[b]
129 .partial_cmp(&logits[a])
130 .unwrap_or(std::cmp::Ordering::Equal)
131 });
132 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
133 let mut probs = vec![0f32; logits.len()];
134 let mut sum = 0f32;
135 for &i in &idx {
136 let p = (logits[i] - max).exp();
137 probs[i] = p;
138 sum += p;
139 }
140 if sum > 0.0 {
141 for p in &mut probs {
142 *p /= sum;
143 }
144 }
145 if top_p < 1.0 {
146 let mut cum = 0f32;
147 for &i in &idx {
148 cum += probs[i];
149 if cum > top_p {
150 for j in idx.iter().position(|&x| x == i).unwrap() + 1..idx.len() {
151 probs[idx[j]] = 0.0;
152 }
153 break;
154 }
155 }
156 }
157 let r: f32 = rand_uniform();
158 let mut c = 0f32;
159 for (i, &p) in probs.iter().enumerate() {
160 c += p;
161 if r <= c {
162 return i as u32;
163 }
164 }
165 argmax_token(logits)
166}
167
168fn rand_uniform() -> f32 {
169 use std::hash::{Hash, Hasher};
170 use std::time::SystemTime;
171 let mut h = std::collections::hash_map::DefaultHasher::new();
172 SystemTime::now().hash(&mut h);
173 (h.finish() % 10_000) as f32 / 10_000.0
174}
175
176pub fn hybrid_continue_mtp(out_type: &str, mode: GenerationMode) -> bool {
178 match mode {
179 GenerationMode::Fast => true,
180 GenerationMode::Slow => false,
181 GenerationMode::Hybrid => !matches!(out_type, "error_box"),
182 }
183}
184
185pub fn classify_ar_token(tok: u32, ids: &TokenIds) -> &'static str {
187 if tok == ids.im_end {
188 "im_end"
189 } else if tok == ids.box_end {
190 "box_end_ar"
191 } else if (ids.coord_start..=ids.coord_end).contains(&tok) || tok == ids.none_token {
192 "coord_ar"
193 } else {
194 "continue_ar"
195 }
196}