entrenar/finetune/instruct_pipeline/
generate.rs1#[allow(clippy::wildcard_imports)]
4use super::*;
5
6impl InstructPipeline {
7 pub fn generate(&self, prompt: &str, config: &GenerateConfig) -> crate::Result<String> {
23 let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
24 crate::Error::ConfigError("No tokenizer loaded — cannot generate text".into())
25 })?;
26
27 let mut token_ids = tokenizer.encode(prompt);
28 let prompt_len = token_ids.len();
29 let eos_token = tokenizer.eos_id().unwrap_or(151643); let vocab_size = self.model.config().vocab_size;
32
33 for _ in 0..config.max_new_tokens {
34 if token_ids.len() >= self.config.max_seq_len {
36 break;
37 }
38
39 let hidden = self.model.forward_hidden_with_lora(&token_ids, &self.lora_layers);
41 let seq_len = token_ids.len();
42 let hidden_size = self.model.config().hidden_size;
43
44 let lm_weight = self.model.lm_head_weight();
46 let logits =
47 crate::autograd::matmul_nt(&hidden, lm_weight, seq_len, hidden_size, vocab_size);
48
49 let logits_data = logits.data();
51 let logits_slice = logits_data.as_slice().unwrap_or(&[]);
52 let last_pos_start = (seq_len - 1) * vocab_size;
53 let last_pos_logits = &logits_slice[last_pos_start..last_pos_start + vocab_size];
54
55 let next_token = sample_token(last_pos_logits, config.temperature, config.top_k);
57
58 if next_token == eos_token {
59 break;
60 }
61
62 if config.stop_tokens.contains(&next_token) {
64 break;
65 }
66
67 token_ids.push(next_token);
68 }
69
70 let generated_ids = &token_ids[prompt_len..];
72 Ok(tokenizer.decode(generated_ids))
73 }
74
75 pub fn generate_chat(
91 &self,
92 system: &str,
93 user_message: &str,
94 config: &GenerateConfig,
95 ) -> crate::Result<String> {
96 let prompt = format!(
97 "<|im_start|>system\n{system}<|im_end|>\n\
98 <|im_start|>user\n{user_message}<|im_end|>\n\
99 <|im_start|>assistant\n"
100 );
101
102 let mut response = self.generate(&prompt, config)?;
103
104 if let Some(stripped) = response.strip_suffix("<|im_end|>") {
106 response = stripped.to_string();
107 }
108
109 Ok(response)
110 }
111}
112
113impl GenerateConfig {
114 #[must_use]
116 pub fn greedy(max_new_tokens: usize) -> Self {
117 contract_pre_greedy!();
118 Self { max_new_tokens, temperature: 0.0, top_k: 0, stop_tokens: Vec::new() }
119 }
120}
121
122impl Default for GenerateConfig {
123 fn default() -> Self {
124 Self { max_new_tokens: 256, temperature: 0.7, top_k: 50, stop_tokens: Vec::new() }
125 }
126}