Skip to main content

rlx_locateanything/
generation.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Text generation — slow (AR), fast (MTP), and hybrid modes.
17
18use crate::config::LocateAnythingConfig;
19use crate::embed::argmax_token;
20
21/// Inference decoding strategy (HF `generation_mode`).
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
23pub enum GenerationMode {
24    /// MTP only — parallel box blocks.
25    Fast,
26    /// Autoregressive one token at a time.
27    Slow,
28    /// MTP with AR fallback on uncertain boxes (default).
29    #[default]
30    Hybrid,
31}
32
33impl GenerationMode {
34    pub fn parse(s: &str) -> Option<Self> {
35        match s.to_lowercase().as_str() {
36            "fast" => Some(Self::Fast),
37            "slow" => Some(Self::Slow),
38            "hybrid" => Some(Self::Hybrid),
39            _ => None,
40        }
41    }
42}
43
44/// Special token ids used during MTP / box decoding.
45#[derive(Debug, Clone)]
46pub struct TokenIds {
47    pub box_start: u32,
48    pub box_end: u32,
49    pub coord_start: u32,
50    pub coord_end: u32,
51    pub ref_start: u32,
52    pub ref_end: u32,
53    pub none_token: u32,
54    pub null_token: u32,
55    pub switch_token: u32,
56    pub text_mask: u32,
57    pub im_end: u32,
58}
59
60impl TokenIds {
61    pub fn from_config(cfg: &LocateAnythingConfig) -> Self {
62        Self {
63            box_start: cfg.box_start_token_id,
64            box_end: cfg.box_end_token_id,
65            coord_start: cfg.coord_start_token_id,
66            coord_end: cfg.coord_end_token_id,
67            ref_start: cfg.ref_start_token_id,
68            ref_end: cfg.ref_end_token_id,
69            none_token: cfg.none_token_id,
70            null_token: cfg.text_config.null_token_id.unwrap_or(152_678),
71            switch_token: cfg.text_config.switch_token_id.unwrap_or(152_679),
72            text_mask: cfg.text_config.text_mask_token_id.unwrap_or(151_676),
73            im_end: cfg.text_config.eos_token_id,
74        }
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct SampleOpts {
80    pub temperature: f32,
81    pub top_p: f32,
82    pub repetition_penalty: f32,
83    pub max_new_tokens: usize,
84    pub mode: GenerationMode,
85}
86
87impl Default for SampleOpts {
88    fn default() -> Self {
89        Self {
90            temperature: 0.7,
91            top_p: 0.9,
92            repetition_penalty: 1.1,
93            max_new_tokens: 2048,
94            mode: GenerationMode::Hybrid,
95        }
96    }
97}
98
99/// Greedy or temperature-scaled sample from a single logits row `[vocab]`.
100pub fn sample_token(logits: &[f32], opts: &SampleOpts, history: &[u32]) -> u32 {
101    debug_assert!(!logits.is_empty());
102    let mut scores: Vec<f32> = logits.to_vec();
103    if opts.repetition_penalty != 1.0 {
104        for &tok in history {
105            let i = tok as usize;
106            if i < scores.len() {
107                if scores[i] > 0.0 {
108                    scores[i] /= opts.repetition_penalty;
109                } else {
110                    scores[i] *= opts.repetition_penalty;
111                }
112            }
113        }
114    }
115    if opts.temperature > 0.0 {
116        for s in &mut scores {
117            *s /= opts.temperature;
118        }
119        sample_stochastic(&scores, opts.top_p)
120    } else {
121        argmax_token(&scores)
122    }
123}
124
125fn sample_stochastic(logits: &[f32], top_p: f32) -> u32 {
126    let mut idx: Vec<usize> = (0..logits.len()).collect();
127    idx.sort_by(|&a, &b| {
128        logits[b]
129            .partial_cmp(&logits[a])
130            .unwrap_or(std::cmp::Ordering::Equal)
131    });
132    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
133    let mut probs = vec![0f32; logits.len()];
134    let mut sum = 0f32;
135    for &i in &idx {
136        let p = (logits[i] - max).exp();
137        probs[i] = p;
138        sum += p;
139    }
140    if sum > 0.0 {
141        for p in &mut probs {
142            *p /= sum;
143        }
144    }
145    if top_p < 1.0 {
146        let mut cum = 0f32;
147        for &i in &idx {
148            cum += probs[i];
149            if cum > top_p {
150                for j in idx.iter().position(|&x| x == i).unwrap() + 1..idx.len() {
151                    probs[idx[j]] = 0.0;
152                }
153                break;
154            }
155        }
156    }
157    let r: f32 = rand_uniform();
158    let mut c = 0f32;
159    for (i, &p) in probs.iter().enumerate() {
160        c += p;
161        if r <= c {
162            return i as u32;
163        }
164    }
165    argmax_token(logits)
166}
167
168fn rand_uniform() -> f32 {
169    use std::hash::{Hash, Hasher};
170    use std::time::SystemTime;
171    let mut h = std::collections::hash_map::DefaultHasher::new();
172    SystemTime::now().hash(&mut h);
173    (h.finish() % 10_000) as f32 / 10_000.0
174}
175
176/// After sampling an MTP block, decide whether to continue in MTP or switch to AR (hybrid).
177pub fn hybrid_continue_mtp(out_type: &str, mode: GenerationMode) -> bool {
178    match mode {
179        GenerationMode::Fast => true,
180        GenerationMode::Slow => false,
181        GenerationMode::Hybrid => !matches!(out_type, "error_box"),
182    }
183}
184
185/// Classify a sampled AR token for hybrid mode switching.
186pub fn classify_ar_token(tok: u32, ids: &TokenIds) -> &'static str {
187    if tok == ids.im_end {
188        "im_end"
189    } else if tok == ids.box_end {
190        "box_end_ar"
191    } else if (ids.coord_start..=ids.coord_end).contains(&tok) || tok == ids.none_token {
192        "coord_ar"
193    } else {
194        "continue_ar"
195    }
196}