1use crate::tensor::DenseTensor;
4use super::model::LlamaModel;
5
6#[derive(Debug, Clone)]
8pub struct GenerationConfig {
9 pub max_length: usize,
11 pub min_length: usize,
13 pub temperature: f64,
15 pub top_k: usize,
17 pub top_p: f64,
19 pub repetition_penalty: f64,
21 pub eos_token_id: Option<usize>,
23 pub pad_token_id: Option<usize>,
25 pub do_sample: bool,
27 pub num_beams: usize,
29 pub length_penalty: f64,
31}
32
33impl Default for GenerationConfig {
34 fn default() -> Self {
35 Self {
36 max_length: 256,
37 min_length: 0,
38 temperature: 1.0,
39 top_k: 0,
40 top_p: 0.0,
41 repetition_penalty: 1.0,
42 eos_token_id: None,
43 pad_token_id: None,
44 do_sample: false,
45 num_beams: 1,
46 length_penalty: 1.0,
47 }
48 }
49}
50
51impl GenerationConfig {
52 pub fn greedy() -> Self {
54 Self {
55 do_sample: false,
56 ..Self::default()
57 }
58 }
59
60 pub fn sampling(temperature: f64) -> Self {
62 Self {
63 do_sample: true,
64 temperature,
65 ..Self::default()
66 }
67 }
68
69 pub fn beam_search(num_beams: usize) -> Self {
71 Self {
72 do_sample: false,
73 num_beams,
74 ..Self::default()
75 }
76 }
77
78 pub fn with_max_length(mut self, max_length: usize) -> Self {
80 self.max_length = max_length;
81 self
82 }
83
84 pub fn with_temperature(mut self, temperature: f64) -> Self {
86 self.temperature = temperature;
87 self
88 }
89
90 pub fn with_top_k(mut self, top_k: usize) -> Self {
92 self.top_k = top_k;
93 self
94 }
95
96 pub fn with_top_p(mut self, top_p: f64) -> Self {
98 self.top_p = top_p;
99 self
100 }
101
102 pub fn with_eos_token_id(mut self, eos_token_id: usize) -> Self {
104 self.eos_token_id = Some(eos_token_id);
105 self
106 }
107}
108
109pub struct TextGenerator<'a> {
111 model: &'a LlamaModel,
113 config: GenerationConfig,
115}
116
117impl<'a> TextGenerator<'a> {
118 pub fn new(model: &'a LlamaModel, config: GenerationConfig) -> Self {
120 Self { model, config }
121 }
122
123 pub fn generate(&self, input_ids: &[usize]) -> Vec<usize> {
131 if self.config.num_beams > 1 {
132 self.generate_beam_search(input_ids)
133 } else if self.config.do_sample {
134 self.generate_sampling(input_ids)
135 } else {
136 self.generate_greedy(input_ids)
137 }
138 }
139
140 fn generate_greedy(&self, input_ids: &[usize]) -> Vec<usize> {
142 let mut current_ids = input_ids.to_vec();
143
144 for _ in 0..self.config.max_length {
145 let logits = self.model.forward_single(¤t_ids, None);
147
148 let seq_len = current_ids.len();
150 let last_logits = logits.get_row(seq_len - 1);
151
152 let mut probs = last_logits.clone();
154 if self.config.temperature != 1.0 {
155 probs = probs.scale(1.0 / self.config.temperature);
156 }
157
158 probs = probs.softmax(-1);
160
161 let next_token = self.argmax(probs.data());
163
164 if Some(next_token) == self.config.eos_token_id {
166 break;
167 }
168
169 current_ids.push(next_token);
170 }
171
172 current_ids
173 }
174
175 fn generate_sampling(&self, input_ids: &[usize]) -> Vec<usize> {
177 let mut current_ids = input_ids.to_vec();
178 let mut rng = rand::thread_rng();
179
180 for _ in 0..self.config.max_length {
181 let logits = self.model.forward_single(¤t_ids, None);
183
184 let seq_len = current_ids.len();
186 let last_logits = logits.get_row(seq_len - 1);
187
188 let mut probs = last_logits.clone();
190 if self.config.temperature != 1.0 {
191 probs = probs.scale(1.0 / self.config.temperature);
192 }
193
194 probs = probs.softmax(-1);
196
197 if self.config.top_k > 0 {
199 probs = self.top_k_filtering(&probs, self.config.top_k);
200 }
201
202 if self.config.top_p > 0.0 {
204 probs = self.top_p_filtering(&probs, self.config.top_p);
205 }
206
207 let next_token = self.sample_from_probs(probs.data(), &mut rng);
209
210 if Some(next_token) == self.config.eos_token_id {
212 break;
213 }
214
215 current_ids.push(next_token);
216 }
217
218 current_ids
219 }
220
221 fn generate_beam_search(&self, input_ids: &[usize]) -> Vec<usize> {
223 let mut beams: Vec<(Vec<usize>, f64)> = vec![(input_ids.to_vec(), 0.0)];
227
228 for _ in 0..self.config.max_length {
229 let mut candidates: Vec<(Vec<usize>, f64)> = Vec::new();
230
231 for (beam_ids, beam_score) in &beams {
232 let logits = self.model.forward_single(beam_ids, None);
234
235 let seq_len = beam_ids.len();
237 let last_logits = logits.get_row(seq_len - 1);
238
239 let top_indices = self.topk_indices(last_logits.data(), self.config.num_beams);
241
242 for &next_token in &top_indices {
243 let mut new_beam = beam_ids.clone();
244 new_beam.push(next_token);
245
246 let token_prob = last_logits.data()[next_token];
248 let new_score = beam_score + token_prob.ln();
249
250 candidates.push((new_beam, new_score));
251 }
252 }
253
254 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
256 beams = candidates.into_iter().take(self.config.num_beams).collect();
257
258 if beams.iter().all(|(ids, _)| {
260 ids.last() == self.config.eos_token_id.as_ref()
261 }) {
262 break;
263 }
264 }
265
266 beams.into_iter()
268 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
269 .map(|(ids, _)| ids)
270 .unwrap_or_else(|| input_ids.to_vec())
271 }
272
273 fn argmax(&self, data: &[f64]) -> usize {
275 data.iter()
276 .enumerate()
277 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
278 .map(|(i, _)| i)
279 .unwrap_or(0)
280 }
281
282 fn topk_indices(&self, data: &[f64], k: usize) -> Vec<usize> {
284 let mut indexed: Vec<(usize, &f64)> = data.iter().enumerate().collect();
285 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
286 indexed.into_iter().take(k).map(|(i, _)| i).collect()
287 }
288
289 fn top_k_filtering(&self, probs: &DenseTensor, k: usize) -> DenseTensor {
291 let data = probs.data();
292 let top_indices = self.topk_indices(data, k);
293 let threshold = top_indices.iter()
294 .map(|&i| data[i])
295 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
296 .unwrap_or(0.0);
297
298 let mut filtered = probs.clone();
299 for (i, &prob) in data.iter().enumerate() {
300 if prob < threshold {
301 filtered.data_mut()[i] = 0.0;
302 }
303 }
304
305 let sum: f64 = filtered.data().iter().sum();
307 if sum > 0.0 {
308 filtered = filtered.scale(1.0 / sum);
309 }
310
311 filtered
312 }
313
314 fn top_p_filtering(&self, probs: &DenseTensor, p: f64) -> DenseTensor {
316 let data = probs.data();
317 let mut indexed: Vec<(usize, &f64)> = data.iter().enumerate().collect();
318 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
319
320 let mut cumulative_prob = 0.0;
321 let mut cutoff_index = indexed.len();
322
323 for (i, (_, &prob)) in indexed.iter().enumerate() {
324 cumulative_prob += prob;
325 if cumulative_prob >= p {
326 cutoff_index = i + 1;
327 break;
328 }
329 }
330
331 let threshold = indexed.into_iter()
332 .take(cutoff_index)
333 .map(|(_, &prob)| prob)
334 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
335 .unwrap_or(0.0);
336
337 let mut filtered = probs.clone();
338 for (i, &prob) in data.iter().enumerate() {
339 if prob < threshold {
340 filtered.data_mut()[i] = 0.0;
341 }
342 }
343
344 let sum: f64 = filtered.data().iter().sum();
346 if sum > 0.0 {
347 filtered = filtered.scale(1.0 / sum);
348 }
349
350 filtered
351 }
352
353 fn sample_from_probs(&self, probs: &[f64], rng: &mut impl rand::Rng) -> usize {
355 let r: f64 = rng.gen();
356 let mut cumulative = 0.0;
357
358 for (i, &prob) in probs.iter().enumerate() {
359 cumulative += prob;
360 if r < cumulative {
361 return i;
362 }
363 }
364
365 probs.len() - 1
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_generation_config() {
375 let config = GenerationConfig::default();
376 assert_eq!(config.max_length, 256);
377 assert_eq!(config.temperature, 1.0);
378 assert!(!config.do_sample);
379
380 let greedy = GenerationConfig::greedy();
381 assert!(!greedy.do_sample);
382
383 let sampling = GenerationConfig::sampling(0.8);
384 assert!(sampling.do_sample);
385 assert_eq!(sampling.temperature, 0.8);
386 }
387
388 #[test]
389 fn test_argmax() {
390 let model = create_test_model();
391 let generator = TextGenerator::new(
392 &model,
393 GenerationConfig::default(),
394 );
395
396 let data = vec![0.1, 0.3, 0.5, 0.2, 0.4];
397 assert_eq!(generator.argmax(&data), 2);
398 }
399
400 #[test]
401 fn test_topk_indices() {
402 let model = create_test_model();
403 let generator = TextGenerator::new(
404 &model,
405 GenerationConfig::default(),
406 );
407
408 let data = vec![0.1, 0.5, 0.3, 0.9, 0.2];
409 let top2 = generator.topk_indices(&data, 2);
410 assert_eq!(top2, vec![3, 1]);
411 }
412}
413
414#[cfg(test)]
415fn create_test_model() -> LlamaModel {
416 use super::model::LlamaModel;
417 use super::layers::{MultiHeadAttention, FeedForward, RMSNorm};
418 use super::loader::LlamaConfig;
419 use crate::tensor::DenseTensor;
420
421 let config = LlamaConfig::llama_7b();
422 let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
423
424 let hidden_dim = config.hidden_size;
425 let num_heads = config.num_attention_heads;
426
427 let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
428 let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
429 let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
430 let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
431 let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
432
433 let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
434 let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
435 let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
436 let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
437
438 let input_layernorm = RMSNorm::default(hidden_dim);
439 let post_attention_layernorm = RMSNorm::default(hidden_dim);
440
441 let layer = super::model::LlamaDecoderLayer::new(
442 self_attn, mlp, input_layernorm, post_attention_layernorm
443 );
444
445 let layers = vec![layer; 2]; let norm = RMSNorm::default(hidden_dim);
447
448 LlamaModel::new(config, embed_tokens, layers, norm, None)
449}