car_inference/tasks/
generate.rs1use candle_core::Tensor;
4use serde::{Deserialize, Serialize};
5
6use crate::backend::CandleBackend;
7use crate::InferenceError;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct GenerateParams {
12 #[serde(default = "default_temperature")]
14 pub temperature: f64,
15 #[serde(default = "default_top_p")]
17 pub top_p: f64,
18 #[serde(default)]
20 pub top_k: usize,
21 #[serde(default = "default_max_tokens")]
23 pub max_tokens: usize,
24 #[serde(default)]
26 pub stop: Vec<String>,
27}
28
29fn default_temperature() -> f64 { 0.7 }
30fn default_top_p() -> f64 { 0.9 }
31fn default_max_tokens() -> usize { 512 }
32
33impl Default for GenerateParams {
34 fn default() -> Self {
35 Self {
36 temperature: default_temperature(),
37 top_p: default_top_p(),
38 top_k: 0,
39 max_tokens: default_max_tokens(),
40 stop: Vec::new(),
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct GenerateRequest {
48 pub prompt: String,
50 pub model: Option<String>,
52 #[serde(default)]
54 pub params: GenerateParams,
55 #[serde(default)]
59 pub context: Option<String>,
60}
61
62fn apply_chat_template(prompt: &str, context: Option<&str>) -> String {
67 if prompt.contains("<|im_start|>") {
68 return prompt.to_string();
69 }
70 match context {
71 Some(ctx) => format!(
72 "<|im_start|>system\nYou are a helpful assistant. Use the following context to inform your response. /no_think\n\n{ctx}<|im_end|>\n\
73 <|im_start|>user\n{prompt}<|im_end|>\n\
74 <|im_start|>assistant\n"
75 ),
76 None => format!(
77 "<|im_start|>system\nYou are a helpful assistant. /no_think<|im_end|>\n\
78 <|im_start|>user\n{prompt}<|im_end|>\n\
79 <|im_start|>assistant\n"
80 ),
81 }
82}
83
84pub fn strip_thinking_pub(text: &str) -> String {
86 strip_thinking(text)
87}
88
89fn strip_thinking(text: &str) -> String {
91 if let Some(end) = text.find("</think>") {
92 text[end + 8..].trim_start().to_string()
93 } else if text.starts_with("<think>") {
94 text.to_string()
96 } else {
97 text.to_string()
98 }
99}
100
101pub type RetrievalCallback = Box<dyn Fn(&str) -> Option<String> + Send>;
104
105pub async fn generate(
108 backend: &mut CandleBackend,
109 req: GenerateRequest,
110) -> Result<String, InferenceError> {
111 backend.clear_kv_cache();
113
114 let formatted = apply_chat_template(&req.prompt, req.context.as_deref());
115 let tokens = backend.encode(&formatted)?;
116 let eos = backend.eos_token_id();
117 let eos_alt = backend.token_id("<|im_end|>");
118 let params = &req.params;
119
120 if tokens.is_empty() {
121 return Ok(String::new());
122 }
123
124 let max_ctx = backend.context_length().unwrap_or(32768);
127 let headroom = params.max_tokens.min(max_ctx / 4);
128 let max_prompt = max_ctx.saturating_sub(headroom);
129 let tokens = if tokens.len() > max_prompt {
130 eprintln!(
131 "[car-inference] truncating prompt from {} to {} tokens (context_length={})",
132 tokens.len(), max_prompt, max_ctx
133 );
134 tokens[tokens.len() - max_prompt..].to_vec()
135 } else {
136 tokens
137 };
138
139 let mut generated = Vec::new();
140
141 let logits = backend.forward(&tokens, 0)?;
143 let mut next_token = sample_token(&logits, params)?;
144
145 for _i in 0..params.max_tokens {
146 if eos.map_or(false, |id| next_token == id)
148 || eos_alt.map_or(false, |id| next_token == id)
149 {
150 break;
151 }
152
153 generated.push(next_token);
154
155 if !params.stop.is_empty() {
157 let text_so_far = backend.decode(&generated)?;
158 if params.stop.iter().any(|s| text_so_far.contains(s)) {
159 break;
160 }
161 }
162
163 let pos = tokens.len() + generated.len() - 1;
165 let logits = backend.forward(&[next_token], pos)?;
166 next_token = sample_token(&logits, params)?;
167 }
168
169 let text = backend.decode(&generated)?;
170 Ok(strip_thinking(&text))
171}
172
173pub async fn generate_with_retrieval(
179 backend: &mut CandleBackend,
180 mut req: GenerateRequest,
181 retrieval_cb: RetrievalCallback,
182) -> Result<String, InferenceError> {
183 backend.clear_kv_cache();
185 let formatted = apply_chat_template(&req.prompt, req.context.as_deref());
186 let tokens = backend.encode(&formatted)?;
187 let eos = backend.eos_token_id();
188 let eos_alt = backend.token_id("<|im_end|>");
189 let params = req.params.clone();
190
191 if tokens.is_empty() {
192 return Ok(String::new());
193 }
194
195 let mut generated = Vec::new();
196 let mut low_confidence_count = 0u32;
197 let mut retrieval_attempts = 0u32;
198 let max_retrievals = 2;
199 let confidence_threshold = 0.4f32;
200 let low_confidence_window = 3u32;
201
202 let logits = backend.forward(&tokens, 0)?;
203 let mut next_token = sample_token(&logits, ¶ms)?;
204
205 for _i in 0..params.max_tokens {
206 if eos.map_or(false, |id| next_token == id)
207 || eos_alt.map_or(false, |id| next_token == id)
208 {
209 break;
210 }
211
212 generated.push(next_token);
213
214 let pos = tokens.len() + generated.len() - 1;
216 let logits = backend.forward(&[next_token], pos)?;
217
218 let logits_f32: Vec<f32> = logits.squeeze(0)
220 .unwrap_or(logits.clone())
221 .to_dtype(candle_core::DType::F32)
222 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
223 .to_vec1()
224 .unwrap_or_default();
225
226 if !logits_f32.is_empty() {
227 let max_logit = logits_f32.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
229 let exp_sum: f32 = logits_f32.iter().map(|&v| (v - max_logit).exp()).sum();
230 let max_prob = 1.0 / exp_sum; if max_prob < confidence_threshold {
233 low_confidence_count += 1;
234 } else {
235 low_confidence_count = 0;
236 }
237
238 if low_confidence_count >= low_confidence_window
240 && retrieval_attempts < max_retrievals
241 {
242 retrieval_attempts += 1;
243 low_confidence_count = 0;
244
245 let partial = backend.decode(&generated)?;
247 if let Some(new_context) = retrieval_cb(&partial) {
248 let combined_context = match req.context.take() {
250 Some(old) => format!("{}\n\n{}", old, new_context),
251 None => new_context,
252 };
253 req.context = Some(combined_context);
254
255 backend.clear_kv_cache();
257 let new_formatted = apply_chat_template(&req.prompt, req.context.as_deref());
258 let new_tokens = backend.encode(&new_formatted)?;
259 generated.clear();
260
261 let logits = backend.forward(&new_tokens, 0)?;
262 next_token = sample_token(&logits, ¶ms)?;
263 continue;
264 }
265 }
266 }
267
268 next_token = sample_token(&logits, ¶ms)?;
269 }
270
271 let text = backend.decode(&generated)?;
272 Ok(strip_thinking(&text))
273}
274
275pub fn sample_token_suppress(logits: &Tensor, params: &GenerateParams, suppress: &[u32]) -> Result<u32, InferenceError> {
277 if suppress.is_empty() {
278 return sample_token(logits, params);
279 }
280 let mut logits_vec: Vec<f32> = logits.squeeze(0)
282 .unwrap_or(logits.clone())
283 .to_dtype(candle_core::DType::F32)
284 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
285 .to_vec1()
286 .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
287 let dims = logits.dims();
289 if dims.len() == 2 {
290 let vocab = dims[dims.len() - 1];
291 let start = logits_vec.len() - vocab;
292 logits_vec = logits_vec[start..].to_vec();
293 }
294 for &id in suppress {
295 if (id as usize) < logits_vec.len() {
296 logits_vec[id as usize] = f32::NEG_INFINITY;
297 }
298 }
299 let modified = Tensor::from_vec(logits_vec, logits.squeeze(0).unwrap_or(logits.clone()).shape(), logits.device())
300 .map_err(|e| InferenceError::InferenceFailed(format!("from_vec: {e}")))?;
301 sample_token(&modified, params)
302}
303
304pub fn sample_token(logits: &Tensor, params: &GenerateParams) -> Result<u32, InferenceError> {
306 let logits = logits
307 .squeeze(0)
308 .map_err(|e| InferenceError::InferenceFailed(format!("squeeze: {e}")))?;
309 let logits = logits
310 .to_dtype(candle_core::DType::F32)
311 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?;
312
313 let dim = logits.dims();
315 let logits = if dim.len() == 2 {
316 logits
317 .get(dim[0] - 1)
318 .map_err(|e| InferenceError::InferenceFailed(format!("get last: {e}")))?
319 } else {
320 logits
321 };
322
323 if params.temperature <= 0.0 {
325 let token = logits
326 .argmax(0)
327 .map_err(|e| InferenceError::InferenceFailed(format!("argmax: {e}")))?
328 .to_scalar::<u32>()
329 .map_err(|e| InferenceError::InferenceFailed(format!("scalar: {e}")))?;
330 return Ok(token);
331 }
332
333 let logits = (&logits / params.temperature)
335 .map_err(|e| InferenceError::InferenceFailed(format!("temp scale: {e}")))?;
336
337 let mut logits_vec: Vec<f32> = logits
338 .to_vec1()
339 .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
340
341 if params.top_k > 0 && params.top_k < logits_vec.len() {
343 let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
344 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
345 let threshold = indexed[params.top_k].1;
346 for v in &mut logits_vec {
347 if *v < threshold {
348 *v = f32::NEG_INFINITY;
349 }
350 }
351 }
352
353 let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
355 let exp: Vec<f32> = logits_vec.iter().map(|&v| (v - max_logit).exp()).collect();
356 let sum: f32 = exp.iter().sum();
357 let mut probs: Vec<f32> = exp.iter().map(|&v| v / sum).collect();
358
359 if params.top_p < 1.0 {
361 let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
362 sorted_indices.sort_by(|&a, &b| {
363 probs[b].partial_cmp(&probs[a]).unwrap_or(std::cmp::Ordering::Equal)
364 });
365
366 let mut cumsum = 0.0f32;
367 let mut cutoff_idx = sorted_indices.len();
368 for (i, &idx) in sorted_indices.iter().enumerate() {
369 cumsum += probs[idx];
370 if cumsum > params.top_p as f32 {
371 cutoff_idx = i + 1;
372 break;
373 }
374 }
375
376 let keep: std::collections::HashSet<usize> =
377 sorted_indices[..cutoff_idx].iter().copied().collect();
378 for (i, p) in probs.iter_mut().enumerate() {
379 if !keep.contains(&i) {
380 *p = 0.0;
381 }
382 }
383
384 let sum: f32 = probs.iter().sum();
386 if sum > 0.0 {
387 for p in &mut probs {
388 *p /= sum;
389 }
390 }
391 }
392
393 let r: f32 = rand_f32();
395 let mut cumsum = 0.0f32;
396 for (i, &p) in probs.iter().enumerate() {
397 cumsum += p;
398 if cumsum >= r {
399 return Ok(i as u32);
400 }
401 }
402
403 Ok(probs
405 .iter()
406 .enumerate()
407 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
408 .map(|(i, _)| i as u32)
409 .unwrap_or(0))
410}
411
412fn rand_f32() -> f32 {
414 rand::random::<f32>()
415}