Skip to main content

car_inference/tasks/
generate.rs

1//! Text generation with sampling.
2
3use candle_core::Tensor;
4use serde::{Deserialize, Serialize};
5
6use crate::backend::CandleBackend;
7use crate::InferenceError;
8
9/// Parameters controlling generation behavior.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct GenerateParams {
12    /// Sampling temperature (0.0 = greedy, 1.0 = full distribution).
13    #[serde(default = "default_temperature")]
14    pub temperature: f64,
15    /// Top-p (nucleus) sampling threshold.
16    #[serde(default = "default_top_p")]
17    pub top_p: f64,
18    /// Top-k sampling (0 = disabled).
19    #[serde(default)]
20    pub top_k: usize,
21    /// Maximum tokens to generate.
22    #[serde(default = "default_max_tokens")]
23    pub max_tokens: usize,
24    /// Stop sequences — generation halts when any is produced.
25    #[serde(default)]
26    pub stop: Vec<String>,
27}
28
29fn default_temperature() -> f64 { 0.7 }
30fn default_top_p() -> f64 { 0.9 }
31fn default_max_tokens() -> usize { 512 }
32
33impl Default for GenerateParams {
34    fn default() -> Self {
35        Self {
36            temperature: default_temperature(),
37            top_p: default_top_p(),
38            top_k: 0,
39            max_tokens: default_max_tokens(),
40            stop: Vec::new(),
41        }
42    }
43}
44
45/// A text generation request.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct GenerateRequest {
48    /// The prompt to complete.
49    pub prompt: String,
50    /// Optional model override.
51    pub model: Option<String>,
52    /// Generation parameters.
53    #[serde(default)]
54    pub params: GenerateParams,
55    /// Optional memory context to prepend to the prompt.
56    /// When provided, this is injected as a system-level context block
57    /// before the user prompt, grounding the model's response.
58    #[serde(default)]
59    pub context: Option<String>,
60}
61
62/// Wrap a raw prompt in Qwen3 chat format if it's not already formatted.
63/// Disables thinking mode by default for faster, more direct responses.
64/// When context is provided, it is injected into the system message to ground
65/// the model's response with memory.
66fn apply_chat_template(prompt: &str, context: Option<&str>) -> String {
67    if prompt.contains("<|im_start|>") {
68        return prompt.to_string();
69    }
70    match context {
71        Some(ctx) => format!(
72            "<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response. /no_think\n\n{ctx}<|im_end|>\n\
73             <|im_start|>user\n{prompt}<|im_end|>\n\
74             <|im_start|>assistant\n"
75        ),
76        None => format!(
77            "<|im_start|>system\nYou are a helpful assistant. /no_think<|im_end|>\n\
78             <|im_start|>user\n{prompt}<|im_end|>\n\
79             <|im_start|>assistant\n"
80        ),
81    }
82}
83
84/// Strip Qwen3 thinking tags from output (public for vision module).
85pub fn strip_thinking_pub(text: &str) -> String {
86    strip_thinking(text)
87}
88
89/// Strip Qwen3 thinking tags from output.
90fn strip_thinking(text: &str) -> String {
91    if let Some(end) = text.find("</think>") {
92        text[end + 8..].trim_start().to_string()
93    } else if text.starts_with("<think>") {
94        // Still in thinking mode, return as-is
95        text.to_string()
96    } else {
97        text.to_string()
98    }
99}
100
101/// Callback for FLARE-style re-retrieval during generation.
102/// Called with partial generation text, returns additional context or None.
103pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;
104
105/// Generate text from a prompt using the loaded model.
106/// `retrieval_cb`: optional callback for confidence-triggered re-retrieval (FLARE).
107pub async fn generate(
108    backend: &mut CandleBackend,
109    req: GenerateRequest,
110) -> Result<String, InferenceError> {
111    // Reset KV cache so each generation starts fresh (prevents cross-call state bleed)
112    backend.clear_kv_cache();
113
114    let formatted = apply_chat_template(&req.prompt, req.context.as_deref());
115    let tokens = backend.encode(&formatted)?;
116    let eos = backend.eos_token_id();
117    let eos_alt = backend.token_id("<|im_end|>");
118    let params = &req.params;
119
120    if tokens.is_empty() {
121        return Ok(String::new());
122    }
123
124    // Truncate to model's max context length minus generation headroom.
125    // This prevents KV cache overflow on long prompts.
126    let max_ctx = backend.context_length().unwrap_or(32768);
127    let headroom = params.max_tokens.min(max_ctx / 4);
128    let max_prompt = max_ctx.saturating_sub(headroom);
129    let tokens = if tokens.len() > max_prompt {
130        eprintln!(
131            "[car-inference] truncating prompt from {} to {} tokens (context_length={})",
132            tokens.len(), max_prompt, max_ctx
133        );
134        tokens[tokens.len() - max_prompt..].to_vec()
135    } else {
136        tokens
137    };
138
139    let mut generated = Vec::new();
140
141    // Prefill: process all prompt tokens, sample first generated token from prefill logits
142    let logits = backend.forward(&tokens, 0)?;
143    let mut next_token = sample_token(&logits, params)?;
144
145    for _i in 0..params.max_tokens {
146        // Check EOS
147        if eos.map_or(false, |id| next_token == id)
148            || eos_alt.map_or(false, |id| next_token == id)
149        {
150            break;
151        }
152
153        generated.push(next_token);
154
155        // Check stop sequences
156        if !params.stop.is_empty() {
157            let text_so_far = backend.decode(&generated)?;
158            if params.stop.iter().any(|s| text_so_far.contains(s)) {
159                break;
160            }
161        }
162
163        // Generate next token
164        let pos = tokens.len() + generated.len() - 1;
165        let logits = backend.forward(&[next_token], pos)?;
166        next_token = sample_token(&logits, params)?;
167    }
168
169    let text = backend.decode(&generated)?;
170    Ok(strip_thinking(&text))
171}
172
173/// Generate with FLARE-style confidence-triggered re-retrieval.
174///
175/// Monitors token logit confidence during generation. When a window of
176/// low-confidence tokens is detected, pauses, re-queries memory with the
177/// partial generation, and resumes with augmented context.
178pub async fn generate_with_retrieval(
179    backend: &mut CandleBackend,
180    mut req: GenerateRequest,
181    retrieval_cb: RetrievalCallback,
182) -> Result<String, InferenceError> {
183    // First pass: generate normally
184    backend.clear_kv_cache();
185    let formatted = apply_chat_template(&req.prompt, req.context.as_deref());
186    let tokens = backend.encode(&formatted)?;
187    let eos = backend.eos_token_id();
188    let eos_alt = backend.token_id("<|im_end|>");
189    let params = req.params.clone();
190
191    if tokens.is_empty() {
192        return Ok(String::new());
193    }
194
195    let mut generated = Vec::new();
196    let mut low_confidence_count = 0u32;
197    let mut retrieval_attempts = 0u32;
198    let max_retrievals = 2;
199    let confidence_threshold = 0.4f32;
200    let low_confidence_window = 3u32;
201
202    let logits = backend.forward(&tokens, 0)?;
203    let mut next_token = sample_token(&logits, &params)?;
204
205    for _i in 0..params.max_tokens {
206        if eos.map_or(false, |id| next_token == id)
207            || eos_alt.map_or(false, |id| next_token == id)
208        {
209            break;
210        }
211
212        generated.push(next_token);
213
214        // Generate next token and check confidence
215        let pos = tokens.len() + generated.len() - 1;
216        let logits = backend.forward(&[next_token], pos)?;
217
218        // Check max logit probability for confidence
219        let logits_f32: Vec<f32> = logits.squeeze(0)
220            .unwrap_or(logits.clone())
221            .to_dtype(candle_core::DType::F32)
222            .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
223            .to_vec1()
224            .unwrap_or_default();
225
226        if !logits_f32.is_empty() {
227            // Compute softmax max probability
228            let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
229            let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
230            let max_prob = 1.0 / exp_sum; // probability of the top token
231
232            if max_prob < confidence_threshold {
233                low_confidence_count += 1;
234            } else {
235                low_confidence_count = 0;
236            }
237
238            // Trigger re-retrieval after sustained low confidence
239            if low_confidence_count >= low_confidence_window
240                && retrieval_attempts < max_retrievals
241            {
242                retrieval_attempts += 1;
243                low_confidence_count = 0;
244
245                // Use partial generation as re-retrieval query
246                let partial = backend.decode(&generated)?;
247                if let Some(new_context) = retrieval_cb(&partial) {
248                    // Restart generation with augmented context
249                    let combined_context = match req.context.take() {
250                        Some(old) => format!("{}\n\n{}", old, new_context),
251                        None => new_context,
252                    };
253                    req.context = Some(combined_context);
254
255                    // Re-encode and restart
256                    backend.clear_kv_cache();
257                    let new_formatted = apply_chat_template(&req.prompt, req.context.as_deref());
258                    let new_tokens = backend.encode(&new_formatted)?;
259                    generated.clear();
260
261                    let logits = backend.forward(&new_tokens, 0)?;
262                    next_token = sample_token(&logits, &params)?;
263                    continue;
264                }
265            }
266        }
267
268        next_token = sample_token(&logits, &params)?;
269    }
270
271    let text = backend.decode(&generated)?;
272    Ok(strip_thinking(&text))
273}
274
275/// Sample a token, suppressing specific token IDs (set to -inf before sampling).
276pub fn sample_token_suppress(logits: &Tensor, params: &GenerateParams, suppress: &[u32]) -> Result<u32, InferenceError> {
277    if suppress.is_empty() {
278        return sample_token(logits, params);
279    }
280    // Clone logits and set suppressed tokens to -inf
281    let mut logits_vec: Vec<f32> = logits.squeeze(0)
282        .unwrap_or(logits.clone())
283        .to_dtype(candle_core::DType::F32)
284        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
285        .to_vec1()
286        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
287    // Handle 2D logits (take last row)
288    let dims = logits.dims();
289    if dims.len() == 2 {
290        let vocab = dims[dims.len() - 1];
291        let start = logits_vec.len() - vocab;
292        logits_vec = logits_vec[start..].to_vec();
293    }
294    for &id in suppress {
295        if (id as usize) < logits_vec.len() {
296            logits_vec[id as usize] = f32::NEG_INFINITY;
297        }
298    }
299    let modified = Tensor::from_vec(logits_vec, logits.squeeze(0).unwrap_or(logits.clone()).shape(), logits.device())
300        .map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
301    sample_token(&modified, params)
302}
303
304/// Sample a token from logits using temperature + top-p + top-k.
305pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
306    let logits = logits
307        .squeeze(0)
308        .map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
309    let logits = logits
310        .to_dtype(candle_core::DType::F32)
311        .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;
312
313    // Get last position's logits
314    let dim = logits.dims();
315    let logits = if dim.len() == 2 {
316        logits
317            .get(dim[0] - 1)
318            .map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
319    } else {
320        logits
321    };
322
323    // Greedy decoding
324    if params.temperature <= 0.0 {
325        let token = logits
326            .argmax(0)
327            .map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
328            .to_scalar::<u32>()
329            .map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
330        return Ok(token);
331    }
332
333    // Temperature scaling
334    let logits = (&logits / params.temperature)
335        .map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;
336
337    let mut logits_vec: Vec<f32> = logits
338        .to_vec1()
339        .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
340
341    // Top-k filtering
342    if params.top_k > 0 && params.top_k < logits_vec.len() {
343        let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
344        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
345        let threshold = indexed[params.top_k].1;
346        for v in &mut logits_vec {
347            if *v < threshold {
348                *v = f32::NEG_INFINITY;
349            }
350        }
351    }
352
353    // Softmax
354    let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
355    let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
356    let sum: f32 = exp.iter().sum();
357    let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();
358
359    // Top-p (nucleus) filtering
360    if params.top_p < 1.0 {
361        let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
362        sorted_indices.sort_by(|&a, &b| {
363            probs[b].partial_cmp(&probs[a]).unwrap_or(std::cmp::Ordering::Equal)
364        });
365
366        let mut cumsum = 0.0f32;
367        let mut cutoff_idx = sorted_indices.len();
368        for (i, &idx) in sorted_indices.iter().enumerate() {
369            cumsum += probs[idx];
370            if cumsum > params.top_p as f32 {
371                cutoff_idx = i + 1;
372                break;
373            }
374        }
375
376        let keep: std::collections::HashSet<usize> =
377            sorted_indices[..cutoff_idx].iter().copied().collect();
378        for (i, p) in probs.iter_mut().enumerate() {
379            if !keep.contains(&i) {
380                *p = 0.0;
381            }
382        }
383
384        // Renormalize
385        let sum: f32 = probs.iter().sum();
386        if sum > 0.0 {
387            for p in &mut probs {
388                *p /= sum;
389            }
390        }
391    }
392
393    // Categorical sample
394    let r: f32 = rand_f32();
395    let mut cumsum = 0.0f32;
396    for (i, &p) in probs.iter().enumerate() {
397        cumsum += p;
398        if cumsum >= r {
399            return Ok(i as u32);
400        }
401    }
402
403    // Fallback: return highest prob token
404    Ok(probs
405        .iter()
406        .enumerate()
407        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
408        .map(|(i, _)| i as u32)
409        .unwrap_or(0))
410}
411
412/// Random float in [0, 1) using the rand crate.
413fn rand_f32() -> f32 {
414    rand::random::<f32>()
415}