1use candle_core::{DType, Device, IndexOp, Tensor};
4use candle_nn::{Module, VarBuilder, VarMap};
5use candle_transformers::models::llama::{Cache, Llama, LlamaConfig, LlamaEosToks};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8
9use crate::config::{AdapterType, AxolotlConfig};
10use crate::error::{AxolotlError, Result};
11
12#[cfg(feature = "peft")]
13use peft_rs::{LoraConfig as PeftLoraConfig, LoraLayer, SaveLoad};
14
15#[cfg(feature = "qlora")]
16use qlora_rs::{QLoraConfig, QuantizedLinear};
17
18#[cfg(feature = "peft")]
19use crate::lora_llama::LoraLlama;
20
21#[cfg(all(feature = "peft", feature = "qlora"))]
22use super::qlora_llama::{prepare_for_qlora_training, QLoraLlama};
23
24#[cfg(test)]
26use crate::config::{DatasetConfig, LoraSettings, QuantType, QuantizationSettings, TrainingConfig};
27
28pub struct LoadedModel {
30 pub model: Box<dyn Module>,
32 pub tokenizer: tokenizers::Tokenizer,
34 #[allow(dead_code)]
36 pub device: Device,
37 #[allow(dead_code)]
39 pub dtype: DType,
40 #[allow(dead_code)]
42 pub adapter_layers: Option<AdapterLayers>,
43 pub trainable_params: VarMap,
45}
46
47#[derive(Default)]
49pub struct AdapterLayers {
50 #[cfg(feature = "peft")]
52 pub lora_layers: HashMap<String, LoraLayer>,
53 #[cfg(feature = "qlora")]
55 pub qlora_layers: HashMap<String, QuantizedLinear>,
56 #[allow(dead_code)]
58 pub is_quantized: bool,
59}
60
61#[cfg(not(feature = "peft"))]
62#[allow(dead_code)]
63impl AdapterLayers {
64 pub fn lora_layers(&self) -> &HashMap<String, ()> {
66 static EMPTY: std::sync::OnceLock<HashMap<String, ()>> = std::sync::OnceLock::new();
67 EMPTY.get_or_init(HashMap::new)
68 }
69}
70
71#[allow(dead_code)]
72impl AdapterLayers {
73 #[must_use]
75 pub fn new(is_quantized: bool) -> Self {
76 Self {
77 #[cfg(feature = "peft")]
78 lora_layers: HashMap::new(),
79 #[cfg(feature = "qlora")]
80 qlora_layers: HashMap::new(),
81 is_quantized,
82 }
83 }
84
85 #[must_use]
87 pub fn len(&self) -> usize {
88 #[cfg(feature = "qlora")]
89 if self.is_quantized {
90 return self.qlora_layers.len();
91 }
92 #[cfg(feature = "peft")]
93 return self.lora_layers.len();
94 #[cfg(not(feature = "peft"))]
95 0
96 }
97
98 #[must_use]
100 pub fn is_empty(&self) -> bool {
101 self.len() == 0
102 }
103}
104
105impl LoadedModel {
106 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
112 self.model
113 .forward(input_ids)
114 .map_err(|e| AxolotlError::Model(format!("Forward pass failed: {}", e)))
115 }
116
117 pub fn forward_with_adapters(&self, input_ids: &Tensor) -> Result<Tensor> {
130 let logits = self.forward(input_ids)?;
132
133 tracing::trace!("Forward pass complete (base model only, LoRA not integrated yet)");
139
140 Ok(logits)
141 }
142
143 #[must_use]
147 #[allow(dead_code)]
148 pub fn trainable_tensors(&self) -> Vec<candle_core::Var> {
149 self.trainable_params.all_vars()
150 }
151
152 #[must_use]
154 #[allow(dead_code)]
155 pub fn trainable_param_count(&self) -> usize {
156 self.trainable_tensors()
157 .iter()
158 .map(|v| v.elem_count())
159 .sum()
160 }
161
162 #[cfg(feature = "peft")]
168 pub fn save_adapter_weights<P: AsRef<Path>>(&self, path: P) -> Result<()> {
169 let adapter_layers = self
170 .adapter_layers
171 .as_ref()
172 .ok_or_else(|| AxolotlError::Model("No adapter layers to save".into()))?;
173
174 let dir = path.as_ref();
175 std::fs::create_dir_all(dir)?;
176
177 let mut all_tensors: Vec<(String, Tensor)> = Vec::new();
179
180 for (name, layer) in &adapter_layers.lora_layers {
181 if let Ok(state) = layer.state_dict() {
183 for (key, tensor) in state {
184 all_tensors.push((format!("{}.{}", name, key), tensor));
185 }
186 }
187 }
188
189 let weights_path = dir.join("adapter_model.safetensors");
191 let tensors_ref: Vec<(&str, Tensor)> = all_tensors
192 .iter()
193 .map(|(name, tensor)| (name.as_str(), tensor.clone()))
194 .collect();
195
196 safetensors::tensor::serialize_to_file(tensors_ref, &None, &weights_path).map_err(|e| {
197 AxolotlError::Checkpoint(format!("Failed to save adapter: {}", e).into())
198 })?;
199
200 tracing::info!("Saved {} adapter layers to {:?}", adapter_layers.len(), dir);
201 Ok(())
202 }
203
204 #[cfg(feature = "peft")]
210 pub fn load_adapter_weights<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
211 let dir = path.as_ref();
212 let weights_path = dir.join("adapter_model.safetensors");
213
214 let tensors = candle_core::safetensors::load(&weights_path, &self.device).map_err(|e| {
215 AxolotlError::Checkpoint(format!("Failed to load adapter: {}", e).into())
216 })?;
217
218 tracing::info!("Loaded {} adapter tensors from {:?}", tensors.len(), dir);
219
220 Ok(())
222 }
223
224 #[cfg(feature = "peft")]
229 pub fn capture_lora_weights(
230 &self,
231 ) -> Result<std::collections::HashMap<String, (Vec<f32>, Vec<f32>)>> {
232 use std::collections::HashMap;
233
234 let mut weights = HashMap::new();
235
236 if let Some(adapter_layers) = &self.adapter_layers {
237 for (module_name, _lora_layer) in &adapter_layers.lora_layers {
238 weights.insert(module_name.clone(), (Vec::new(), Vec::new()));
241 }
242 }
243
244 Ok(weights)
245 }
246
247 #[cfg(feature = "peft")]
252 pub fn verify_lora_weight_updates(
253 &self,
254 initial_weights: &std::collections::HashMap<String, (Vec<f32>, Vec<f32>)>,
255 ) -> Result<bool> {
256 if initial_weights.is_empty() {
257 return Ok(false);
258 }
259
260 let current_weights = self.capture_lora_weights()?;
261
262 for (module_name, (initial_a, initial_b)) in initial_weights {
264 if let Some((current_a, current_b)) = current_weights.get(module_name) {
265 let a_changed = if !initial_a.is_empty() && !current_a.is_empty() {
267 let diff: f64 = initial_a
268 .iter()
269 .zip(current_a.iter())
270 .map(|(i, c)| ((i - c) as f64).abs())
271 .sum();
272 diff > 0.0
273 } else {
274 false
275 };
276
277 let b_changed = if !initial_b.is_empty() && !current_b.is_empty() {
279 let diff: f64 = initial_b
280 .iter()
281 .zip(current_b.iter())
282 .map(|(i, c)| ((i - c) as f64).abs())
283 .sum();
284 diff > 0.0
285 } else {
286 false
287 };
288
289 if a_changed || b_changed {
290 tracing::debug!(
291 "LoRA weights updated in {}: A={}, B={}",
292 module_name,
293 a_changed,
294 b_changed
295 );
296 return Ok(true);
297 }
298 }
299 }
300
301 Ok(false)
302 }
303}
304
305#[derive(Debug, Clone)]
310pub struct ModelInfo {
311 pub hidden_size: usize,
313 pub num_layers: usize,
315 #[allow(dead_code)]
317 pub num_attention_heads: usize,
318 pub num_kv_heads: usize,
320 #[allow(dead_code)]
322 pub intermediate_size: usize,
323}
324
325impl ModelInfo {
326 pub fn from_llama_config(config: &LlamaConfig) -> Self {
328 Self {
329 hidden_size: config.hidden_size,
330 num_layers: config.num_hidden_layers,
331 num_attention_heads: config.num_attention_heads,
332 num_kv_heads: config
333 .num_key_value_heads
334 .unwrap_or(config.num_attention_heads),
335 intermediate_size: config.intermediate_size,
336 }
337 }
338
339 #[allow(dead_code)]
348 pub fn get_target_dims(&self, target: &str) -> (usize, usize) {
349 match target {
350 "q_proj" | "o_proj" => (self.hidden_size, self.hidden_size),
352 "k_proj" | "v_proj" => {
353 let kv_dim = self.hidden_size * self.num_kv_heads / self.num_attention_heads;
354 (self.hidden_size, kv_dim)
355 }
356 "gate_proj" | "up_proj" => (self.hidden_size, self.intermediate_size),
358 "down_proj" => (self.intermediate_size, self.hidden_size),
359 _ => (self.hidden_size, self.hidden_size),
361 }
362 }
363
364 #[cfg(test)]
366 pub fn default_7b() -> Self {
367 Self {
368 hidden_size: 4096,
369 num_layers: 32,
370 num_attention_heads: 32,
371 num_kv_heads: 32,
372 intermediate_size: 11008,
373 }
374 }
375}
376
377pub fn load_model(config: &AxolotlConfig, device: &Device) -> Result<LoadedModel> {
383 tracing::info!("Loading model: {}", config.base_model);
384
385 let model_path = resolve_model_path(&config.base_model)?;
387
388 let tokenizer = load_tokenizer(&model_path)?;
390 tracing::info!(
391 "Loaded tokenizer with vocab size: {}",
392 tokenizer.get_vocab_size(true)
393 );
394
395 let model_info = load_model_info(&model_path)?;
397 tracing::info!(
398 "Model info: hidden_size={}, num_layers={}, kv_heads={}",
399 model_info.hidden_size,
400 model_info.num_layers,
401 model_info.num_kv_heads
402 );
403
404 let dtype = DType::F32;
408
409 if config.quantization.is_some() {
410 tracing::info!("QLoRA mode: using F32 for model (quantization applied to weights)");
411 }
412
413 let trainable_params = VarMap::new();
415
416 let use_lora_model = config.adapter == AdapterType::Lora;
418 let use_qlora_model = config.adapter == AdapterType::Qlora;
419
420 let (model, adapter_layers) = if use_qlora_model {
422 #[cfg(all(feature = "peft", feature = "qlora"))]
424 {
425 let quant_settings = config.quantization.as_ref().ok_or_else(|| {
426 AxolotlError::Config("QLoRA requires quantization settings".into())
427 })?;
428
429 let qlora_config = qlora_rs::QLoraConfig {
430 lora: peft_rs::LoraConfig {
431 r: config.lora.r,
432 alpha: config.lora.alpha,
433 dropout: config.lora.dropout,
434 target_modules: config.lora.target_modules.clone(),
435 ..Default::default()
436 },
437 quantization: qlora_rs::QuantizationConfig {
438 block_size: quant_settings.block_size,
439 double_quant: quant_settings.double_quant,
440 compute_dtype: qlora_rs::quantization::ComputeDType::BF16,
443 ..Default::default()
444 },
445 target_modules: config.lora.target_modules.clone(),
446 cache_dequantized: false, };
448
449 let model = load_qlora_model(
450 config,
451 &model_path,
452 device,
453 dtype,
454 &qlora_config,
455 &trainable_params,
456 )?;
457
458 (model, None)
460 }
461 #[cfg(not(all(feature = "peft", feature = "qlora")))]
462 {
463 return Err(AxolotlError::Model(
464 "QLoRA requested but peft and/or qlora features not enabled".into(),
465 ));
466 }
467 } else if use_lora_model {
468 #[cfg(feature = "peft")]
471 {
472 let lora_config = PeftLoraConfig {
473 r: config.lora.r,
474 alpha: config.lora.alpha,
475 dropout: config.lora.dropout,
476 target_modules: config.lora.target_modules.clone(),
477 ..Default::default()
478 };
479
480 let model = load_model_architecture(
481 config,
482 &model_path,
483 device,
484 dtype,
485 None,
486 Some((&model_info, &trainable_params, &lora_config)),
487 )?;
488 (model, None)
490 }
491 #[cfg(not(feature = "peft"))]
492 {
493 return Err(AxolotlError::Model(
494 "LoRA requested but peft feature not enabled".into(),
495 ));
496 }
497 } else {
498 let model = load_model_architecture(config, &model_path, device, dtype, None, None)?;
500 let adapter_layers = create_adapter_layers(config, &model_info, device, &trainable_params)?;
501 (model, adapter_layers)
502 };
503
504 let adapter_count = adapter_layers.as_ref().map_or(0, AdapterLayers::len);
505 let trainable_count: usize = trainable_params
506 .all_vars()
507 .iter()
508 .map(|v| v.elem_count())
509 .sum();
510
511 tracing::info!(
512 "Model loaded on {:?} with dtype {:?}, {} adapter layers, {} trainable params",
513 device,
514 dtype,
515 adapter_count,
516 trainable_count
517 );
518
519 Ok(LoadedModel {
520 model,
521 tokenizer,
522 device: device.clone(),
523 dtype,
524 adapter_layers,
525 trainable_params,
526 })
527}
528
529#[allow(unused_variables)]
534fn create_adapter_layers(
535 config: &AxolotlConfig,
536 model_info: &ModelInfo,
537 device: &Device,
538 trainable_params: &VarMap,
539) -> Result<Option<AdapterLayers>> {
540 match config.adapter {
541 AdapterType::None => Ok(None),
542 AdapterType::Lora => {
543 #[cfg(feature = "peft")]
544 {
545 let mut layers = AdapterLayers::new(false);
546
547 let lora_config = PeftLoraConfig {
549 r: config.lora.r,
550 alpha: config.lora.alpha,
551 dropout: config.lora.dropout,
552 target_modules: config.lora.target_modules.clone(),
553 ..Default::default()
554 };
555
556 let vb = VarBuilder::from_varmap(trainable_params, DType::F32, device);
559
560 for target in &config.lora.target_modules {
562 let (in_features, out_features) = model_info.get_target_dims(target);
563
564 for layer_idx in 0..model_info.num_layers {
565 let layer_name = format!("model.layers.{}.self_attn.{}", layer_idx, target);
566
567 let layer_vb = vb.pp(&layer_name);
569 let lora_layer = LoraLayer::new(
570 in_features,
571 out_features,
572 lora_config.clone(),
573 layer_vb,
574 )
575 .map_err(|e| {
576 AxolotlError::Model(format!(
577 "Failed to create LoRA layer {}: {}",
578 layer_name, e
579 ))
580 })?;
581
582 layers.lora_layers.insert(layer_name, lora_layer);
583 }
584 }
585
586 tracing::info!(
587 "Created {} LoRA layers with r={}, alpha={}",
588 layers.len(),
589 config.lora.r,
590 config.lora.alpha
591 );
592
593 Ok(Some(layers))
594 }
595 #[cfg(not(feature = "peft"))]
596 {
597 tracing::warn!("LoRA requested but peft feature not enabled");
598 Ok(None)
599 }
600 }
601 AdapterType::Qlora => {
602 #[cfg(feature = "qlora")]
603 {
604 let quant_settings = config.quantization.as_ref().ok_or_else(|| {
605 AxolotlError::Config("QLoRA requires quantization settings".into())
606 })?;
607
608 let mut layers = AdapterLayers::new(true);
609
610 let qlora_config = QLoraConfig {
612 lora: peft_rs::LoraConfig {
613 r: config.lora.r,
614 alpha: config.lora.alpha,
615 dropout: config.lora.dropout,
616 target_modules: config.lora.target_modules.clone(),
617 ..Default::default()
618 },
619 quantization: qlora_rs::QuantizationConfig {
620 block_size: quant_settings.block_size,
621 double_quant: quant_settings.double_quant,
622 ..Default::default()
623 },
624 target_modules: config.lora.target_modules.clone(),
625 cache_dequantized: false, };
627
628 let vb = VarBuilder::from_varmap(trainable_params, DType::F32, device);
630
631 for target in &config.lora.target_modules {
633 let (in_features, out_features) = model_info.get_target_dims(target);
634
635 for layer_idx in 0..model_info.num_layers {
636 let layer_name = format!("model.layers.{}.self_attn.{}", layer_idx, target);
637
638 let weight =
641 Tensor::zeros(&[out_features, in_features], DType::F32, device)
642 .map_err(|e| {
643 AxolotlError::Model(format!(
644 "Failed to create weight tensor for {}: {}",
645 layer_name, e
646 ))
647 })?;
648
649 let layer_vb = vb.pp(&layer_name);
651 let qlora_layer = QuantizedLinear::from_weight_with_varbuilder(
652 &weight,
653 None,
654 &qlora_config,
655 layer_vb,
656 )
657 .map_err(|e| {
658 AxolotlError::Model(format!(
659 "Failed to create QLoRA layer {}: {}",
660 layer_name, e
661 ))
662 })?;
663
664 layers.qlora_layers.insert(layer_name, qlora_layer);
665 }
666 }
667
668 tracing::info!(
669 "Created {} QLoRA layers with r={}, alpha={}, {}bit quantization",
670 layers.len(),
671 config.lora.r,
672 config.lora.alpha,
673 quant_settings.bits
674 );
675
676 Ok(Some(layers))
677 }
678 #[cfg(not(feature = "qlora"))]
679 {
680 tracing::warn!("QLoRA requested but qlora feature not enabled");
681 Ok(None)
682 }
683 }
684 }
685}
686
687fn load_model_info(model_path: &PathBuf) -> Result<ModelInfo> {
689 let config_path = model_path.join("config.json");
690
691 if config_path.exists() {
692 let config_str = std::fs::read_to_string(&config_path)
693 .map_err(|e| AxolotlError::Model(format!("Failed to read config.json: {}", e)))?;
694 let llama_config: LlamaConfig = serde_json::from_str(&config_str)
695 .map_err(|e| AxolotlError::Model(format!("Failed to parse config.json: {}", e)))?;
696 Ok(ModelInfo::from_llama_config(&llama_config))
697 } else {
698 tracing::warn!("config.json not found, using default LLaMA-7B dimensions");
700 Ok(ModelInfo {
701 hidden_size: 4096,
702 num_layers: 32,
703 num_attention_heads: 32,
704 num_kv_heads: 32,
705 intermediate_size: 11008,
706 })
707 }
708}
709
710fn resolve_model_path(model_id: &str) -> Result<PathBuf> {
712 let path = PathBuf::from(model_id);
714 if path.exists() {
715 return Ok(path);
716 }
717
718 let cache_dir = std::env::var("HF_HOME")
720 .or_else(|_| std::env::var("HOME").map(|h| format!("{}/.cache/huggingface", h)))
721 .unwrap_or_else(|_| "/tmp/huggingface".to_string());
722
723 let hf_path = PathBuf::from(format!(
724 "{}/hub/models--{}",
725 cache_dir,
726 model_id.replace("/", "--")
727 ));
728
729 if hf_path.exists() {
730 Ok(hf_path)
731 } else {
732 Err(AxolotlError::Model(format!(
733 "Model not found at '{}' or in HF cache at '{:?}'. Use `huggingface-cli download {}` to download.",
734 model_id, hf_path, model_id
735 )))
736 }
737}
738
739fn load_tokenizer(model_path: &PathBuf) -> Result<tokenizers::Tokenizer> {
741 let tokenizer_file = model_path.join("tokenizer.json");
742
743 if !tokenizer_file.exists() {
744 return Err(AxolotlError::Tokenizer(
745 format!("tokenizer.json not found in {:?}", model_path).into(),
746 ));
747 }
748
749 tokenizers::Tokenizer::from_file(&tokenizer_file)
750 .map_err(|e| AxolotlError::Tokenizer(format!("Failed to load tokenizer: {}", e).into()))
751}
752
753fn load_model_architecture(
755 config: &AxolotlConfig,
756 model_path: &PathBuf,
757 device: &Device,
758 dtype: DType,
759 _adapter_layers: Option<&AdapterLayers>,
760 #[cfg(feature = "peft")] lora_params: Option<(&ModelInfo, &VarMap, &PeftLoraConfig)>,
761 #[cfg(not(feature = "peft"))] lora_params: Option<(&ModelInfo, &VarMap)>,
762) -> Result<Box<dyn Module>> {
763 let config_path = model_path.join("config.json");
765 let is_llama_arch = if config_path.exists() {
766 let config_str = std::fs::read_to_string(&config_path).unwrap_or_default();
767 config_str.contains("LlamaForCausalLM") || config_str.contains("\"model_type\": \"llama\"")
769 } else {
770 let name_lower = config.base_model.to_lowercase();
772 name_lower.contains("llama")
773 || name_lower.contains("smollm")
774 || name_lower.contains("tinyllama")
775 };
776
777 if is_llama_arch {
778 load_llama_model(config, model_path, device, dtype, lora_params)
779 } else {
780 tracing::warn!(
782 "Architecture not supported yet: {}, using stub model",
783 config.base_model
784 );
785 let vb = VarBuilder::zeros(dtype, device);
786 let model = SimpleModel::new(vb)?;
787 Ok(Box::new(model))
788 }
789}
790
791fn load_llama_model(
793 _axolotl_config: &AxolotlConfig,
794 model_path: &PathBuf,
795 device: &Device,
796 dtype: DType,
797 #[cfg(feature = "peft")] lora_params: Option<(&ModelInfo, &VarMap, &PeftLoraConfig)>,
798 #[cfg(not(feature = "peft"))] _lora_params: Option<(&ModelInfo, &VarMap)>,
799) -> Result<Box<dyn Module>> {
800 let config_path = model_path.join("config.json");
802 let llama_config: LlamaConfig = if config_path.exists() {
803 let config_str = std::fs::read_to_string(&config_path)
804 .map_err(|e| AxolotlError::Model(format!("Failed to read config.json: {}", e)))?;
805 let parsed: LlamaConfig = serde_json::from_str(&config_str)
806 .map_err(|e| AxolotlError::Model(format!("Failed to parse config.json: {}", e)))?;
807 parsed
808 } else {
809 tracing::warn!("config.json not found, using default LLaMA 2 7B config");
811 LlamaConfig {
812 vocab_size: 32000,
813 hidden_size: 4096,
814 intermediate_size: 11008,
815 num_hidden_layers: 32,
816 num_attention_heads: 32,
817 num_key_value_heads: Some(32),
818 rms_norm_eps: 1e-5,
819 rope_theta: 10000.0,
820 bos_token_id: Some(1),
821 eos_token_id: Some(LlamaEosToks::Single(2)),
822 max_position_embeddings: 4096,
823 rope_scaling: None,
824 tie_word_embeddings: None,
825 }
826 };
827
828 let vb = if model_path.join("model.safetensors").exists() {
830 let tensors = candle_core::safetensors::load(model_path.join("model.safetensors"), device)
831 .map_err(|e| AxolotlError::Model(format!("Failed to load safetensors: {}", e)))?;
832 VarBuilder::from_tensors(tensors, dtype, device)
833 } else if model_path.join("pytorch_model.bin").exists() {
834 VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, device)
835 .map_err(|e| AxolotlError::Model(format!("Failed to load pytorch model: {}", e)))?
836 } else {
837 return Err(AxolotlError::Model(format!(
838 "No model weights found in {}. Expected model.safetensors or pytorch_model.bin",
839 model_path.display()
840 )));
841 };
842
843 let config = candle_transformers::models::llama::Config {
845 hidden_size: llama_config.hidden_size,
846 intermediate_size: llama_config.intermediate_size,
847 vocab_size: llama_config.vocab_size,
848 num_hidden_layers: llama_config.num_hidden_layers,
849 num_attention_heads: llama_config.num_attention_heads,
850 num_key_value_heads: llama_config.num_key_value_heads(),
851 use_flash_attn: false, rms_norm_eps: llama_config.rms_norm_eps,
853 rope_theta: llama_config.rope_theta,
854 bos_token_id: llama_config.bos_token_id,
855 eos_token_id: llama_config.eos_token_id,
856 rope_scaling: llama_config.rope_scaling,
857 max_position_embeddings: llama_config.max_position_embeddings,
858 tie_word_embeddings: llama_config.tie_word_embeddings.unwrap_or(false),
859 };
860
861 #[cfg(feature = "peft")]
862 let model: Box<dyn Module> =
863 if let Some((_model_info, trainable_params, lora_config)) = lora_params {
864 tracing::info!("Loading LoraLlama with per-layer LoRA injection");
865
866 let model = LoraLlama::new_with_lora(&config, vb, lora_config, trainable_params)
868 .map_err(|e| AxolotlError::Model(format!("Failed to create LoraLlama: {}", e)))?;
869
870 Box::new(model)
871 } else {
872 let model = Llama::load(vb, &config)
874 .map_err(|e| AxolotlError::Model(format!("Failed to create LLaMA model: {}", e)))?;
875
876 Box::new(LlamaWrapper::new(model, &config, device)?)
877 };
878
879 #[cfg(not(feature = "peft"))]
880 let model: Box<dyn Module> = {
881 let model = Llama::load(vb, &config)
883 .map_err(|e| AxolotlError::Model(format!("Failed to create LLaMA model: {}", e)))?;
884
885 Box::new(LlamaWrapper::new(model, &config, device)?)
886 };
887
888 tracing::info!(
889 "Loaded LLaMA model with {} layers, {} hidden size",
890 llama_config.num_hidden_layers,
891 llama_config.hidden_size
892 );
893
894 Ok(model)
895}
896
897#[cfg(all(feature = "peft", feature = "qlora"))]
916fn load_qlora_model(
917 _axolotl_config: &AxolotlConfig,
918 model_path: &PathBuf,
919 device: &Device,
920 dtype: DType,
921 qlora_config: &qlora_rs::QLoraConfig,
922 trainable_params: &VarMap,
923) -> Result<Box<dyn Module>> {
924 let config_path = model_path.join("config.json");
926 let llama_config: LlamaConfig = if config_path.exists() {
927 let config_str = std::fs::read_to_string(&config_path)
928 .map_err(|e| AxolotlError::Model(format!("Failed to read config.json: {}", e)))?;
929 serde_json::from_str(&config_str)
930 .map_err(|e| AxolotlError::Model(format!("Failed to parse config.json: {}", e)))?
931 } else {
932 return Err(AxolotlError::Model(
933 "config.json required for QLoRA model loading".into(),
934 ));
935 };
936
937 let vb = if model_path.join("model.safetensors").exists() {
939 let tensors = candle_core::safetensors::load(model_path.join("model.safetensors"), device)
940 .map_err(|e| AxolotlError::Model(format!("Failed to load safetensors: {}", e)))?;
941 VarBuilder::from_tensors(tensors, dtype, device)
942 } else if model_path.join("pytorch_model.bin").exists() {
943 VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, device)
944 .map_err(|e| AxolotlError::Model(format!("Failed to load pytorch model: {}", e)))?
945 } else {
946 return Err(AxolotlError::Model(format!(
947 "No model weights found in {}. Expected model.safetensors or pytorch_model.bin",
948 model_path.display()
949 )));
950 };
951
952 let config = candle_transformers::models::llama::Config {
954 hidden_size: llama_config.hidden_size,
955 intermediate_size: llama_config.intermediate_size,
956 vocab_size: llama_config.vocab_size,
957 num_hidden_layers: llama_config.num_hidden_layers,
958 num_attention_heads: llama_config.num_attention_heads,
959 num_key_value_heads: llama_config.num_key_value_heads(),
960 use_flash_attn: false,
961 rms_norm_eps: llama_config.rms_norm_eps,
962 rope_theta: llama_config.rope_theta,
963 bos_token_id: llama_config.bos_token_id,
964 eos_token_id: llama_config.eos_token_id,
965 rope_scaling: llama_config.rope_scaling,
966 max_position_embeddings: llama_config.max_position_embeddings,
967 tie_word_embeddings: llama_config.tie_word_embeddings.unwrap_or(false),
968 };
969
970 tracing::info!(
971 "Loading QLoraLlama with {} layers, {} hidden size, r={}, alpha={}",
972 config.num_hidden_layers,
973 config.hidden_size,
974 qlora_config.lora.r,
975 qlora_config.lora.alpha
976 );
977
978 let model = QLoraLlama::new_with_qlora(&config, vb, qlora_config, trainable_params)
980 .map_err(|e| AxolotlError::Model(format!("Failed to create QLoraLlama: {}", e)))?;
981
982 prepare_for_qlora_training(&model, trainable_params)
984 .map_err(|e| AxolotlError::Model(format!("Failed to prepare QLoRA for training: {}", e)))?;
985
986 let trainable_count: usize = trainable_params
987 .all_vars()
988 .iter()
989 .map(|v| v.elem_count())
990 .sum();
991 let total_params = model.total_param_count();
992 let trainable_pct = 100.0 * trainable_count as f64 / total_params as f64;
993
994 tracing::info!(
995 "QLoraLlama ready: {} total params, {} trainable ({:.2}%)",
996 total_params,
997 trainable_count,
998 trainable_pct
999 );
1000
1001 Ok(Box::new(model))
1002}
1003
1004struct SimpleModel {
1006 layer: candle_nn::Linear,
1007}
1008
1009impl SimpleModel {
1010 fn new(vb: VarBuilder) -> Result<Self> {
1011 let layer = candle_nn::linear(10, 10, vb)?;
1012 Ok(Self { layer })
1013 }
1014}
1015
1016impl Module for SimpleModel {
1017 fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
1018 self.layer.forward(xs)
1019 }
1020}
1021
1022pub struct LlamaWrapper {
1027 model: Llama,
1028 cache: std::cell::RefCell<Cache>,
1029 #[allow(dead_code)]
1031 training_mode: bool,
1032}
1033
1034impl LlamaWrapper {
1035 pub fn new(
1037 model: Llama,
1038 config: &candle_transformers::models::llama::Config,
1039 device: &Device,
1040 ) -> Result<Self> {
1041 let cache = Cache::new(false, DType::F32, config, device)
1042 .map_err(|e| AxolotlError::Model(format!("Failed to create cache: {}", e)))?;
1043 Ok(Self {
1044 model,
1045 cache: std::cell::RefCell::new(cache),
1046 training_mode: true, })
1048 }
1049
1050 #[allow(dead_code)]
1052 pub fn set_training_mode(&mut self, training: bool) {
1053 self.training_mode = training;
1054 }
1055}
1056
1057impl Module for LlamaWrapper {
1058 fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
1059 let mut cache = self.cache.borrow_mut();
1060
1061 self.model.forward(xs, 0, &mut cache)
1065 }
1066}
1067
1068impl LlamaWrapper {
1069 #[allow(dead_code)]
1075 fn forward_all_positions(&self, xs: &Tensor, cache: &mut Cache) -> candle_core::Result<Tensor> {
1076 let (_b_sz, seq_len) = xs.dims2()?;
1078
1079 let mut all_logits = Vec::new();
1086
1087 for pos in 0..seq_len {
1088 let input_slice = xs.i((.., 0..=pos))?;
1090 let logits = self.model.forward(&input_slice, 0, cache)?;
1091 all_logits.push(logits);
1092
1093 }
1096
1097 let stacked = Tensor::stack(&all_logits, 1)?;
1099 Ok(stacked)
1100 }
1101}
1102
1103pub fn merge_adapter(
1116 _config: &AxolotlConfig,
1117 _adapter_path: &str,
1118 _output_path: &str,
1119) -> Result<()> {
1120 Err(AxolotlError::Model(
1126 "Adapter merging not yet implemented".into(),
1127 ))
1128}
1129
1130#[cfg(feature = "download")]
1132#[allow(dead_code)]
1133pub async fn download_model(_model_id: &str, _cache_dir: &str) -> Result<String> {
1134 Err(AxolotlError::Model(
1136 "Model download not yet implemented".into(),
1137 ))
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142 use super::*;
1143 use std::fs;
1144 use tempfile::TempDir;
1145
1146 #[test]
1148 fn test_model_info_target_dims() {
1149 let smollm2 = ModelInfo {
1151 hidden_size: 576,
1152 num_layers: 30,
1153 num_attention_heads: 9,
1154 num_kv_heads: 3,
1155 intermediate_size: 1536,
1156 };
1157
1158 assert_eq!(smollm2.get_target_dims("q_proj"), (576, 576));
1160 assert_eq!(smollm2.get_target_dims("o_proj"), (576, 576));
1161
1162 assert_eq!(smollm2.get_target_dims("k_proj"), (576, 192));
1165 assert_eq!(smollm2.get_target_dims("v_proj"), (576, 192));
1166
1167 assert_eq!(smollm2.get_target_dims("gate_proj"), (576, 1536));
1169 assert_eq!(smollm2.get_target_dims("up_proj"), (576, 1536));
1170 assert_eq!(smollm2.get_target_dims("down_proj"), (1536, 576));
1171 }
1172
1173 #[test]
1175 fn test_model_info_tinyllama() {
1176 let tinyllama = ModelInfo {
1177 hidden_size: 2048,
1178 num_layers: 22,
1179 num_attention_heads: 32,
1180 num_kv_heads: 4,
1181 intermediate_size: 5632,
1182 };
1183
1184 assert_eq!(tinyllama.get_target_dims("q_proj"), (2048, 2048));
1186
1187 assert_eq!(tinyllama.get_target_dims("k_proj"), (2048, 256));
1189
1190 assert_eq!(tinyllama.get_target_dims("gate_proj"), (2048, 5632));
1192 }
1193
1194 #[test]
1196 fn test_model_info_llama7b() {
1197 let llama7b = ModelInfo::default_7b();
1198
1199 assert_eq!(llama7b.get_target_dims("q_proj"), (4096, 4096));
1201 assert_eq!(llama7b.get_target_dims("k_proj"), (4096, 4096));
1202 assert_eq!(llama7b.get_target_dims("v_proj"), (4096, 4096));
1203 }
1204
1205 #[test]
1210 fn test_load_model_llama2() {
1211 let config = AxolotlConfig {
1212 base_model: "meta-llama/Llama-2-7b-hf".to_string(),
1213 adapter: AdapterType::Lora,
1214 lora: LoraSettings::default(),
1215 quantization: None,
1216 dataset: DatasetConfig::default(),
1217 training: TrainingConfig::default(),
1218 output_dir: "./test_output".to_string(),
1219 seed: 42,
1220 };
1221
1222 let device = Device::Cpu;
1223
1224 let result = load_model(&config, &device);
1226 assert!(result.is_err());
1227 if let Err(AxolotlError::Model(msg)) = result {
1228 assert!(msg.contains("Model not found"));
1229 } else {
1230 panic!("Expected Model error");
1231 }
1232 }
1233
1234 #[test]
1239 fn test_load_model_mistral() {
1240 let config = AxolotlConfig {
1241 base_model: "mistralai/Mistral-7B-v0.1".to_string(),
1242 adapter: AdapterType::Lora,
1243 lora: LoraSettings::default(),
1244 quantization: None,
1245 dataset: DatasetConfig::default(),
1246 training: TrainingConfig::default(),
1247 output_dir: "./test_output".to_string(),
1248 seed: 42,
1249 };
1250
1251 let device = Device::Cpu;
1252
1253 let result = load_model(&config, &device);
1255 assert!(result.is_err());
1256 if let Err(AxolotlError::Model(msg)) = result {
1257 assert!(msg.contains("Model not found"));
1258 } else {
1259 panic!("Expected Model error");
1260 }
1261 }
1262
1263 #[test]
1268 fn test_load_model_phi3() {
1269 let config = AxolotlConfig {
1270 base_model: "microsoft/Phi-3-mini-4k-instruct".to_string(),
1271 adapter: AdapterType::Lora,
1272 lora: LoraSettings::default(),
1273 quantization: None,
1274 dataset: DatasetConfig::default(),
1275 training: TrainingConfig::default(),
1276 output_dir: "./test_output".to_string(),
1277 seed: 42,
1278 };
1279
1280 let device = Device::Cpu;
1281
1282 let result = load_model(&config, &device);
1284 assert!(result.is_err());
1285 if let Err(AxolotlError::Model(msg)) = result {
1286 assert!(msg.contains("Model not found"));
1287 } else {
1288 panic!("Expected Model error");
1289 }
1290 }
1291
1292 #[test]
1297 fn test_merge_adapter_lora() {
1298 let config = AxolotlConfig {
1299 base_model: "meta-llama/Llama-2-7b-hf".to_string(),
1300 adapter: AdapterType::Lora,
1301 lora: LoraSettings {
1302 r: 64,
1303 alpha: 16,
1304 dropout: 0.0,
1305 target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
1306 },
1307 quantization: None,
1308 dataset: DatasetConfig::default(),
1309 training: TrainingConfig::default(),
1310 output_dir: "./test_output".to_string(),
1311 seed: 42,
1312 };
1313
1314 let temp_dir = TempDir::new().unwrap();
1315 let adapter_path = temp_dir.path().join("adapter");
1316 fs::create_dir(&adapter_path).unwrap();
1317
1318 let result = merge_adapter(&config, adapter_path.to_str().unwrap(), "./output");
1320 assert!(result.is_err());
1321 let err = result.unwrap_err();
1322 match err {
1323 AxolotlError::Model(msg) => {
1324 assert!(msg.contains("Adapter merging not yet implemented"))
1325 }
1326 _ => panic!("Expected Model error, got {:?}", err),
1327 }
1328 }
1329
1330 #[test]
1335 fn test_merge_adapter_qlora() {
1336 let config = AxolotlConfig {
1337 base_model: "meta-llama/Llama-2-7b-hf".to_string(),
1338 adapter: AdapterType::Qlora,
1339 lora: LoraSettings {
1340 r: 64,
1341 alpha: 16,
1342 dropout: 0.0,
1343 target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
1344 },
1345 quantization: Some(QuantizationSettings {
1346 bits: 4,
1347 quant_type: QuantType::Nf4,
1348 double_quant: true,
1349 block_size: 64,
1350 }),
1351 dataset: DatasetConfig::default(),
1352 training: TrainingConfig::default(),
1353 output_dir: "./test_output".to_string(),
1354 seed: 42,
1355 };
1356
1357 let temp_dir = TempDir::new().unwrap();
1358 let adapter_path = temp_dir.path().join("adapter");
1359 fs::create_dir(&adapter_path).unwrap();
1360
1361 let result = merge_adapter(&config, adapter_path.to_str().unwrap(), "./output");
1363 assert!(result.is_err());
1364 let err = result.unwrap_err();
1365 match err {
1366 AxolotlError::Model(msg) => {
1367 assert!(msg.contains("Adapter merging not yet implemented"))
1368 }
1369 _ => panic!("Expected Model error, got {:?}", err),
1370 }
1371 }
1372
1373 #[test]
1378 fn test_merge_adapter() {
1379 let config = AxolotlConfig {
1380 base_model: "meta-llama/Llama-2-7b-hf".to_string(),
1381 adapter: AdapterType::Lora,
1382 lora: LoraSettings::default(),
1383 quantization: None,
1384 dataset: DatasetConfig::default(),
1385 training: TrainingConfig::default(),
1386 output_dir: "./test_output".to_string(),
1387 seed: 42,
1388 };
1389
1390 let temp_dir = TempDir::new().unwrap();
1391 let adapter_path = temp_dir.path().join("adapter");
1392 fs::create_dir(&adapter_path).unwrap();
1393
1394 let result = merge_adapter(&config, adapter_path.to_str().unwrap(), "./output");
1396 assert!(result.is_err());
1397 let err = result.unwrap_err();
1398 match err {
1399 AxolotlError::Model(msg) => {
1400 assert!(msg.contains("Adapter merging not yet implemented"))
1401 }
1402 _ => panic!("Expected Model error, got {:?}", err),
1403 }
1404 }
1405
1406 #[test]
1411 #[cfg(feature = "download")]
1412 fn test_download_model_from_hub() {
1413 let result: Result<String> = tokio::runtime::Runtime::new()
1415 .unwrap()
1416 .block_on(async { download_model("meta-llama/Llama-2-7b-hf", "/tmp/cache").await });
1417 assert!(result.is_err());
1418 let err = result.unwrap_err();
1419 match err {
1420 AxolotlError::Model(msg) => assert!(msg.contains("Model download not yet implemented")),
1421 _ => panic!("Expected Model error, got {:?}", err),
1422 }
1423 }
1424
1425 #[test]
1427 fn test_resolve_model_path_invalid() {
1428 let result = resolve_model_path("nonexistent-model-id");
1429 assert!(result.is_err());
1430 let err = result.unwrap_err();
1431 match err {
1432 AxolotlError::Model(msg) => assert!(msg.contains("Model not found")),
1433 _ => panic!("Expected Model error, got {:?}", err),
1434 }
1435 }
1436
1437 #[test]
1439 fn test_load_tokenizer_missing_file() {
1440 let temp_dir = TempDir::new().unwrap();
1441 let result = load_tokenizer(&temp_dir.path().to_path_buf());
1442 assert!(result.is_err());
1443 let err = result.unwrap_err();
1444 match err {
1445 AxolotlError::Tokenizer(e) => {
1446 assert!(e.to_string().contains("tokenizer.json not found"))
1447 }
1448 _ => panic!("Expected Tokenizer error, got {:?}", err),
1449 }
1450 }
1451
1452 #[test]
1454 fn test_load_model_architecture_stub() {
1455 let config = AxolotlConfig {
1456 base_model: "test-model".to_string(),
1457 adapter: AdapterType::Lora,
1458 lora: LoraSettings::default(),
1459 quantization: None,
1460 dataset: DatasetConfig::default(),
1461 training: TrainingConfig::default(),
1462 output_dir: "./test_output".to_string(),
1463 seed: 42,
1464 };
1465 let temp_dir = TempDir::new().unwrap();
1466 let device = Device::Cpu;
1467 let dtype = DType::F32;
1468
1469 let result = load_model_architecture(
1470 &config,
1471 &temp_dir.path().to_path_buf(),
1472 &device,
1473 dtype,
1474 None,
1475 None,
1476 );
1477 assert!(result.is_ok());
1478
1479 let model = result.unwrap();
1480 let input = Tensor::zeros((1, 10), dtype, &device).unwrap();
1482 let output = model.forward(&input).unwrap();
1483 assert_eq!(output.shape(), input.shape());
1484 }
1485}