use std::collections::{HashMap, HashSet};
use std::io::{BufRead, BufReader};
use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use candle_core::{D, DType, Device, Tensor, Var};
use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder};
use tokenizers::Tokenizer;
pub struct FineTuneConfig {
pub model_dir: PathBuf,
pub data_path: PathBuf,
pub output_dir: PathBuf,
pub lora_rank: usize,
pub lora_alpha: f32,
pub learning_rate: f64,
pub epochs: usize,
pub max_seq_len: usize,
}
pub enum FineTuneProgress {
Validating,
LoadingModel,
Tokenizing { done: usize, total: usize },
Training {
epoch: usize,
total_epochs: usize,
step: usize,
total_steps: usize,
loss: f32,
},
Saving,
Done { adapter_path: PathBuf },
Failed(String),
}
pub fn start_finetune(
config: FineTuneConfig,
tx: tokio::sync::mpsc::UnboundedSender<FineTuneProgress>,
) {
std::thread::spawn(move || {
if let Err(e) = run_finetune(&config, &tx) {
eprintln!("[finetune] error: {e:#}");
let _ = tx.send(FineTuneProgress::Failed(format!("{e:#}")));
}
});
}
fn run_finetune(
config: &FineTuneConfig,
tx: &tokio::sync::mpsc::UnboundedSender<FineTuneProgress>,
) -> Result<()> {
let _ = tx.send(FineTuneProgress::Validating);
let model_dir = &config.model_dir;
let config_json_path = model_dir.join("config.json");
if !config_json_path.exists() {
anyhow::bail!(
"model config not found at {:?}; is this a valid HuggingFace model directory?",
config_json_path
);
}
if !config.data_path.exists() {
anyhow::bail!("training data file not found at {:?}", config.data_path);
}
let config_text = std::fs::read_to_string(&config_json_path)
.with_context(|| format!("reading {:?}", config_json_path))?;
let model_cfg: ModelConfig = serde_json::from_str(&config_text)
.with_context(|| format!("parsing {:?}", config_json_path))?;
#[cfg(target_os = "macos")]
let device = Device::new_metal(0).unwrap_or(Device::Cpu);
#[cfg(not(target_os = "macos"))]
let device = Device::Cpu;
let _ = tx.send(FineTuneProgress::LoadingModel);
let index_path = model_dir.join("model.safetensors.index.json");
let vb = if index_path.exists() {
let index_text = std::fs::read_to_string(&index_path)
.with_context(|| format!("reading {:?}", index_path))?;
let index: IndexJson = serde_json::from_str(&index_text)
.with_context(|| format!("parsing {:?}", index_path))?;
let mut seen: HashSet<String> = HashSet::new();
let mut shard_paths: Vec<PathBuf> = Vec::new();
for filename in index.weight_map.values() {
if seen.insert(filename.clone()) {
shard_paths.push(model_dir.join(filename));
}
}
shard_paths.sort();
unsafe {
VarBuilder::from_mmaped_safetensors(&shard_paths, DType::F32, &device)
.context("loading sharded safetensors")?
}
} else {
let single = model_dir.join("model.safetensors");
unsafe {
VarBuilder::from_mmaped_safetensors(&[single], DType::F32, &device)
.context("loading model.safetensors")?
}
};
let model = LoraQwenModel::load(vb, &model_cfg, config.lora_rank, config.lora_alpha, &device)
.context("building LoRA model")?;
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("loading tokenizer from {:?}: {e}", tokenizer_path))?;
let file = std::fs::File::open(&config.data_path)
.with_context(|| format!("opening {:?}", config.data_path))?;
let reader = BufReader::new(file);
let raw_lines: Vec<String> = reader
.lines()
.collect::<std::io::Result<Vec<_>>>()
.context("reading training data")?;
let total_lines = raw_lines.len();
let mut token_batches: Vec<Vec<u32>> = Vec::new();
for (i, line) in raw_lines.iter().enumerate() {
let line = line.trim();
if line.is_empty() {
let _ = tx.send(FineTuneProgress::Tokenizing {
done: i + 1,
total: total_lines,
});
continue;
}
let entry: DataEntry =
serde_json::from_str(line).with_context(|| format!("parsing JSONL line {}", i + 1))?;
let encoding = tokenizer
.encode(entry.text.as_str(), false)
.map_err(|e| anyhow::anyhow!("tokenizing line {}: {e}", i + 1))?;
let ids = encoding.get_ids();
let ids: Vec<u32> = if ids.len() > config.max_seq_len {
ids[..config.max_seq_len].to_vec()
} else {
ids.to_vec()
};
if ids.len() >= 2 {
token_batches.push(ids);
}
let _ = tx.send(FineTuneProgress::Tokenizing {
done: i + 1,
total: total_lines,
});
}
if token_batches.is_empty() {
anyhow::bail!("no valid training examples found in {:?}", config.data_path);
}
let lora_vars = model.lora_vars();
let adamw_params = ParamsAdamW {
lr: config.learning_rate,
..ParamsAdamW::default()
};
let mut optimizer =
AdamW::new(lora_vars.clone(), adamw_params).context("creating AdamW optimizer")?;
let total_steps = token_batches.len();
let vocab_size = model_cfg.vocab_size;
let max_grad_norm: f64 = 1.0;
for epoch in 0..config.epochs {
for (step, batch_tokens) in token_batches.iter().enumerate() {
let seq_len = batch_tokens.len();
let input_ids = Tensor::from_slice(batch_tokens.as_slice(), (1, seq_len), &device)
.context("building input_ids tensor")?;
let logits = model.forward(&input_ids).context("model forward pass")?;
let logits_shifted = logits
.narrow(1, 0, seq_len - 1)
.and_then(|t| t.reshape((seq_len - 1, vocab_size)))
.context("preparing shifted logits")?;
let targets = Tensor::from_slice(&batch_tokens[1..], (seq_len - 1,), &device)
.context("building target tensor")?;
let loss = candle_nn::loss::cross_entropy(&logits_shifted, &targets)
.context("computing cross-entropy loss")?;
let loss_val = loss.to_scalar::<f32>().context("reading loss scalar")?;
if loss_val.is_nan() || loss_val.is_infinite() {
eprintln!(
"[finetune] WARNING: loss is {} at epoch {} step {}, skipping",
loss_val,
epoch + 1,
step + 1
);
let _ = tx.send(FineTuneProgress::Training {
epoch: epoch + 1,
total_epochs: config.epochs,
step: step + 1,
total_steps,
loss: loss_val,
});
continue;
}
let grads = loss.backward().context("backward pass")?;
let mut total_norm_sq: f64 = 0.0;
for var in &lora_vars {
if let Some(grad) = grads.get(var.as_tensor()) {
let norm_sq = grad
.sqr()
.and_then(|t| t.sum_all())
.and_then(|t| t.to_scalar::<f32>())
.unwrap_or(0.0) as f64;
total_norm_sq += norm_sq;
}
}
let total_norm = total_norm_sq.sqrt();
let grads = if total_norm > max_grad_norm {
let clip_coef = max_grad_norm / (total_norm + 1e-6);
let mut clipped_grads = grads;
for var in &lora_vars {
if let Some(grad) = clipped_grads.get(var.as_tensor()) {
let clipped = (grad * clip_coef).context("clipping gradient")?;
clipped_grads.insert(var.as_tensor(), clipped);
}
}
clipped_grads
} else {
grads
};
optimizer.step(&grads).context("optimizer step")?;
let _ = tx.send(FineTuneProgress::Training {
epoch: epoch + 1,
total_epochs: config.epochs,
step: step + 1,
total_steps,
loss: loss_val,
});
}
}
let _ = tx.send(FineTuneProgress::Saving);
let adapter_path = model
.save_lora(&config.output_dir)
.context("saving LoRA adapter")?;
let _ = tx.send(FineTuneProgress::Done { adapter_path });
Ok(())
}
#[derive(serde::Deserialize)]
#[allow(dead_code)]
struct ModelConfig {
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
num_key_value_heads: Option<usize>,
head_dim: Option<usize>,
intermediate_size: usize,
vocab_size: usize,
rms_norm_eps: Option<f64>,
rope_theta: Option<f64>,
#[serde(default)]
tie_word_embeddings: bool,
}
#[derive(serde::Deserialize)]
struct DataEntry {
text: String,
}
#[derive(serde::Deserialize)]
struct IndexJson {
weight_map: HashMap<String, String>,
}
fn rms_norm(x: &Tensor, weight: &Tensor, eps: f64) -> candle_core::Result<Tensor> {
let mean_sq = x.sqr()?.mean_keepdim(D::Minus1)?;
let normed = x.broadcast_div(&mean_sq.affine(1.0, eps)?.sqrt()?)?;
normed.broadcast_mul(weight)
}
fn apply_rope(x: &Tensor, cos: &Tensor, sin: &Tensor) -> candle_core::Result<Tensor> {
let head_dim = x.dim(D::Minus1)?;
let half = head_dim / 2;
let x1 = x.narrow(D::Minus1, 0, half)?; let x2 = x.narrow(D::Minus1, half, half)?;
let neg_x2 = x2.neg()?;
let rotated = Tensor::cat(&[&neg_x2, &x1], D::Minus1)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
x.broadcast_mul(&cos)? + rotated.broadcast_mul(&sin)?
}
fn precompute_rope(
head_dim: usize,
max_seq_len: usize,
theta: f64,
device: &Device,
) -> candle_core::Result<(Tensor, Tensor)> {
let half = head_dim / 2;
let freqs: Vec<f32> = (0..half)
.map(|i| 1.0f32 / (theta as f32).powf(2.0 * i as f32 / head_dim as f32))
.collect();
let freqs = Tensor::from_slice(freqs.as_slice(), (1, half), device)?;
let pos: Vec<f32> = (0..max_seq_len).map(|i| i as f32).collect();
let pos = Tensor::from_slice(pos.as_slice(), (max_seq_len, 1), device)?;
let freqs = pos.broadcast_mul(&freqs)?;
let cos = Tensor::cat(&[&freqs.cos()?, &freqs.cos()?], 1)?;
let sin = Tensor::cat(&[&freqs.sin()?, &freqs.sin()?], 1)?;
Ok((cos, sin))
}
struct LoraLinear {
weight: Tensor,
bias: Option<Tensor>,
lora_a: Var,
lora_b: Var,
scale: f64,
}
impl LoraLinear {
fn new(
weight: Tensor,
bias: Option<Tensor>,
rank: usize,
alpha: f32,
device: &Device,
) -> candle_core::Result<Self> {
let shape = weight.shape().dims();
let out_features = shape[0];
let in_features = shape[1];
let std_val = (1.0 / (rank as f64).sqrt()) as f32;
let lora_a = Var::randn(0.0f32, std_val, (rank, in_features), device)?;
let lora_b = Var::zeros((out_features, rank), DType::F32, device)?;
let scale = alpha as f64 / rank as f64;
Ok(Self {
weight,
bias,
lora_a,
lora_b,
scale,
})
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let (b, seq, in_f) = x.dims3()?;
let x_flat = x.reshape((b * seq, in_f))?;
let base = x_flat.matmul(&self.weight.t()?)?;
let lora = x_flat
.matmul(&self.lora_a.as_tensor().t()?)?
.matmul(&self.lora_b.as_tensor().t()?)?
.affine(self.scale, 0.0)?;
let out_f = self.weight.dim(0)?;
let combined = (base + lora)?.reshape((b, seq, out_f))?;
match &self.bias {
Some(bias) => combined.broadcast_add(bias),
None => Ok(combined),
}
}
fn vars(&self) -> Vec<Var> {
vec![self.lora_a.clone(), self.lora_b.clone()]
}
}
struct FrozenLinear {
weight: Tensor,
bias: Option<Tensor>,
}
impl FrozenLinear {
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let (b, seq, in_f) = x.dims3()?;
let x_flat = x.reshape((b * seq, in_f))?;
let out_flat = x_flat.matmul(&self.weight.t()?)?;
let out_f = self.weight.dim(0)?;
let out = out_flat.reshape((b, seq, out_f))?;
match &self.bias {
Some(bias) => out.broadcast_add(bias),
None => Ok(out),
}
}
}
struct TransformerLayer {
q_proj: LoraLinear, k_proj: FrozenLinear, v_proj: LoraLinear, o_proj: FrozenLinear, q_norm: Option<Tensor>, k_norm: Option<Tensor>, gate_proj: FrozenLinear,
up_proj: FrozenLinear,
down_proj: FrozenLinear,
input_layernorm: Tensor, post_attention_layernorm: Tensor, num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rms_norm_eps: f64,
}
impl TransformerLayer {
fn forward(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> candle_core::Result<Tensor> {
let (b, seq, _hidden) = x.dims3()?;
let normed = rms_norm(x, &self.input_layernorm, self.rms_norm_eps)?;
let q = self.q_proj.forward(&normed)?; let k = self.k_proj.forward(&normed)?; let v = self.v_proj.forward(&normed)?;
let q = q
.reshape((b, seq, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b, seq, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b, seq, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let q = if let Some(w) = &self.q_norm {
rms_norm(&q, w, self.rms_norm_eps)?
} else {
q
};
let k = if let Some(w) = &self.k_norm {
rms_norm(&k, w, self.rms_norm_eps)?
} else {
k
};
let q = apply_rope(&q.contiguous()?, cos, sin)?;
let k = apply_rope(&k.contiguous()?, cos, sin)?;
let groups = self.num_heads / self.num_kv_heads;
let (k, v) = if groups == 1 {
(k.contiguous()?, v.contiguous()?)
} else {
let k_exp = k
.unsqueeze(2)?
.expand((b, self.num_kv_heads, groups, seq, self.head_dim))?
.reshape((b, self.num_heads, seq, self.head_dim))?
.contiguous()?;
let v_exp = v
.unsqueeze(2)?
.expand((b, self.num_kv_heads, groups, seq, self.head_dim))?
.reshape((b, self.num_heads, seq, self.head_dim))?
.contiguous()?;
(k_exp, v_exp)
};
let scale = 1.0 / (self.head_dim as f64).sqrt();
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
let device = x.device();
let tril =
Tensor::tril2(seq, DType::F32, device)?.broadcast_as((b, self.num_heads, seq, seq))?;
let attn_bias = (1.0f64 - &tril)?.affine(-1e9, 0.0)?;
let attn_weights = (attn_weights + attn_bias)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_out = attn_weights.matmul(&v)?;
let attn_out = attn_out.transpose(1, 2)?.contiguous()?.reshape((
b,
seq,
self.num_heads * self.head_dim,
))?;
let attn_out = self.o_proj.forward(&attn_out)?;
let h = (x + &attn_out)?;
let normed2 = rms_norm(&h, &self.post_attention_layernorm, self.rms_norm_eps)?;
let gate = candle_nn::ops::silu(&self.gate_proj.forward(&normed2)?)?;
let up = self.up_proj.forward(&normed2)?;
let mlp_out = self.down_proj.forward(&(gate * up)?)?;
h + mlp_out
}
fn lora_vars(&self) -> Vec<Var> {
let mut v = self.q_proj.vars();
v.extend(self.v_proj.vars());
v
}
}
struct LoraQwenModel {
embed_tokens: Tensor, layers: Vec<TransformerLayer>,
norm: Tensor, lm_head: Tensor, rope_cos: Tensor, rope_sin: Tensor, hidden_size: usize,
#[allow(dead_code)]
vocab_size: usize,
rms_norm_eps: f64,
}
impl LoraQwenModel {
fn load(
vb: VarBuilder,
cfg: &ModelConfig,
rank: usize,
alpha: f32,
device: &Device,
) -> candle_core::Result<Self> {
let hidden_size = cfg.hidden_size;
let vocab_size = cfg.vocab_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads.unwrap_or(num_heads);
let head_dim = cfg.head_dim.unwrap_or(hidden_size / num_heads);
let intermediate_size = cfg.intermediate_size;
let rms_norm_eps = cfg.rms_norm_eps.unwrap_or(1e-6);
let rope_theta = cfg.rope_theta.unwrap_or(10_000.0);
let max_seq_len = 2048usize;
let embed_tokens = vb
.pp("model")
.pp("embed_tokens")
.get((vocab_size, hidden_size), "weight")?;
let mut layers: Vec<TransformerLayer> = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
let vb_layer = vb.pp("model").pp("layers").pp(i.to_string());
let vb_attn = vb_layer.pp("self_attn");
let q_out = num_heads * head_dim;
let q_weight = vb_attn.pp("q_proj").get((q_out, hidden_size), "weight")?;
let q_bias = vb_attn.pp("q_proj").get((q_out,), "bias").ok();
let q_proj = LoraLinear::new(q_weight, q_bias, rank, alpha, device)?;
let kv_out = num_kv_heads * head_dim;
let k_weight = vb_attn.pp("k_proj").get((kv_out, hidden_size), "weight")?;
let k_bias = vb_attn.pp("k_proj").get((kv_out,), "bias").ok();
let k_proj = FrozenLinear {
weight: k_weight,
bias: k_bias,
};
let v_weight = vb_attn.pp("v_proj").get((kv_out, hidden_size), "weight")?;
let v_bias = vb_attn.pp("v_proj").get((kv_out,), "bias").ok();
let v_proj = LoraLinear::new(v_weight, v_bias, rank, alpha, device)?;
let o_weight = vb_attn
.pp("o_proj")
.get((hidden_size, num_heads * head_dim), "weight")?;
let o_proj = FrozenLinear {
weight: o_weight,
bias: None,
};
let q_norm = vb_attn.pp("q_norm").get((head_dim,), "weight").ok();
let k_norm = vb_attn.pp("k_norm").get((head_dim,), "weight").ok();
let vb_mlp = vb_layer.pp("mlp");
let gate_weight = vb_mlp
.pp("gate_proj")
.get((intermediate_size, hidden_size), "weight")?;
let up_weight = vb_mlp
.pp("up_proj")
.get((intermediate_size, hidden_size), "weight")?;
let down_weight = vb_mlp
.pp("down_proj")
.get((hidden_size, intermediate_size), "weight")?;
let input_layernorm = vb_layer
.pp("input_layernorm")
.get((hidden_size,), "weight")?;
let post_attention_layernorm = vb_layer
.pp("post_attention_layernorm")
.get((hidden_size,), "weight")?;
layers.push(TransformerLayer {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
gate_proj: FrozenLinear {
weight: gate_weight,
bias: None,
},
up_proj: FrozenLinear {
weight: up_weight,
bias: None,
},
down_proj: FrozenLinear {
weight: down_weight,
bias: None,
},
input_layernorm,
post_attention_layernorm,
num_heads,
num_kv_heads,
head_dim,
rms_norm_eps,
});
}
let norm = vb.pp("model").pp("norm").get((hidden_size,), "weight")?;
let lm_head = vb
.pp("lm_head")
.get((vocab_size, hidden_size), "weight")
.unwrap_or_else(|_| embed_tokens.clone());
let (rope_cos, rope_sin) = precompute_rope(head_dim, max_seq_len, rope_theta, device)?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
rope_cos,
rope_sin,
hidden_size,
vocab_size,
rms_norm_eps,
})
}
fn forward(&self, input_ids: &Tensor) -> candle_core::Result<Tensor> {
let (b, seq) = input_ids.dims2()?;
let ids_flat = input_ids.flatten_all()?; let mut hidden = self
.embed_tokens
.embedding(&ids_flat)? .reshape((b, seq, self.hidden_size))?;
let cos = self.rope_cos.narrow(0, 0, seq)?; let sin = self.rope_sin.narrow(0, 0, seq)?;
for layer in &self.layers {
hidden = layer.forward(&hidden, &cos, &sin)?;
}
let hidden = rms_norm(&hidden, &self.norm, self.rms_norm_eps)?;
let hidden_flat = hidden.reshape((b * seq, self.hidden_size))?;
let logits_flat = hidden_flat.matmul(&self.lm_head.t()?)?;
logits_flat.reshape((b, seq, self.vocab_size))
}
fn lora_vars(&self) -> Vec<Var> {
self.layers
.iter()
.flat_map(|layer| layer.lora_vars())
.collect()
}
fn save_lora(&self, output_dir: &Path) -> Result<PathBuf> {
std::fs::create_dir_all(output_dir)
.with_context(|| format!("creating output directory {:?}", output_dir))?;
let mut tensors: HashMap<String, Tensor> = HashMap::new();
for (i, layer) in self.layers.iter().enumerate() {
tensors.insert(
format!("model.layers.{i}.self_attn.q_proj.lora_a"),
layer.q_proj.lora_a.as_tensor().clone(),
);
tensors.insert(
format!("model.layers.{i}.self_attn.q_proj.lora_b"),
layer.q_proj.lora_b.as_tensor().clone(),
);
tensors.insert(
format!("model.layers.{i}.self_attn.v_proj.lora_a"),
layer.v_proj.lora_a.as_tensor().clone(),
);
tensors.insert(
format!("model.layers.{i}.self_attn.v_proj.lora_b"),
layer.v_proj.lora_b.as_tensor().clone(),
);
}
let path = output_dir.join("lora_adapter.safetensors");
candle_core::safetensors::save(&tensors, &path)
.with_context(|| format!("writing {:?}", path))?;
Ok(path)
}
}