use anyhow::Result;
use candle::{DType, Tensor};
use candle_core as candle;
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama as model;
use hf_hub::{Repo, RepoType, api::sync::Api};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, OnceLock, RwLock};
use std::{
env,
time::{SystemTime, UNIX_EPOCH},
};
use tokenizers::Tokenizer;
const EOS_TOKEN: &str = "</s>";
pub fn build_fallback_prompt(sys: &str, user: &str) -> String {
if sys.trim().is_empty() {
user.to_string()
} else {
format!("<|system|>\n{}\n<|user|>\n{}", sys, user)
}
}
pub fn build_chat_messages(sys: &str, user: &str) -> Vec<Value> {
let mut messages = Vec::new();
if !sys.trim().is_empty() {
messages.push(json!({"role":"system","content": sys}));
}
messages.push(json!({"role":"user","content": user}));
messages
}
#[derive(Clone, Debug)]
pub struct CandleRunParams {
pub model_id: Option<String>, pub revision: Option<String>, pub cpu: bool, pub sample_len: usize, pub min_tokens: usize, pub temperature: f32,
pub top_p: Option<f32>,
pub top_k: Option<usize>,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
pub seed: Option<u64>,
}
impl Default for CandleRunParams {
fn default() -> Self {
Self {
model_id: None,
revision: Some("main".into()),
cpu: true,
sample_len: 128,
min_tokens: 0,
temperature: 0.7,
top_p: Some(0.95),
top_k: None,
repeat_penalty: 1.1,
repeat_last_n: 128,
seed: None,
}
}
}
struct CandleEngine {
device: candle::Device,
dtype: DType,
llama: model::Llama,
config: model::Config,
tokenizer: Tokenizer,
eos_token_id: Option<model::LlamaEosToks>,
model_id: String,
revision: String,
}
static ENGINE: OnceLock<Arc<CandleEngine>> = OnceLock::new();
static LOGIT_BIAS_STORE: OnceLock<RwLock<Option<Vec<f32>>>> = OnceLock::new();
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TrainExample {
#[serde(default)]
pub system: Option<String>,
pub user: String,
pub assistant: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TrainParams {
#[serde(default)]
pub learning_rate: Option<f32>,
#[serde(default)]
pub epochs: Option<u32>,
#[serde(default)]
pub max_examples: Option<usize>,
#[serde(default)]
pub bias_cap: Option<f32>,
#[serde(default)]
pub topk_updates: Option<usize>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TrainResult {
pub adapter_path: String,
pub epochs: u32,
pub examples: usize,
pub vocab: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct LogitBiasFile {
vocab: usize,
bias: Vec<f32>,
created_at: u64,
}
fn bias_store() -> &'static RwLock<Option<Vec<f32>>> {
LOGIT_BIAS_STORE.get_or_init(|| RwLock::new(None))
}
fn adapter_dir() -> PathBuf {
env::var("ADAPTER_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("/models/adapters"))
}
fn load_active_logit_bias() -> Option<Vec<f32>> {
let path = adapter_dir().join("active_logit_bias.json");
if !path.exists() {
return None;
}
match fs::read_to_string(&path) {
Ok(s) => match serde_json::from_str::<LogitBiasFile>(&s) {
Ok(f) => Some(f.bias),
Err(_) => None,
},
Err(_) => None,
}
}
fn persist_logit_bias(bias: &[f32]) -> Result<PathBuf> {
let dir = adapter_dir();
fs::create_dir_all(&dir).ok();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let file = LogitBiasFile {
vocab: bias.len(),
bias: bias.to_vec(),
created_at: now,
};
let active_path = dir.join("active_logit_bias.json");
let named_path = dir.join(format!("logit_bias_{}.json", now));
let data = serde_json::to_string_pretty(&file)?;
fs::write(&named_path, &data)?;
fs::write(&active_path, &data)?;
Ok(active_path)
}
impl Default for TrainParams {
fn default() -> Self {
Self {
learning_rate: Some(0.05),
epochs: Some(1),
max_examples: None,
bias_cap: Some(2.0),
topk_updates: Some(64),
}
}
}
fn ensure_engine(params: &CandleRunParams) -> Result<Arc<CandleEngine>> {
if let Some(engine) = ENGINE.get() {
return Ok(engine.clone());
}
let device = candle_examples::device(params.cpu)?;
let dtype = DType::F16;
let model_id = params
.model_id
.clone()
.unwrap_or_else(|| "HuggingFaceTB/SmolLM2-1.7B-Instruct".to_string());
let revision = params.revision.clone().unwrap_or_else(|| "main".into());
let api = Api::new()?;
let api = api.repo(Repo::with_revision(
model_id.clone(),
RepoType::Model,
revision.clone(),
));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let llama_cfg: model::LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = llama_cfg.into_config(false);
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")
.unwrap_or_else(|_| vec![api.get("model.safetensors").expect("weights")]);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let llama = model::Llama::load(vb, &config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename.clone()).map_err(anyhow::Error::msg)?;
let eos_token_id = config.eos_token_id.clone().or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single)
});
let engine = Arc::new(CandleEngine {
device,
dtype,
llama,
config,
tokenizer,
eos_token_id,
model_id,
revision,
});
let _ = bias_store();
{
let mut w = bias_store().write().unwrap();
if w.is_none() {
if let Some(b) = load_active_logit_bias() {
*w = Some(b);
}
}
}
let _ = ENGINE.set(engine.clone());
Ok(engine)
}
pub fn preload_local_candle(params: &CandleRunParams) -> Result<()> {
let _ = ensure_engine(params)?;
Ok(())
}
pub fn generate_local_candle(
sys: &str,
user: &str,
stop: Option<Vec<String>>,
params: &CandleRunParams,
) -> Result<String> {
let engine = ensure_engine(params)?;
let mut cache = model::Cache::new(true, engine.dtype, &engine.config, &engine.device)?;
let llama = &engine.llama;
let mut final_prompt = String::new();
if !sys.trim().is_empty() {
final_prompt.push_str("System: ");
final_prompt.push_str(sys);
final_prompt.push_str("\n");
}
final_prompt.push_str("User: ");
final_prompt.push_str(user);
final_prompt.push_str("\nAssistant: ");
let mut tokens = engine
.tokenizer
.clone()
.encode(final_prompt.as_str(), true)
.map_err(anyhow::Error::msg)?
.get_ids()
.to_vec();
let mut tok_stream = {
let t = engine.tokenizer.clone();
candle_examples::token_output_stream::TokenOutputStream::new(t)
};
let eos_ids: Option<Vec<u32>> = match engine.eos_token_id.clone() {
Some(model::LlamaEosToks::Single(id)) => Some(vec![id]),
Some(model::LlamaEosToks::Multiple(ids)) => Some(ids),
None => None,
};
let t = params.temperature as f64;
let sampling = if params.temperature <= 0.0 {
Sampling::ArgMax
} else {
match (params.top_k, params.top_p) {
(None, None) => Sampling::All { temperature: t },
(Some(k), None) => Sampling::TopK { k, temperature: t },
(None, Some(p)) => Sampling::TopP {
p: p as f64,
temperature: t,
},
(Some(k), Some(p)) => Sampling::TopKThenTopP {
k,
p: p as f64,
temperature: t,
},
}
};
let mut logits_processor = LogitsProcessor::from_sampling(params.seed.unwrap_or(42), sampling);
let mut index_pos = 0usize;
let mut generated = 0usize;
let mut out = String::new();
for index in 0..params.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &engine.device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if params.repeat_penalty == 1.0 {
logits
} else {
let start_at = tokens.len().saturating_sub(params.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
params.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let logits = {
let r = bias_store().read().unwrap();
if let Some(bias) = &*r {
let mut data = logits.to_vec1::<f32>()?;
if data.len() == bias.len() {
for i in 0..data.len() {
data[i] += bias[i];
}
}
Tensor::new(&data[..], &engine.device)?
} else {
logits
}
};
let logits = if generated < params.min_tokens {
if let Some(ref ids) = eos_ids {
let mut data = logits.to_vec1::<f32>()?;
for id in ids {
let i = *id as usize;
if i < data.len() {
data[i] = f32::NEG_INFINITY;
}
}
Tensor::new(&data[..], &engine.device)?
} else {
logits
}
} else {
logits
};
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
generated += 1;
if let Some(ref ids) = eos_ids {
if generated >= params.min_tokens && ids.contains(&next_token) {
break;
}
}
if let Some(t) = tok_stream.next_token(next_token)? {
out.push_str(&t);
}
if let Some(stops) = &stop {
if stops.iter().any(|s| out.ends_with(s) || out.contains(s)) {
break;
}
}
}
if let Some(rest) = tok_stream.decode_rest().map_err(anyhow::Error::msg)? {
out.push_str(&rest);
}
Ok(out.trim().to_string())
}
pub fn train_logit_bias(
examples: &[TrainExample],
params: Option<TrainParams>,
run: &CandleRunParams,
) -> Result<TrainResult> {
let params = params.unwrap_or_default();
let lr = params.learning_rate.unwrap_or(0.05);
let epochs = params.epochs.unwrap_or(1).max(1);
let max_examples = params.max_examples;
let bias_cap = params.bias_cap.unwrap_or(2.0);
let topk = params.topk_updates;
let engine = ensure_engine(run)?;
let llama = &engine.llama;
let mut bias_guard = bias_store().write().unwrap();
if bias_guard.is_none() {
*bias_guard = load_active_logit_bias();
}
let mut used_examples = 0usize;
let mut bias_vec: Vec<f32> = Vec::new();
for _epoch in 0..epochs {
used_examples = 0;
let mut grad_accum: Option<Vec<f32>> = None;
'outer: for ex in examples.iter() {
if let Some(m) = max_examples {
if used_examples >= m {
break 'outer;
}
}
let sys = ex.system.as_deref().unwrap_or("");
let mut prefix = String::new();
if !sys.trim().is_empty() {
prefix.push_str("System: ");
prefix.push_str(sys);
prefix.push('\n');
}
prefix.push_str("User: ");
prefix.push_str(&ex.user);
prefix.push('\n');
prefix.push_str("Assistant: ");
let full = format!("{}{}", prefix, ex.assistant);
let prefix_ids = engine
.tokenizer
.clone()
.encode(prefix.as_str(), true)
.map_err(anyhow::Error::msg)?
.get_ids()
.to_vec();
let full_ids = engine
.tokenizer
.clone()
.encode(full.as_str(), true)
.map_err(anyhow::Error::msg)?
.get_ids()
.to_vec();
if full_ids.len() <= prefix_ids.len() + 1 {
continue;
}
let mut cache = model::Cache::new(true, engine.dtype, &engine.config, &engine.device)?;
let mut index_pos = 0usize;
for pos in 0..(full_ids.len() - 1) {
let (context_size, context_index) = if cache.use_kv_cache && pos > 0 {
(1, index_pos)
} else {
(pos + 1, 0)
};
let ctxt = &full_ids[(pos + 1).saturating_sub(context_size)..=pos];
let input = Tensor::new(ctxt, &engine.device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let mut logv = logits.to_vec1::<f32>()?;
if bias_vec.is_empty() {
let vocab = logv.len();
bias_vec = match &*bias_guard {
Some(b) if b.len() == vocab => b.clone(),
_ => vec![0.0; vocab],
};
}
if grad_accum.is_none() {
grad_accum = Some(vec![0.0; logv.len()]);
}
if logv.len() == bias_vec.len() {
for i in 0..logv.len() {
logv[i] += bias_vec[i];
}
}
if pos < prefix_ids.len() {
continue;
}
let target = full_ids[pos + 1] as usize;
if target >= logv.len() {
continue;
}
let mut maxv = f32::NEG_INFINITY;
for &v in &logv {
if v > maxv {
maxv = v;
}
}
let mut sum = 0.0f32;
for v in &mut logv {
*v = (*v - maxv).exp();
sum += *v;
}
if sum == 0.0 {
continue;
}
for v in &mut logv {
*v /= sum;
}
if let Some(ga) = grad_accum.as_mut() {
for i in 0..ga.len() {
ga[i] += logv[i];
}
ga[target] -= 1.0;
}
}
used_examples += 1;
}
if let Some(ga) = grad_accum {
if topk.unwrap_or(0) > 0 {
let k = topk.unwrap();
let mut idxs: Vec<usize> = (0..ga.len()).collect();
idxs.sort_unstable_by(|&a, &b| {
ga[b]
.abs()
.partial_cmp(&ga[a].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
for &i in idxs.iter().take(k) {
bias_vec[i] = (bias_vec[i] - lr * ga[i]).clamp(-bias_cap, bias_cap);
}
} else {
for i in 0..ga.len() {
bias_vec[i] = (bias_vec[i] - lr * ga[i]).clamp(-bias_cap, bias_cap);
}
}
}
}
let path = persist_logit_bias(&bias_vec)?;
*bias_guard = Some(bias_vec.clone());
drop(bias_guard);
Ok(TrainResult {
adapter_path: path.to_string_lossy().to_string(),
epochs,
examples: used_examples,
vocab: bias_vec.len(),
})
}