Skip to main content

entrenar/finetune/instruct_pipeline/
generate.rs

1//! Text generation methods: `generate`, `generate_chat`.
2
3#[allow(clippy::wildcard_imports)]
4use super::*;
5
6impl InstructPipeline {
7    /// Autoregressive text generation with LoRA adapters (entrenar#246).
8    ///
9    /// Generates tokens one at a time using the transformer + LoRA forward pass.
10    /// Supports greedy decoding (temperature=0) and temperature-scaled sampling
11    /// with optional top-k filtering.
12    ///
13    /// # Arguments
14    /// * `prompt` - Input text to continue from
15    /// * `config` - Generation parameters (max tokens, temperature, top-k)
16    ///
17    /// # Returns
18    /// Generated text (excluding the input prompt)
19    ///
20    /// # Errors
21    /// Returns error if no tokenizer is loaded.
22    pub fn generate(&self, prompt: &str, config: &GenerateConfig) -> crate::Result<String> {
23        let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
24            crate::Error::ConfigError("No tokenizer loaded — cannot generate text".into())
25        })?;
26
27        let mut token_ids = tokenizer.encode(prompt);
28        let prompt_len = token_ids.len();
29        let eos_token = tokenizer.eos_id().unwrap_or(151643); // Qwen2 default EOS
30
31        let vocab_size = self.model.config().vocab_size;
32
33        for _ in 0..config.max_new_tokens {
34            // Truncate to max_seq_len if needed
35            if token_ids.len() >= self.config.max_seq_len {
36                break;
37            }
38
39            // Forward pass with LoRA
40            let hidden = self.model.forward_hidden_with_lora(&token_ids, &self.lora_layers);
41            let seq_len = token_ids.len();
42            let hidden_size = self.model.config().hidden_size;
43
44            // Apply lm_head to get logits
45            let lm_weight = self.model.lm_head_weight();
46            let logits =
47                crate::autograd::matmul_nt(&hidden, lm_weight, seq_len, hidden_size, vocab_size);
48
49            // Extract logits for last position
50            let logits_data = logits.data();
51            let logits_slice = logits_data.as_slice().unwrap_or(&[]);
52            let last_pos_start = (seq_len - 1) * vocab_size;
53            let last_pos_logits = &logits_slice[last_pos_start..last_pos_start + vocab_size];
54
55            // Sample next token
56            let next_token = sample_token(last_pos_logits, config.temperature, config.top_k);
57
58            if next_token == eos_token {
59                break;
60            }
61
62            // Check stop tokens
63            if config.stop_tokens.contains(&next_token) {
64                break;
65            }
66
67            token_ids.push(next_token);
68        }
69
70        // Decode only the generated part (not the prompt)
71        let generated_ids = &token_ids[prompt_len..];
72        Ok(tokenizer.decode(generated_ids))
73    }
74
75    /// Generate a chat response using ChatML format (entrenar#246).
76    ///
77    /// Formats messages as ChatML (`<|im_start|>` / `<|im_end|>`) and generates
78    /// the assistant's response.
79    ///
80    /// # Arguments
81    /// * `system` - System prompt
82    /// * `user_message` - User's input message
83    /// * `config` - Generation parameters
84    ///
85    /// # Returns
86    /// The assistant's generated response text.
87    ///
88    /// # Errors
89    /// Returns error if no tokenizer is loaded.
90    pub fn generate_chat(
91        &self,
92        system: &str,
93        user_message: &str,
94        config: &GenerateConfig,
95    ) -> crate::Result<String> {
96        let prompt = format!(
97            "<|im_start|>system\n{system}<|im_end|>\n\
98             <|im_start|>user\n{user_message}<|im_end|>\n\
99             <|im_start|>assistant\n"
100        );
101
102        let mut response = self.generate(&prompt, config)?;
103
104        // Strip trailing <|im_end|> if present
105        if let Some(stripped) = response.strip_suffix("<|im_end|>") {
106            response = stripped.to_string();
107        }
108
109        Ok(response)
110    }
111}
112
113impl GenerateConfig {
114    /// Create a greedy decoding config (deterministic, always picks highest probability token).
115    #[must_use]
116    pub fn greedy(max_new_tokens: usize) -> Self {
117        contract_pre_greedy!();
118        Self { max_new_tokens, temperature: 0.0, top_k: 0, stop_tokens: Vec::new() }
119    }
120}
121
122impl Default for GenerateConfig {
123    fn default() -> Self {
124        Self { max_new_tokens: 256, temperature: 0.7, top_k: 50, stop_tokens: Vec::new() }
125    }
126}