1use anyhow::Result;
2use candle::{DType, Tensor};
3use candle_core as candle;
4use candle_nn::VarBuilder;
5use candle_transformers::generation::{LogitsProcessor, Sampling};
6use candle_transformers::models::llama as model;
7use hf_hub::{Repo, RepoType, api::sync::Api};
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use std::fs;
11use std::path::PathBuf;
12use std::sync::{Arc, OnceLock, RwLock};
13use std::{
14 env,
15 time::{SystemTime, UNIX_EPOCH},
16};
17use tokenizers::Tokenizer;
18
19const EOS_TOKEN: &str = "</s>";
20
21pub fn build_fallback_prompt(sys: &str, user: &str) -> String {
24 if sys.trim().is_empty() {
25 user.to_string()
26 } else {
27 format!("<|system|>\n{}\n<|user|>\n{}", sys, user)
28 }
29}
30
31pub fn build_chat_messages(sys: &str, user: &str) -> Vec<Value> {
33 let mut messages = Vec::new();
34 if !sys.trim().is_empty() {
35 messages.push(json!({"role":"system","content": sys}));
36 }
37 messages.push(json!({"role":"user","content": user}));
38 messages
39}
40
41#[derive(Clone, Debug)]
43pub struct CandleRunParams {
44 pub model_id: Option<String>, pub revision: Option<String>, pub cpu: bool, pub sample_len: usize, pub min_tokens: usize, pub temperature: f32,
50 pub top_p: Option<f32>,
51 pub top_k: Option<usize>,
52 pub repeat_penalty: f32,
53 pub repeat_last_n: usize,
54 pub seed: Option<u64>,
55}
56
57impl Default for CandleRunParams {
58 fn default() -> Self {
59 Self {
60 model_id: None,
61 revision: Some("main".into()),
62 cpu: true,
63 sample_len: 128,
64 min_tokens: 0,
65 temperature: 0.7,
66 top_p: Some(0.95),
67 top_k: None,
68 repeat_penalty: 1.1,
69 repeat_last_n: 128,
70 seed: None,
71 }
72 }
73}
74
75struct CandleEngine {
77 device: candle::Device,
78 dtype: DType,
79 llama: model::Llama,
80 config: model::Config,
81 tokenizer: Tokenizer,
82 eos_token_id: Option<model::LlamaEosToks>,
83 model_id: String,
84 revision: String,
85}
86
87static ENGINE: OnceLock<Arc<CandleEngine>> = OnceLock::new();
88static LOGIT_BIAS_STORE: OnceLock<RwLock<Option<Vec<f32>>>> = OnceLock::new();
89
90#[derive(Debug, Serialize, Deserialize, Clone)]
91pub struct TrainExample {
92 #[serde(default)]
93 pub system: Option<String>,
94 pub user: String,
95 pub assistant: String,
96}
97
98#[derive(Debug, Serialize, Deserialize, Clone)]
99pub struct TrainParams {
100 #[serde(default)]
101 pub learning_rate: Option<f32>,
102 #[serde(default)]
103 pub epochs: Option<u32>,
104 #[serde(default)]
105 pub max_examples: Option<usize>,
106 #[serde(default)]
107 pub bias_cap: Option<f32>,
108 #[serde(default)]
109 pub topk_updates: Option<usize>,
110}
111
112#[derive(Debug, Serialize, Deserialize, Clone)]
113pub struct TrainResult {
114 pub adapter_path: String,
115 pub epochs: u32,
116 pub examples: usize,
117 pub vocab: usize,
118}
119
120#[derive(Debug, Serialize, Deserialize)]
121struct LogitBiasFile {
122 vocab: usize,
123 bias: Vec<f32>,
124 created_at: u64,
125}
126
127fn bias_store() -> &'static RwLock<Option<Vec<f32>>> {
128 LOGIT_BIAS_STORE.get_or_init(|| RwLock::new(None))
129}
130
131fn adapter_dir() -> PathBuf {
132 env::var("ADAPTER_DIR")
133 .map(PathBuf::from)
134 .unwrap_or_else(|_| PathBuf::from("/models/adapters"))
135}
136
137fn load_active_logit_bias() -> Option<Vec<f32>> {
138 let path = adapter_dir().join("active_logit_bias.json");
139 if !path.exists() {
140 return None;
141 }
142 match fs::read_to_string(&path) {
143 Ok(s) => match serde_json::from_str::<LogitBiasFile>(&s) {
144 Ok(f) => Some(f.bias),
145 Err(_) => None,
146 },
147 Err(_) => None,
148 }
149}
150
151fn persist_logit_bias(bias: &[f32]) -> Result<PathBuf> {
152 let dir = adapter_dir();
153 fs::create_dir_all(&dir).ok();
154 let now = SystemTime::now()
155 .duration_since(UNIX_EPOCH)
156 .unwrap_or_default()
157 .as_secs();
158 let file = LogitBiasFile {
159 vocab: bias.len(),
160 bias: bias.to_vec(),
161 created_at: now,
162 };
163 let active_path = dir.join("active_logit_bias.json");
164 let named_path = dir.join(format!("logit_bias_{}.json", now));
165 let data = serde_json::to_string_pretty(&file)?;
166 fs::write(&named_path, &data)?;
167 fs::write(&active_path, &data)?;
168 Ok(active_path)
169}
170
171impl Default for TrainParams {
172 fn default() -> Self {
173 Self {
174 learning_rate: Some(0.05),
175 epochs: Some(1),
176 max_examples: None,
177 bias_cap: Some(2.0),
178 topk_updates: Some(64),
179 }
180 }
181}
182
183fn ensure_engine(params: &CandleRunParams) -> Result<Arc<CandleEngine>> {
184 if let Some(engine) = ENGINE.get() {
185 return Ok(engine.clone());
187 }
188
189 let device = candle_examples::device(params.cpu)?;
191 let dtype = DType::F16;
192
193 let model_id = params
195 .model_id
196 .clone()
197 .unwrap_or_else(|| "HuggingFaceTB/SmolLM2-1.7B-Instruct".to_string());
198 let revision = params.revision.clone().unwrap_or_else(|| "main".into());
199
200 let api = Api::new()?;
202 let api = api.repo(Repo::with_revision(
203 model_id.clone(),
204 RepoType::Model,
205 revision.clone(),
206 ));
207
208 let tokenizer_filename = api.get("tokenizer.json")?;
210 let config_filename = api.get("config.json")?;
211 let llama_cfg: model::LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
212 let config = llama_cfg.into_config(false);
214
215 let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")
217 .unwrap_or_else(|_| vec![api.get("model.safetensors").expect("weights")]);
218
219 let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
221 let llama = model::Llama::load(vb, &config)?;
222
223 let tokenizer = Tokenizer::from_file(tokenizer_filename.clone()).map_err(anyhow::Error::msg)?;
225 let eos_token_id = config.eos_token_id.clone().or_else(|| {
226 tokenizer
227 .token_to_id(EOS_TOKEN)
228 .map(model::LlamaEosToks::Single)
229 });
230
231 let engine = Arc::new(CandleEngine {
232 device,
233 dtype,
234 llama,
235 config,
236 tokenizer,
237 eos_token_id,
238 model_id,
239 revision,
240 });
241
242 let _ = bias_store();
244 {
245 let mut w = bias_store().write().unwrap();
246 if w.is_none() {
247 if let Some(b) = load_active_logit_bias() {
248 *w = Some(b);
249 }
250 }
251 }
252
253 let _ = ENGINE.set(engine.clone());
255 Ok(engine)
256}
257
258pub fn preload_local_candle(params: &CandleRunParams) -> Result<()> {
260 let _ = ensure_engine(params)?;
261 Ok(())
262}
263
264pub fn generate_local_candle(
267 sys: &str,
268 user: &str,
269 stop: Option<Vec<String>>,
270 params: &CandleRunParams,
271) -> Result<String> {
272 let engine = ensure_engine(params)?;
274
275 let mut cache = model::Cache::new(true, engine.dtype, &engine.config, &engine.device)?;
277 let llama = &engine.llama;
278
279 let mut final_prompt = String::new();
281 if !sys.trim().is_empty() {
282 final_prompt.push_str("System: ");
283 final_prompt.push_str(sys);
284 final_prompt.push_str("\n");
285 }
286 final_prompt.push_str("User: ");
287 final_prompt.push_str(user);
288 final_prompt.push_str("\nAssistant: ");
289
290 let mut tokens = engine
292 .tokenizer
293 .clone()
294 .encode(final_prompt.as_str(), true)
295 .map_err(anyhow::Error::msg)?
296 .get_ids()
297 .to_vec();
298 let mut tok_stream = {
299 let t = engine.tokenizer.clone();
300 candle_examples::token_output_stream::TokenOutputStream::new(t)
301 };
302
303 let eos_ids: Option<Vec<u32>> = match engine.eos_token_id.clone() {
305 Some(model::LlamaEosToks::Single(id)) => Some(vec![id]),
306 Some(model::LlamaEosToks::Multiple(ids)) => Some(ids),
307 None => None,
308 };
309
310 let t = params.temperature as f64;
312 let sampling = if params.temperature <= 0.0 {
313 Sampling::ArgMax
314 } else {
315 match (params.top_k, params.top_p) {
316 (None, None) => Sampling::All { temperature: t },
317 (Some(k), None) => Sampling::TopK { k, temperature: t },
318 (None, Some(p)) => Sampling::TopP {
319 p: p as f64,
320 temperature: t,
321 },
322 (Some(k), Some(p)) => Sampling::TopKThenTopP {
323 k,
324 p: p as f64,
325 temperature: t,
326 },
327 }
328 };
329 let mut logits_processor = LogitsProcessor::from_sampling(params.seed.unwrap_or(42), sampling);
330
331 let mut index_pos = 0usize;
333 let mut generated = 0usize;
334 let mut out = String::new();
335
336 for index in 0..params.sample_len {
337 let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
338 (1, index_pos)
339 } else {
340 (tokens.len(), 0)
341 };
342 let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
343 let input = Tensor::new(ctxt, &engine.device)?.unsqueeze(0)?;
344 let logits = llama.forward(&input, context_index, &mut cache)?;
345 let logits = logits.squeeze(0)?;
346 let logits = if params.repeat_penalty == 1.0 {
347 logits
348 } else {
349 let start_at = tokens.len().saturating_sub(params.repeat_last_n);
350 candle_transformers::utils::apply_repeat_penalty(
351 &logits,
352 params.repeat_penalty,
353 &tokens[start_at..],
354 )?
355 };
356 index_pos += ctxt.len();
357
358 let logits = {
360 let r = bias_store().read().unwrap();
361 if let Some(bias) = &*r {
362 let mut data = logits.to_vec1::<f32>()?;
363 if data.len() == bias.len() {
364 for i in 0..data.len() {
365 data[i] += bias[i];
366 }
367 }
368 Tensor::new(&data[..], &engine.device)?
369 } else {
370 logits
371 }
372 };
373
374 let logits = if generated < params.min_tokens {
375 if let Some(ref ids) = eos_ids {
376 let mut data = logits.to_vec1::<f32>()?;
377 for id in ids {
378 let i = *id as usize;
379 if i < data.len() {
380 data[i] = f32::NEG_INFINITY;
381 }
382 }
383 Tensor::new(&data[..], &engine.device)?
384 } else {
385 logits
386 }
387 } else {
388 logits
389 };
390
391 let next_token = logits_processor.sample(&logits)?;
392 tokens.push(next_token);
393 generated += 1;
394
395 if let Some(ref ids) = eos_ids {
396 if generated >= params.min_tokens && ids.contains(&next_token) {
397 break;
398 }
399 }
400
401 if let Some(t) = tok_stream.next_token(next_token)? {
402 out.push_str(&t);
403 }
404
405 if let Some(stops) = &stop {
406 if stops.iter().any(|s| out.ends_with(s) || out.contains(s)) {
407 break;
408 }
409 }
410 }
411
412 if let Some(rest) = tok_stream.decode_rest().map_err(anyhow::Error::msg)? {
413 out.push_str(&rest);
414 }
415
416 Ok(out.trim().to_string())
417}
418
419pub fn train_logit_bias(
424 examples: &[TrainExample],
425 params: Option<TrainParams>,
426 run: &CandleRunParams,
427) -> Result<TrainResult> {
428 let params = params.unwrap_or_default();
429 let lr = params.learning_rate.unwrap_or(0.05);
430 let epochs = params.epochs.unwrap_or(1).max(1);
431 let max_examples = params.max_examples;
432 let bias_cap = params.bias_cap.unwrap_or(2.0);
433 let topk = params.topk_updates;
434
435 let engine = ensure_engine(run)?;
437 let llama = &engine.llama;
438
439 let mut bias_guard = bias_store().write().unwrap();
441 if bias_guard.is_none() {
442 *bias_guard = load_active_logit_bias();
443 }
444
445 let mut used_examples = 0usize;
448 let mut bias_vec: Vec<f32> = Vec::new();
450
451 for _epoch in 0..epochs {
452 used_examples = 0;
453 let mut grad_accum: Option<Vec<f32>> = None;
455
456 'outer: for ex in examples.iter() {
457 if let Some(m) = max_examples {
458 if used_examples >= m {
459 break 'outer;
460 }
461 }
462
463 let sys = ex.system.as_deref().unwrap_or("");
465 let mut prefix = String::new();
466 if !sys.trim().is_empty() {
467 prefix.push_str("System: ");
468 prefix.push_str(sys);
469 prefix.push('\n');
470 }
471 prefix.push_str("User: ");
472 prefix.push_str(&ex.user);
473 prefix.push('\n');
474 prefix.push_str("Assistant: ");
475
476 let full = format!("{}{}", prefix, ex.assistant);
477
478 let prefix_ids = engine
480 .tokenizer
481 .clone()
482 .encode(prefix.as_str(), true)
483 .map_err(anyhow::Error::msg)?
484 .get_ids()
485 .to_vec();
486 let full_ids = engine
487 .tokenizer
488 .clone()
489 .encode(full.as_str(), true)
490 .map_err(anyhow::Error::msg)?
491 .get_ids()
492 .to_vec();
493
494 if full_ids.len() <= prefix_ids.len() + 1 {
495 continue;
496 }
497
498 let mut cache = model::Cache::new(true, engine.dtype, &engine.config, &engine.device)?;
500
501 let mut index_pos = 0usize;
503
504 for pos in 0..(full_ids.len() - 1) {
505 let (context_size, context_index) = if cache.use_kv_cache && pos > 0 {
507 (1, index_pos)
508 } else {
509 (pos + 1, 0)
510 };
511 let ctxt = &full_ids[(pos + 1).saturating_sub(context_size)..=pos];
512 let input = Tensor::new(ctxt, &engine.device)?.unsqueeze(0)?;
513 let logits = llama.forward(&input, context_index, &mut cache)?;
514 let logits = logits.squeeze(0)?;
515 index_pos += ctxt.len();
516
517 let mut logv = logits.to_vec1::<f32>()?;
519 if bias_vec.is_empty() {
520 let vocab = logv.len();
521 bias_vec = match &*bias_guard {
522 Some(b) if b.len() == vocab => b.clone(),
523 _ => vec![0.0; vocab],
524 };
525 }
526 if grad_accum.is_none() {
527 grad_accum = Some(vec![0.0; logv.len()]);
528 }
529
530 if logv.len() == bias_vec.len() {
532 for i in 0..logv.len() {
533 logv[i] += bias_vec[i];
534 }
535 }
536
537 if pos < prefix_ids.len() {
539 continue;
540 }
541 let target = full_ids[pos + 1] as usize;
542 if target >= logv.len() {
543 continue;
544 }
545
546 let mut maxv = f32::NEG_INFINITY;
549 for &v in &logv {
550 if v > maxv {
551 maxv = v;
552 }
553 }
554 let mut sum = 0.0f32;
555 for v in &mut logv {
556 *v = (*v - maxv).exp();
557 sum += *v;
558 }
559 if sum == 0.0 {
560 continue;
561 }
562 for v in &mut logv {
563 *v /= sum;
564 }
565
566 if let Some(ga) = grad_accum.as_mut() {
568 for i in 0..ga.len() {
569 ga[i] += logv[i];
570 }
571 ga[target] -= 1.0;
572 }
573 }
574
575 used_examples += 1;
576 }
577
578 if let Some(ga) = grad_accum {
580 if topk.unwrap_or(0) > 0 {
581 let k = topk.unwrap();
583 let mut idxs: Vec<usize> = (0..ga.len()).collect();
584 idxs.sort_unstable_by(|&a, &b| {
585 ga[b]
586 .abs()
587 .partial_cmp(&ga[a].abs())
588 .unwrap_or(std::cmp::Ordering::Equal)
589 });
590 for &i in idxs.iter().take(k) {
591 bias_vec[i] = (bias_vec[i] - lr * ga[i]).clamp(-bias_cap, bias_cap);
592 }
593 } else {
594 for i in 0..ga.len() {
595 bias_vec[i] = (bias_vec[i] - lr * ga[i]).clamp(-bias_cap, bias_cap);
596 }
597 }
598 }
599 }
600
601 let path = persist_logit_bias(&bias_vec)?;
603 *bias_guard = Some(bias_vec.clone());
604 drop(bias_guard);
605
606 Ok(TrainResult {
607 adapter_path: path.to_string_lossy().to_string(),
608 epochs,
609 examples: used_examples,
610 vocab: bias_vec.len(),
611 })
612}